|
1 | 1 | from __future__ import annotations
|
2 | 2 | from typing import TYPE_CHECKING
|
| 3 | +import os |
3 | 4 | import logging
|
4 | 5 | from botocore.exceptions import ClientError
|
| 6 | +from contextlib import contextmanager |
| 7 | +import botocore |
| 8 | +import boto3 |
5 | 9 |
|
6 | 10 | if TYPE_CHECKING:
|
7 | 11 | from typing import Any
|
| 12 | + from collections.abc import Iterable, Iterator |
8 | 13 |
|
9 | 14 | logger = logging.getLogger("e3.aws.s3")
|
10 | 15 |
|
11 | 16 |
|
| 17 | +class BucketExistsError(Exception): |
| 18 | + """Exception when a bucket already exists.""" |
| 19 | + |
| 20 | + def __init__(self, bucket: str) -> None: |
| 21 | + """Initialize BucketExistsError. |
| 22 | +
|
| 23 | + :param bucket: name of the bucket |
| 24 | + """ |
| 25 | + self.bucket = bucket |
| 26 | + |
| 27 | + |
12 | 28 | class KeyExistsError(Exception):
|
13 | 29 | """Exception when a key already exists."""
|
14 | 30 |
|
@@ -47,6 +63,42 @@ def __init__(
|
47 | 63 | self.client = client
|
48 | 64 | self.bucket = bucket
|
49 | 65 |
|
| 66 | + def create_bucket(self, *, exist_ok: bool = True) -> None: |
| 67 | + """Create the bucket. |
| 68 | +
|
| 69 | + :param exist_ok: don't raise an exception if the bucket already exists |
| 70 | + :raises BucketExistsError: if the bucket already exists |
| 71 | + """ |
| 72 | + try: |
| 73 | + params: dict[str, Any] = {} |
| 74 | + |
| 75 | + # us-east-1 is the default location |
| 76 | + region = self.client.meta.region_name |
| 77 | + if region != "us-east-1": |
| 78 | + params["CreateBucketConfiguration"] = {"LocationConstraint": region} |
| 79 | + |
| 80 | + self.client.create_bucket(Bucket=self.bucket, **params) |
| 81 | + except ClientError as error: |
| 82 | + # Raise any non already exists error |
| 83 | + if error.response["Error"]["Code"] not in [ |
| 84 | + "BucketAlreadyExists", |
| 85 | + "BucketAlreadyOwnedByYou", |
| 86 | + ]: |
| 87 | + raise |
| 88 | + |
| 89 | + if not exist_ok: |
| 90 | + raise BucketExistsError(self.bucket) from error |
| 91 | + |
| 92 | + def clear_bucket(self) -> None: |
| 93 | + """Clear objects from S3.""" |
| 94 | + for obj in list(self.iterate()): |
| 95 | + self.client.delete_object(Bucket=self.bucket, Key=obj["Key"]) |
| 96 | + |
| 97 | + def delete_bucket(self) -> None: |
| 98 | + """Clear and delete the bucket.""" |
| 99 | + self.clear_bucket() |
| 100 | + self.client.delete_bucket(Bucket=self.bucket) |
| 101 | + |
50 | 102 | def push(self, key: str, content: bytes, exist_ok: bool | None = None) -> None:
|
51 | 103 | """Push content to S3.
|
52 | 104 |
|
@@ -93,9 +145,84 @@ def get(self, key: str, default: bytes | None = None) -> bytes:
|
93 | 145 | raise KeyNotFoundError(key) from e
|
94 | 146 | raise e
|
95 | 147 |
|
| 148 | + def iterate(self, *, prefix: str | None = None) -> Iterable[dict[str, Any]]: |
| 149 | + """Iterate all objects from S3. |
| 150 | +
|
| 151 | + :param prefix: limit to objects with that prefix |
| 152 | + :return: an iterator over objects from S3 |
| 153 | + """ |
| 154 | + params = {"Bucket": self.bucket} |
| 155 | + |
| 156 | + if prefix is not None: |
| 157 | + params["Prefix"] = prefix |
| 158 | + |
| 159 | + paginator = self.client.get_paginator("list_objects_v2") |
| 160 | + for page in paginator.paginate(**params): |
| 161 | + for content in page.get("Contents", []): |
| 162 | + yield content |
| 163 | + |
96 | 164 | def delete(self, key: str) -> None:
|
97 | 165 | """Delete content from S3.
|
98 | 166 |
|
99 | 167 | :param key: object key
|
100 | 168 | """
|
101 | 169 | self.client.delete_object(Bucket=self.bucket, Key=key)
|
| 170 | + |
| 171 | + @property |
| 172 | + def bucket_exists(self) -> bool: |
| 173 | + """Return if the bucket exists.""" |
| 174 | + try: |
| 175 | + self.client.head_bucket(Bucket=self.bucket) |
| 176 | + return True |
| 177 | + except ClientError as e: |
| 178 | + if e.response["Error"]["Code"] == "404": |
| 179 | + return False |
| 180 | + raise |
| 181 | + |
| 182 | + @property |
| 183 | + def key_count(self) -> int: |
| 184 | + """Return the number of keys from S3.""" |
| 185 | + return len(list(self.iterate())) |
| 186 | + |
| 187 | + |
| 188 | +@contextmanager |
| 189 | +def bucket( |
| 190 | + name: str, |
| 191 | + *, |
| 192 | + client: botocore.client.S3 | None = None, |
| 193 | + region: str | None = None, |
| 194 | + auto_create: bool = True, |
| 195 | + auto_delete: bool = False, |
| 196 | + exist_ok: bool = True, |
| 197 | +) -> Iterator[S3]: |
| 198 | + """Context manager to create and make AWS API calls on a bucket. |
| 199 | +
|
| 200 | + If auto_create is True, the bucket is created when entering the |
| 201 | + context. If the bucket already exists and exist_ok is False, an |
| 202 | + exception is raised. |
| 203 | +
|
| 204 | + If auto_delete is True, the bucket is cleared and deleted when |
| 205 | + leaving the context. |
| 206 | +
|
| 207 | + :param name: name of the bucket |
| 208 | + :param client: a client for the S3 API |
| 209 | + :param region: region of the client (default AWS_DEFAULT_REGION) |
| 210 | + :param auto_create: create the bucket when entering the context |
| 211 | + :param auto_delete: delete the bucket when leaving the context |
| 212 | + :param exist_ok: don't raise an exception if the bucket already exists |
| 213 | + :raises BucketExistsError: if the bucket already exists |
| 214 | + """ |
| 215 | + if client is None: |
| 216 | + region = region if region is not None else os.environ["AWS_DEFAULT_REGION"] |
| 217 | + client = boto3.client("s3", region_name=region) |
| 218 | + |
| 219 | + s3 = S3(client=client, bucket=name) |
| 220 | + |
| 221 | + if auto_create: |
| 222 | + s3.create_bucket(exist_ok=exist_ok) |
| 223 | + |
| 224 | + try: |
| 225 | + yield s3 |
| 226 | + finally: |
| 227 | + if auto_delete: |
| 228 | + s3.delete_bucket() |
0 commit comments