Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
name: CI

on:
push:
branches: [ master ]
pull_request:
branches:
- '**'
push:
branches: [master]
workflow_dispatch:

concurrency: # https://stackoverflow.com/questions/66335225#comment133398800_72408109
group: ${{ github.workflow }}-${{ github.ref || github.run_id }}
cancel-in-progress: ${{ github.event_name == 'pull_request' }}

jobs:
run_unittest_tests:
Expand Down
12 changes: 12 additions & 0 deletions docs/backends/amazon-S3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,25 @@ Settings
Setting this overrides the settings for ``addressing_style``, ``signature_version`` and
``proxies``. Include them as arguments to your ``botocore.config.Config`` class if you need them.

``client_ttl`` or ``AWS_S3_CLIENT_TTL``

Default: ``3600``

The amount of seconds to store a boto3 client resource in an S3Storage instance's time-to-live cache.

.. note::

Long-lived boto3 clients have a known `memory leak`_, which is why the client is
periodically recreated to avoid excessive memory consumption.

.. _AWS Signature Version 4: https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html
.. _S3 region list: https://docs.aws.amazon.com/general/latest/gr/s3.html#s3_region
.. _list of canned ACLs: https://docs.aws.amazon.com/AmazonS3/latest/dev/acl-overview.html#canned-acl
.. _Boto3 docs for uploading files: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.put_object
.. _Boto3 docs for TransferConfig: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.TransferConfig
.. _ManifestStaticFilesStorage: https://docs.djangoproject.com/en/3.1/ref/contrib/staticfiles/#manifeststaticfilesstorage
.. _Botocore docs: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html#botocore.config.Config
.. _memory leak: https://github.com/boto/boto3/issues/1670

.. _cloudfront-signed-url-header:

Expand Down
112 changes: 70 additions & 42 deletions storages/backends/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import posixpath
import tempfile
import threading
import time
import warnings
from datetime import datetime
from datetime import timedelta
Expand Down Expand Up @@ -38,8 +39,9 @@
from botocore.config import Config
from botocore.exceptions import ClientError
from botocore.signers import CloudFrontSigner
except ImportError as e:
raise ImproperlyConfigured("Could not load Boto3's S3 bindings. %s" % e)
except (ImportError, ModuleNotFoundError) as e:
msg = "Could not import boto3. Did you run 'pip install django-storages[s3]'?"
raise ImproperlyConfigured(msg) from e


# NOTE: these are defined as functions so both can be tested
Expand Down Expand Up @@ -329,9 +331,15 @@ def __init__(self, **settings):
"AWS_S3_SECRET_ACCESS_KEY/secret_key"
)

self._bucket = None
self._connections = threading.local()
self._unsigned_connections = threading.local()
# These variables are used for a boto3 client time-to-live caching mechanism.
# We want to avoid storing a resource for too long to avoid their memory leak
# ref https://github.com/boto/boto3/issues/1670.
self._connection_lock = threading.Lock()
self._connection_expiry = None
self._connection = None
self._unsigned_connection_lock = threading.Lock()
self._unsigned_connection_expiry = None
self._unsigned_connection = None

if self.config is not None:
warnings.warn(
Expand All @@ -347,6 +355,9 @@ def __init__(self, **settings):
s3={"addressing_style": self.addressing_style},
signature_version=self.signature_version,
proxies=self.proxies,
max_pool_connections=64, # thread-safe
tcp_keepalive=True,
retries={"max_attempts": 6, "mode": "adaptive"},
)

if self.use_threads is False:
Expand Down Expand Up @@ -441,58 +452,77 @@ def get_default_settings(self):
"use_threads": setting("AWS_S3_USE_THREADS", True),
"transfer_config": setting("AWS_S3_TRANSFER_CONFIG", None),
"client_config": setting("AWS_S3_CLIENT_CONFIG", None),
"client_ttl": setting("AWS_S3_CLIENT_TTL", 3600),
}

def __getstate__(self):
state = self.__dict__.copy()
state.pop("_connections", None)
state.pop("_unsigned_connections", None)
state.pop("_bucket", None)
state.pop("_connection_lock", None)
state.pop("_connection_expiry", None)
state.pop("_connection", None)
state.pop("_unsigned_connection_lock", None)
state.pop("_unsigned_connection_expiry", None)
state.pop("_unsigned_connection", None)
return state

def __setstate__(self, state):
state["_connections"] = threading.local()
state["_unsigned_connections"] = threading.local()
state["_bucket"] = None
state["_connection_lock"] = threading.Lock()
state["_connection_expiry"] = None
state["_connection"] = None
state["_unsigned_connection_lock"] = threading.Lock()
state["_unsigned_connection_expiry"] = None
state["_unsigned_connection"] = None
self.__dict__ = state

@property
def connection(self):
connection = getattr(self._connections, "connection", None)
if connection is None:
session = self._create_session()
self._connections.connection = session.resource(
"s3",
region_name=self.region_name,
use_ssl=self.use_ssl,
endpoint_url=self.endpoint_url,
config=self.client_config,
verify=self.verify,
)
return self._connections.connection
"""
Get the (cached) thread-safe boto3 s3 resource.
"""
with self._connection_lock:
if (
self._connection is None # fresh instance
or time.monotonic() > self._connection_expiry # TTL expired
):
self._connection_expiry = time.monotonic() + self.client_ttl
self._connection = self._create_connection()
return self._connection

@property
def unsigned_connection(self):
unsigned_connection = getattr(self._unsigned_connections, "connection", None)
if unsigned_connection is None:
session = self._create_session()
config = self.client_config.merge(
Config(signature_version=botocore.UNSIGNED)
)
self._unsigned_connections.connection = session.resource(
"s3",
region_name=self.region_name,
use_ssl=self.use_ssl,
endpoint_url=self.endpoint_url,
config=config,
verify=self.verify,
)
return self._unsigned_connections.connection
"""
Get the (cached) thread-safe boto3 s3 resource (unsigned).
"""
with self._unsigned_connection_lock:
if (
self._unsigned_connection is None # fresh instance
or time.monotonic() > self._unsigned_connection_expiry # TTL expired
):
self._unsigned_connection_expiry = time.monotonic() + self.client_ttl
self._unsigned_connection = self._create_connection(unsigned=True)
return self._unsigned_connection

def _create_connection(self, *, unsigned=False):
"""
Create a new session and boto3 s3 resource.
"""
config = self.client_config
if unsigned:
config = config.merge(Config(signature_version=botocore.UNSIGNED))
session = self._create_session()
return session.resource(
"s3",
region_name=self.region_name,
use_ssl=self.use_ssl,
endpoint_url=self.endpoint_url,
config=config,
verify=self.verify,
)

def _create_session(self):
"""
If a user specifies a profile name and this class obtains access keys
from another source such as environment variables,we want the profile
from another source such as environment variables, we want the profile
name to take precedence.
"""
if self.session_profile:
Expand All @@ -511,9 +541,7 @@ def bucket(self):
Get the current bucket. If there is no current bucket object
create it.
"""
if self._bucket is None:
self._bucket = self.connection.Bucket(self.bucket_name)
return self._bucket
return self.connection.Bucket(self.bucket_name)

def _normalize_name(self, name):
"""
Expand Down
96 changes: 38 additions & 58 deletions tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import io
import os
import pickle
import threading
from textwrap import dedent
from unittest import mock
from unittest import skipIf
from urllib.parse import urlparse

import boto3
Expand Down Expand Up @@ -35,8 +33,7 @@ def read_manifest(self):
class S3StorageTests(TestCase):
def setUp(self):
self.storage = s3.S3Storage()
self.storage._connections.connection = mock.MagicMock()
self.storage._unsigned_connections.connection = mock.MagicMock()
self.storage._create_connection = mock.MagicMock()

@mock.patch("boto3.Session")
def test_s3_session(self, session):
Expand All @@ -57,7 +54,9 @@ def test_client_config(self, resource):

@mock.patch("boto3.Session.resource")
def test_connection_unsiged(self, resource):
with override_settings(AWS_S3_ADDRESSING_STYLE="virtual"):
with override_settings(
AWS_S3_ADDRESSING_STYLE="virtual", AWS_QUERYSTRING_AUTH=False
):
storage = s3.S3Storage()
_ = storage.unsigned_connection
resource.assert_called_once()
Expand All @@ -68,36 +67,17 @@ def test_connection_unsiged(self, resource):
"virtual", resource.call_args[1]["config"].s3["addressing_style"]
)

def test_pickle_with_bucket(self):
"""
Test that the storage can be pickled with a bucket attached
"""
# Ensure the bucket has been used
self.storage.bucket
self.assertIsNotNone(self.storage._bucket)

# Can't pickle MagicMock, but you can't pickle a real Bucket object either
p = pickle.dumps(self.storage)
new_storage = pickle.loads(p)

self.assertIsInstance(new_storage._connections, threading.local)
# Put the mock connection back in
new_storage._connections.connection = mock.MagicMock()

self.assertIsNone(new_storage._bucket)
new_storage.bucket
self.assertIsNotNone(new_storage._bucket)

def test_pickle_without_bucket(self):
def test_pickle(self):
"""
Test that the storage can be pickled, without a bucket instance
"""

# Can't pickle a threadlocal
p = pickle.dumps(self.storage)
storage = s3.S3Storage()
_ = storage.connection
_ = storage.bucket
p = pickle.dumps(storage)
new_storage = pickle.loads(p)

self.assertIsInstance(new_storage._connections, threading.local)
_ = new_storage.connection
_ = storage.bucket

def test_storage_url_slashes(self):
"""
Expand Down Expand Up @@ -580,9 +560,7 @@ def test_storage_listdir_base(self):

paginator = mock.MagicMock()
paginator.paginate.return_value = pages
self.storage._connections.connection.meta.client.get_paginator.return_value = (
paginator
)
self.storage.connection.meta.client.get_paginator.return_value = paginator

dirs, files = self.storage.listdir("")
paginator.paginate.assert_called_with(
Expand All @@ -609,9 +587,7 @@ def test_storage_listdir_subdir(self):

paginator = mock.MagicMock()
paginator.paginate.return_value = pages
self.storage._connections.connection.meta.client.get_paginator.return_value = (
paginator
)
self.storage.connection.meta.client.get_paginator.return_value = paginator

dirs, files = self.storage.listdir("some/")
paginator.paginate.assert_called_with(
Expand All @@ -634,9 +610,7 @@ def test_storage_listdir_empty(self):

paginator = mock.MagicMock()
paginator.paginate.return_value = pages
self.storage._connections.connection.meta.client.get_paginator.return_value = (
paginator
)
self.storage.connection.meta.client.get_paginator.return_value = paginator

dirs, files = self.storage.listdir("dir/")
paginator.paginate.assert_called_with(
Expand Down Expand Up @@ -797,21 +771,6 @@ def test_custom_domain_parameters(self):
self.assertEqual(parsed_url.path, "/filename.mp4")
self.assertEqual(parsed_url.query, "version=10")

@skipIf(threading is None, "Test requires threading")
def test_connection_threading(self):
connections = []

def thread_storage_connection():
connections.append(self.storage.connection)

for _ in range(2):
t = threading.Thread(target=thread_storage_connection)
t.start()
t.join()

# Connection for each thread needs to be unique
self.assertIsNot(connections[0], connections[1])

def test_location_leading_slash(self):
msg = (
"S3Storage.location cannot begin with a leading slash. "
Expand Down Expand Up @@ -1004,11 +963,32 @@ def test_security_token(self):
storage = s3.S3Storage()
self.assertEqual(storage.security_token, "baz")

def test_connection_cache(self):
storage1 = s3.S3Storage()
connection1 = storage1.connection
unsigned_connection1 = storage1.unsigned_connection
# different connections
self.assertNotEqual(id(connection1), id(unsigned_connection1))
# cache hits
self.assertEqual(id(connection1), id(storage1.connection))
self.assertEqual(id(unsigned_connection1), id(storage1.unsigned_connection))

storage2 = s3.S3Storage()
connection2 = storage2.connection
unsigned_connection2 = storage2.unsigned_connection
# different connections
self.assertNotEqual(id(connection2), id(unsigned_connection2))
self.assertNotEqual(id(connection1), id(connection2))
self.assertNotEqual(id(unsigned_connection1), id(unsigned_connection2))
# cache hits
self.assertEqual(id(connection2), id(storage2.connection))
self.assertEqual(id(unsigned_connection2), id(storage2.unsigned_connection))


class S3StaticStorageTests(TestCase):
def setUp(self):
self.storage = s3.S3StaticStorage()
self.storage._connections.connection = mock.MagicMock()
self.storage._create_connection = mock.MagicMock()

def test_querystring_auth(self):
self.assertFalse(self.storage.querystring_auth)
Expand All @@ -1017,7 +997,7 @@ def test_querystring_auth(self):
class S3ManifestStaticStorageTests(TestCase):
def setUp(self):
self.storage = S3ManifestStaticStorageTestStorage()
self.storage._connections.connection = mock.MagicMock()
self.storage._create_connection = mock.MagicMock()

def test_querystring_auth(self):
self.assertFalse(self.storage.querystring_auth)
Expand All @@ -1029,7 +1009,7 @@ def test_save(self):
class S3FileTests(TestCase):
def setUp(self) -> None:
self.storage = s3.S3Storage()
self.storage._connections.connection = mock.MagicMock()
self.storage._create_connection = mock.MagicMock()

def test_loading_ssec(self):
params = {"SSECustomerKey": "xyz", "CacheControl": "never"}
Expand Down