diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9c545fd5e..143d2961c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,12 +1,14 @@ name: CI on: - push: - branches: [ master ] pull_request: - branches: - - '**' + push: + branches: [master] + workflow_dispatch: +concurrency: # https://stackoverflow.com/questions/66335225#comment133398800_72408109 + group: ${{ github.workflow }}-${{ github.ref || github.run_id }} + cancel-in-progress: ${{ github.event_name == 'pull_request' }} jobs: run_unittest_tests: diff --git a/docs/backends/amazon-S3.rst b/docs/backends/amazon-S3.rst index 641a1c830..b832a26d9 100644 --- a/docs/backends/amazon-S3.rst +++ b/docs/backends/amazon-S3.rst @@ -259,6 +259,17 @@ Settings Setting this overrides the settings for ``addressing_style``, ``signature_version`` and ``proxies``. Include them as arguments to your ``botocore.config.Config`` class if you need them. +``client_ttl`` or ``AWS_S3_CLIENT_TTL`` + + Default: ``3600`` + + The amount of seconds to store a boto3 client resource in an S3Storage instance's time-to-live cache. + + .. note:: + + Long-lived boto3 clients have a known `memory leak`_, which is why the client is + periodically recreated to avoid excessive memory consumption. + .. _AWS Signature Version 4: https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html .. _S3 region list: https://docs.aws.amazon.com/general/latest/gr/s3.html#s3_region .. _list of canned ACLs: https://docs.aws.amazon.com/AmazonS3/latest/dev/acl-overview.html#canned-acl @@ -266,6 +277,7 @@ Settings .. _Boto3 docs for TransferConfig: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.TransferConfig .. _ManifestStaticFilesStorage: https://docs.djangoproject.com/en/3.1/ref/contrib/staticfiles/#manifeststaticfilesstorage .. _Botocore docs: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html#botocore.config.Config +.. _memory leak: https://github.com/boto/boto3/issues/1670 .. _cloudfront-signed-url-header: diff --git a/storages/backends/s3.py b/storages/backends/s3.py index fa2d39f16..2c8a98d75 100644 --- a/storages/backends/s3.py +++ b/storages/backends/s3.py @@ -4,6 +4,7 @@ import posixpath import tempfile import threading +import time import warnings from datetime import datetime from datetime import timedelta @@ -38,8 +39,9 @@ from botocore.config import Config from botocore.exceptions import ClientError from botocore.signers import CloudFrontSigner -except ImportError as e: - raise ImproperlyConfigured("Could not load Boto3's S3 bindings. %s" % e) +except (ImportError, ModuleNotFoundError) as e: + msg = "Could not import boto3. Did you run 'pip install django-storages[s3]'?" + raise ImproperlyConfigured(msg) from e # NOTE: these are defined as functions so both can be tested @@ -329,9 +331,15 @@ def __init__(self, **settings): "AWS_S3_SECRET_ACCESS_KEY/secret_key" ) - self._bucket = None - self._connections = threading.local() - self._unsigned_connections = threading.local() + # These variables are used for a boto3 client time-to-live caching mechanism. + # We want to avoid storing a resource for too long to avoid their memory leak + # ref https://github.com/boto/boto3/issues/1670. + self._connection_lock = threading.Lock() + self._connection_expiry = None + self._connection = None + self._unsigned_connection_lock = threading.Lock() + self._unsigned_connection_expiry = None + self._unsigned_connection = None if self.config is not None: warnings.warn( @@ -347,6 +355,9 @@ def __init__(self, **settings): s3={"addressing_style": self.addressing_style}, signature_version=self.signature_version, proxies=self.proxies, + max_pool_connections=64, # shared between threads + tcp_keepalive=True, + retries={"max_attempts": 6, "mode": "adaptive"}, ) if self.use_threads is False: @@ -441,58 +452,79 @@ def get_default_settings(self): "use_threads": setting("AWS_S3_USE_THREADS", True), "transfer_config": setting("AWS_S3_TRANSFER_CONFIG", None), "client_config": setting("AWS_S3_CLIENT_CONFIG", None), + "client_ttl": setting("AWS_S3_CLIENT_TTL", 3600), } def __getstate__(self): state = self.__dict__.copy() - state.pop("_connections", None) - state.pop("_unsigned_connections", None) - state.pop("_bucket", None) + state.pop("_connection_lock", None) + state.pop("_connection_expiry", None) + state.pop("_connection", None) + state.pop("_unsigned_connection_lock", None) + state.pop("_unsigned_connection_expiry", None) + state.pop("_unsigned_connection", None) return state def __setstate__(self, state): - state["_connections"] = threading.local() - state["_unsigned_connections"] = threading.local() - state["_bucket"] = None + state["_connection_lock"] = threading.Lock() + state["_connection_expiry"] = None + state["_connection"] = None + state["_unsigned_connection_lock"] = threading.Lock() + state["_unsigned_connection_expiry"] = None + state["_unsigned_connection"] = None self.__dict__ = state @property def connection(self): - connection = getattr(self._connections, "connection", None) - if connection is None: - session = self._create_session() - self._connections.connection = session.resource( - "s3", - region_name=self.region_name, - use_ssl=self.use_ssl, - endpoint_url=self.endpoint_url, - config=self.client_config, - verify=self.verify, - ) - return self._connections.connection + """ + Get the (cached) thread-safe boto3 s3 resource. + """ + with self._connection_lock: + if ( + self._connection is None # fresh instance + or time.monotonic() > self._connection_expiry # TTL expired + ): + self._connection_expiry = time.monotonic() + self.client_ttl + self._connection = self._create_connection() + return self._connection @property def unsigned_connection(self): - unsigned_connection = getattr(self._unsigned_connections, "connection", None) - if unsigned_connection is None: - session = self._create_session() - config = self.client_config.merge( - Config(signature_version=botocore.UNSIGNED) - ) - self._unsigned_connections.connection = session.resource( - "s3", - region_name=self.region_name, - use_ssl=self.use_ssl, - endpoint_url=self.endpoint_url, - config=config, - verify=self.verify, - ) - return self._unsigned_connections.connection + """ + Get the (cached) thread-safe boto3 s3 resource (unsigned). + """ + with self._unsigned_connection_lock: + if ( + self._unsigned_connection is None # fresh instance + or time.monotonic() > self._unsigned_connection_expiry # TTL expired + ): + self._unsigned_connection_expiry = time.monotonic() + self.client_ttl + self._unsigned_connection = self._create_connection(unsigned=True) + return self._unsigned_connection + + def _create_connection(self, *, unsigned=False): + """ + Create a new session and thread-safe boto3 s3 resource. + """ + config = self.client_config + if unsigned: + config = config.merge(Config(signature_version=botocore.UNSIGNED)) + session = self._create_session() + # thread-safe boto3 client (wrapped by a boto3 resource) ref: + # https://github.com/boto/boto3/blob/1.38.41/docs/source/guide/clients.rst?plain=1#L111 + return session.resource( + "s3", + region_name=self.region_name, + use_ssl=self.use_ssl, + endpoint_url=self.endpoint_url, + config=config, + verify=self.verify, + ) def _create_session(self): """ If a user specifies a profile name and this class obtains access keys - from another source such as environment variables,we want the profile + from another source such as environment variables, we want the profile name to take precedence. """ if self.session_profile: @@ -511,9 +543,7 @@ def bucket(self): Get the current bucket. If there is no current bucket object create it. """ - if self._bucket is None: - self._bucket = self.connection.Bucket(self.bucket_name) - return self._bucket + return self.connection.Bucket(self.bucket_name) def _normalize_name(self, name): """ diff --git a/tests/test_s3.py b/tests/test_s3.py index e324baf05..d964e3fd3 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -3,10 +3,8 @@ import io import os import pickle -import threading from textwrap import dedent from unittest import mock -from unittest import skipIf from urllib.parse import urlparse import boto3 @@ -35,8 +33,7 @@ def read_manifest(self): class S3StorageTests(TestCase): def setUp(self): self.storage = s3.S3Storage() - self.storage._connections.connection = mock.MagicMock() - self.storage._unsigned_connections.connection = mock.MagicMock() + self.storage._create_connection = mock.MagicMock() @mock.patch("boto3.Session") def test_s3_session(self, session): @@ -57,7 +54,9 @@ def test_client_config(self, resource): @mock.patch("boto3.Session.resource") def test_connection_unsiged(self, resource): - with override_settings(AWS_S3_ADDRESSING_STYLE="virtual"): + with override_settings( + AWS_S3_ADDRESSING_STYLE="virtual", AWS_QUERYSTRING_AUTH=False + ): storage = s3.S3Storage() _ = storage.unsigned_connection resource.assert_called_once() @@ -68,36 +67,17 @@ def test_connection_unsiged(self, resource): "virtual", resource.call_args[1]["config"].s3["addressing_style"] ) - def test_pickle_with_bucket(self): - """ - Test that the storage can be pickled with a bucket attached - """ - # Ensure the bucket has been used - self.storage.bucket - self.assertIsNotNone(self.storage._bucket) - - # Can't pickle MagicMock, but you can't pickle a real Bucket object either - p = pickle.dumps(self.storage) - new_storage = pickle.loads(p) - - self.assertIsInstance(new_storage._connections, threading.local) - # Put the mock connection back in - new_storage._connections.connection = mock.MagicMock() - - self.assertIsNone(new_storage._bucket) - new_storage.bucket - self.assertIsNotNone(new_storage._bucket) - - def test_pickle_without_bucket(self): + def test_pickle(self): """ Test that the storage can be pickled, without a bucket instance """ - - # Can't pickle a threadlocal - p = pickle.dumps(self.storage) + storage = s3.S3Storage() + _ = storage.connection + _ = storage.bucket + p = pickle.dumps(storage) new_storage = pickle.loads(p) - - self.assertIsInstance(new_storage._connections, threading.local) + _ = new_storage.connection + _ = storage.bucket def test_storage_url_slashes(self): """ @@ -580,9 +560,7 @@ def test_storage_listdir_base(self): paginator = mock.MagicMock() paginator.paginate.return_value = pages - self.storage._connections.connection.meta.client.get_paginator.return_value = ( - paginator - ) + self.storage.connection.meta.client.get_paginator.return_value = paginator dirs, files = self.storage.listdir("") paginator.paginate.assert_called_with( @@ -609,9 +587,7 @@ def test_storage_listdir_subdir(self): paginator = mock.MagicMock() paginator.paginate.return_value = pages - self.storage._connections.connection.meta.client.get_paginator.return_value = ( - paginator - ) + self.storage.connection.meta.client.get_paginator.return_value = paginator dirs, files = self.storage.listdir("some/") paginator.paginate.assert_called_with( @@ -634,9 +610,7 @@ def test_storage_listdir_empty(self): paginator = mock.MagicMock() paginator.paginate.return_value = pages - self.storage._connections.connection.meta.client.get_paginator.return_value = ( - paginator - ) + self.storage.connection.meta.client.get_paginator.return_value = paginator dirs, files = self.storage.listdir("dir/") paginator.paginate.assert_called_with( @@ -797,21 +771,6 @@ def test_custom_domain_parameters(self): self.assertEqual(parsed_url.path, "/filename.mp4") self.assertEqual(parsed_url.query, "version=10") - @skipIf(threading is None, "Test requires threading") - def test_connection_threading(self): - connections = [] - - def thread_storage_connection(): - connections.append(self.storage.connection) - - for _ in range(2): - t = threading.Thread(target=thread_storage_connection) - t.start() - t.join() - - # Connection for each thread needs to be unique - self.assertIsNot(connections[0], connections[1]) - def test_location_leading_slash(self): msg = ( "S3Storage.location cannot begin with a leading slash. " @@ -1004,11 +963,32 @@ def test_security_token(self): storage = s3.S3Storage() self.assertEqual(storage.security_token, "baz") + def test_connection_cache(self): + storage1 = s3.S3Storage() + connection1 = storage1.connection + unsigned_connection1 = storage1.unsigned_connection + # different connections + self.assertNotEqual(id(connection1), id(unsigned_connection1)) + # cache hits + self.assertEqual(id(connection1), id(storage1.connection)) + self.assertEqual(id(unsigned_connection1), id(storage1.unsigned_connection)) + + storage2 = s3.S3Storage() + connection2 = storage2.connection + unsigned_connection2 = storage2.unsigned_connection + # different connections + self.assertNotEqual(id(connection2), id(unsigned_connection2)) + self.assertNotEqual(id(connection1), id(connection2)) + self.assertNotEqual(id(unsigned_connection1), id(unsigned_connection2)) + # cache hits + self.assertEqual(id(connection2), id(storage2.connection)) + self.assertEqual(id(unsigned_connection2), id(storage2.unsigned_connection)) + class S3StaticStorageTests(TestCase): def setUp(self): self.storage = s3.S3StaticStorage() - self.storage._connections.connection = mock.MagicMock() + self.storage._create_connection = mock.MagicMock() def test_querystring_auth(self): self.assertFalse(self.storage.querystring_auth) @@ -1017,7 +997,7 @@ def test_querystring_auth(self): class S3ManifestStaticStorageTests(TestCase): def setUp(self): self.storage = S3ManifestStaticStorageTestStorage() - self.storage._connections.connection = mock.MagicMock() + self.storage._create_connection = mock.MagicMock() def test_querystring_auth(self): self.assertFalse(self.storage.querystring_auth) @@ -1029,7 +1009,7 @@ def test_save(self): class S3FileTests(TestCase): def setUp(self) -> None: self.storage = s3.S3Storage() - self.storage._connections.connection = mock.MagicMock() + self.storage._create_connection = mock.MagicMock() def test_loading_ssec(self): params = {"SSECustomerKey": "xyz", "CacheControl": "never"}