diff --git a/api/environments/managers.py b/api/environments/managers.py index d191316ff236..f5f30dc9587e 100644 --- a/api/environments/managers.py +++ b/api/environments/managers.py @@ -1,10 +1,16 @@ +import typing + from django.db.models import Prefetch from softdelete.models import SoftDeleteManager # type: ignore[import-untyped] +from environments.constants import IDENTITY_INTEGRATIONS_RELATION_NAMES from features.models import FeatureSegment, FeatureState from features.multivariate.models import MultivariateFeatureStateValue from segments.models import Segment +if typing.TYPE_CHECKING: + from environments.models import Environment + class EnvironmentManager(SoftDeleteManager): # type: ignore[misc] def filter_for_document_builder( # type: ignore[no-untyped-def] @@ -52,6 +58,20 @@ def filter_for_document_builder( # type: ignore[no-untyped-def] .filter(*args, **kwargs) ) + def get_for_cache(self, api_key: str) -> "Environment": + select_related_args = ( + "project", + "project__organisation", + *IDENTITY_INTEGRATIONS_RELATION_NAMES, + ) + base_qs = self.select_related(*select_related_args).defer("description") + qs_for_embedded_api_key = base_qs.filter(api_key=api_key) + qs_for_fk_api_key = base_qs.filter(api_keys__key=api_key) + environment: Environment = qs_for_embedded_api_key.union( + qs_for_fk_api_key + ).get() + return environment + def get_queryset(self): # type: ignore[no-untyped-def] return super().get_queryset().select_related("project", "project__organisation") diff --git a/api/environments/models.py b/api/environments/models.py index d872645c6d14..2f30691e2f66 100644 --- a/api/environments/models.py +++ b/api/environments/models.py @@ -164,8 +164,11 @@ def create_feature_states(self) -> None: @hook(AFTER_UPDATE) # type: ignore[misc] def clear_environment_cache(self) -> None: - # TODO: this could rebuild the cache itself (using an async task) - environment_cache.delete(self.initial_value("api_key")) + from environments.tasks import update_environment_caches + + update_environment_caches.delay( + kwargs={"environment_api_key": self.initial_value("api_key")} + ) @hook(AFTER_UPDATE, when="api_key", has_changed=True) # type: ignore[misc] def update_environment_document_cache(self) -> None: @@ -246,19 +249,8 @@ def get_from_cache(cls, api_key: str | None) -> "Environment | None": environment: "Environment" = environment_cache.get(api_key) if not environment: - select_related_args = ( - "project", - "project__organisation", - *IDENTITY_INTEGRATIONS_RELATION_NAMES, - ) - base_qs = cls.objects.select_related(*select_related_args).defer( - "description" - ) - qs_for_embedded_api_key = base_qs.filter(api_key=api_key) - qs_for_fk_api_key = base_qs.filter(api_keys__key=api_key) - try: - environment = qs_for_embedded_api_key.union(qs_for_fk_api_key).get() + environment = cls.objects.get_for_cache(api_key) except cls.DoesNotExist: cls.set_bad_key(api_key) logger.info("Environment with api_key %s does not exist" % api_key) diff --git a/api/environments/tasks.py b/api/environments/tasks.py index 1a30f5c51411..b55e454db7b9 100644 --- a/api/environments/tasks.py +++ b/api/environments/tasks.py @@ -1,3 +1,5 @@ +from django.conf import settings +from django.core.cache import caches from task_processor.decorators import ( register_task_handler, ) @@ -16,12 +18,36 @@ send_environment_update_message_for_project, ) +environment_cache = caches[settings.ENVIRONMENT_CACHE_NAME] + @register_task_handler(priority=TaskPriority.HIGH) def rebuild_environment_document(environment_id: int) -> None: Environment.write_environment_documents(environment_id=environment_id) +@register_task_handler(priority=TaskPriority.HIGH) +def update_environment_caches(environment_api_key: str) -> None: + try: + environment = Environment.objects.get_for_cache(api_key=environment_api_key) + + # only rebuild the caches for those that previously existed to avoid + # unnecessarily caching data for unused keys. + cached_environments_by_api_key = environment_cache.get_many( + [eak.key for eak in environment.api_keys.all()] + ) + environment_cache.set_many( + { + environment.api_key: environment, + **{key: environment for key in cached_environments_by_api_key}, + } + ) + except Environment.DoesNotExist: + # unfortunately, since the EnvironmentAPIKey model is not soft-deleted + # we cannot clear those caches here and instead rely on the cache timeout + environment_cache.delete(environment_api_key) + + @register_task_handler(priority=TaskPriority.HIGHEST) def process_environment_update(audit_log_id: int): # type: ignore[no-untyped-def] audit_log = AuditLog.objects.get(id=audit_log_id) diff --git a/api/tests/unit/environments/test_unit_environments_models.py b/api/tests/unit/environments/test_unit_environments_models.py index 55c068fcb621..ac45c34482df 100644 --- a/api/tests/unit/environments/test_unit_environments_models.py +++ b/api/tests/unit/environments/test_unit_environments_models.py @@ -392,14 +392,10 @@ def test_change_request_audit_logs_does_not_update_updated_at(environment): # t def test_save_environment_clears_environment_cache(mocker, project): # type: ignore[no-untyped-def] # Given - mock_environment_cache = mocker.patch("environments.models.environment_cache") + mock_environment_cache = mocker.patch("environments.tasks.environment_cache") environment = Environment.objects.create(name="test environment", project=project) - # perform an update of the name to verify basic functionality - environment.name = "updated" - environment.save() - - # and update the api key to verify that the original api key is used to clear cache + # update the api key to verify that the original api key is used to clear cache old_key = copy(environment.api_key) new_key = "some-new-key" environment.api_key = new_key @@ -409,8 +405,8 @@ def test_save_environment_clears_environment_cache(mocker, project): # type: ig # Then mock_calls = mock_environment_cache.delete.mock_calls - assert len(mock_calls) == 2 - assert mock_calls[0][1][0] == mock_calls[1][1][0] == old_key + assert len(mock_calls) == 1 + assert mock_calls[0][1][0] == old_key @pytest.mark.parametrize( diff --git a/api/tests/unit/integrations/amplitude/test_unit_amplitude_models.py b/api/tests/unit/integrations/amplitude/test_unit_amplitude_models.py index 86c6214ab42b..ec2ad822a03b 100644 --- a/api/tests/unit/integrations/amplitude/test_unit_amplitude_models.py +++ b/api/tests/unit/integrations/amplitude/test_unit_amplitude_models.py @@ -48,9 +48,9 @@ def test_amplitude_configuration_delete_writes_environment_to_dynamodb( # type: ) -def test_amplitude_configuration_update_clears_environment_cache(environment, mocker): # type: ignore[no-untyped-def] +def test_amplitude_configuration_update_updates_environment_cache(environment, mocker): # type: ignore[no-untyped-def] # Given - mock_environment_cache = mocker.patch("environments.models.environment_cache") + mock_environment_cache = mocker.patch("environments.tasks.environment_cache") amplitude_config = AmplitudeConfiguration.objects.create( environment=environment, api_key="api-key", base_url="https://base.url.com" ) @@ -60,4 +60,8 @@ def test_amplitude_configuration_update_clears_environment_cache(environment, mo amplitude_config.save() # Then - mock_environment_cache.delete.assert_called_once_with(environment.api_key) + mock_environment_cache.set_many.assert_called_once() + + call_args = mock_environment_cache.set_many.call_args + assert len(call_args.args[0]) == 1 + assert call_args.args[0][environment.api_key].id == environment.id