diff --git a/CHANGELOG.md b/CHANGELOG.md index 65ed5469f..535a5dc40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,50 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +## [0.31.0] - 2025-07-18 +### Adds plugins support +- Adds an `experimental` property (`SuperTokensExperimentalConfig`) to the `SuperTokensConfig` + - Plugins can be configured under using the `plugins` property in the `experimental` config +- Refactors the AccountLinking recipe to be automatically initialized on SuperTokens init +- Adds `is_recipe_initialized` method to check if a recipe has been initialized + +### Breaking Changes +- `AccountLinkingRecipe.get_instance` will now raise an exception if not initialized +- Various config classes renamed for consistency across the codebase, and classes added where they were missing + - Old classes added to the recipe modules as aliases for backward compatibility, but will be removed in future versions. Prefer using the renamed classes. + - `InputOverrideConfig` renamed to `OverrideConfig` + - `OverrideConfig` renamed to `NormalisedOverrideConfig` + - Input config classes like `InputConfig` renamed to `Config` + - Normalised config classes like `Config` renamed to `NormalisedConfig` + - Changed classes: + - AccountLinking `InputOverrideConfig` -> `AccountLinkingOverrideConfig` + - Dashboard `InputOverrideConfig` -> `DashboardOverrideConfig` + - EmailPassword + - `InputOverrideConfig` -> `EmailPasswordOverrideConfig` + - `exceptions` export removed from `__init__`, import the `exceptions` module directly + - EmailVerification + - `InputOverrideConfig` -> `EmailVerificationOverrideConfig` + - `exception` export removed from `__init__`, import the `exceptions` module directly + - JWT `OverrideConfig` -> `JWTOverrideConfig` + - MultiFactorAuth `OverrideConfig` -> `MultiFactorAuthOverrideConfig` + - Multitenancy + - `InputOverrideConfig` -> `MultitenancyOverrideConfig` + - `exceptions` export removed from `__init__`, import the `exceptions` module directly + - OAuth2Provider + - `InputOverrideConfig` -> `OAuth2ProviderOverrideConfig` + - `exceptions` export removed from `__init__`, import the `exceptions` module directly + - OpenId `InputOverrideConfig` -> `OpenIdOverrideConfig` + - Passwordless `InputOverrideConfig` -> `PasswordlessOverrideConfig` + - Session + - `InputOverrideConfig` -> `SessionOverrideConfig` + - `exceptions` export removed from `__init__`, import the `exceptions` module directly + - ThirdParty + - `InputOverrideConfig` -> `ThirdPartyOverrideConfig` + - `exceptions` export removed from `__init__`, import the `exceptions` module directly + - TOTP `OverrideConfig` -> `TOTPOverrideConfig` + - UserMetadata `InputOverrideConfig` -> `UserMetadataOverrideConfig` + - UserRoles `InputOverrideConfig` -> `UserRolesOverrideConfig` + ## [0.30.1] - 2025-07-21 - Adds missing register credential endpoint to the Webauthn recipe diff --git a/dev-requirements.txt b/dev-requirements.txt index e39220586..99ffb6fe8 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -7,11 +7,12 @@ fastapi==0.115.5 Flask==3.0.3 flask-cors==5.0.0 nest-asyncio==1.6.0 +packaging==25.0 pdoc3==0.11.0 pre-commit==3.5.0 pyfakefs==5.7.4 pylint==3.2.7 -pyright==1.1.393 +pyright==1.1.402 python-dotenv==1.0.1 pytest==8.3.3 pytest-asyncio==0.24.0 diff --git a/pyproject.toml b/pyproject.toml index 69efe4a1a..a7e85ceca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,10 @@ line-length = 88 # Match Black's formatting src = ["supertokens_python"] [tool.ruff.lint] -extend-select = ["I"] # enable import sorting +extend-select = [ + "I", # enable import sorting + "RUF022", # Sort __all__ exports +] [tool.ruff.format] quote-style = "double" # Default @@ -18,3 +21,5 @@ include = ["supertokens_python/", "tests/", "examples/"] addopts = " -v -p no:warnings" python_paths = "." xfail_strict = true +# Removes requirement to use `@mark.asyncio` on async tests +asyncio_mode = "auto" diff --git a/setup.py b/setup.py index 95fa79288..fe85d89d7 100644 --- a/setup.py +++ b/setup.py @@ -61,28 +61,28 @@ } exclude_list = [ - "tests", - "examples", - "hooks", - ".gitignore", + ".circleci", ".git", + ".github", + ".gitignore", + ".pylintrc", + "Makefile", "addDevTag", "addReleaseTag", - "frontendDriverInterfaceSupported.json", "coreDriverInterfaceSupported.json", - ".github", - ".circleci", - "html", - "pyrightconfig.json", - "Makefile", - ".pylintrc", "dev-requirements.txt", "docs-templates", + "examples", + "frontendDriverInterfaceSupported.json", + "hooks", + "html", + "pyrightconfig.json", + "tests", ] setup( name="supertokens_python", - version="0.30.1", + version="0.31.0", author="SuperTokens", license="Apache 2.0", author_email="team@supertokens.com", @@ -112,22 +112,23 @@ ], keywords="", install_requires=[ + "Deprecated<1.3.0", # [crypto] ensures that it installs the `cryptography` library as well # based on constraints specified in https://github.com/jpadilla/pyjwt/blob/master/setup.cfg#L50 "PyJWT[crypto]>=2.5.0,<3.0.0", - "httpx>=0.15.0,<1.0.0", - "pycryptodome<3.21.0", - "tldextract<6.0.0", + "aiosmtplib>=1.1.6,<4.0.0", "asgiref>=3.4.1,<4", - "typing_extensions>=4.1.1,<5.0.0", - "Deprecated<1.3.0", + "httpx>=0.15.0,<1.0.0", + "packaging>=25.0,<26.0", "phonenumbers<9", - "twilio<10", - "aiosmtplib>=1.1.6,<4.0.0", "pkce<1.1.0", + "pycryptodome<3.21.0", + "pydantic>=2.10.6,<3.0.0", "pyotp<3", "python-dateutil<3", - "pydantic>=2.10.6,<3.0.0", + "tldextract<6.0.0", + "twilio<10", + "typing_extensions>=4.1.1,<5.0.0", ], python_requires=">=3.8", include_package_data=True, diff --git a/supertokens_python/__init__.py b/supertokens_python/__init__.py index 43d3573b2..7923bb72d 100644 --- a/supertokens_python/__init__.py +++ b/supertokens_python/__init__.py @@ -12,38 +12,61 @@ # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional from typing_extensions import Literal from supertokens_python.framework.request import BaseRequest +from supertokens_python.recipe_module import RecipeModule from supertokens_python.types import RecipeUserId -from . import supertokens -from .recipe_module import RecipeModule +from .plugins import LoadPluginsResponse +from .supertokens import ( + AppInfo, + InputAppInfo, + RecipeInit, + Supertokens, + SupertokensConfig, + SupertokensExperimentalConfig, + SupertokensInputConfig, + SupertokensPublicConfig, +) -InputAppInfo = supertokens.InputAppInfo -Supertokens = supertokens.Supertokens -SupertokensConfig = supertokens.SupertokensConfig -AppInfo = supertokens.AppInfo +# Some Pydantic models need a rebuild to resolve ForwardRefs +# Referencing imports here to prevent lint errors. +# Caveat: These will be available for import from this module directly. +RecipeModule # type: ignore + +# LoadPluginsResponse -> SupertokensPublicConfig +LoadPluginsResponse.model_rebuild() +# SupertokensInputConfig -> RecipeModule +SupertokensInputConfig.model_rebuild() def init( app_info: InputAppInfo, framework: Literal["fastapi", "flask", "django"], supertokens_config: SupertokensConfig, - recipe_list: List[Callable[[supertokens.AppInfo], RecipeModule]], + recipe_list: List[RecipeInit], mode: Optional[Literal["asgi", "wsgi"]] = None, telemetry: Optional[bool] = None, debug: Optional[bool] = None, + experimental: Optional[SupertokensExperimentalConfig] = None, ): return Supertokens.init( - app_info, framework, supertokens_config, recipe_list, mode, telemetry, debug + app_info, + framework, + supertokens_config, + recipe_list, + mode, + telemetry, + debug, + experimental=experimental, ) def get_all_cors_headers() -> List[str]: - return supertokens.Supertokens.get_instance().get_all_cors_headers() + return Supertokens.get_instance().get_all_cors_headers() def get_request_from_user_context( @@ -54,3 +77,23 @@ def get_request_from_user_context( def convert_to_recipe_user_id(user_id: str) -> RecipeUserId: return RecipeUserId(user_id) + + +is_recipe_initialized = Supertokens.is_recipe_initialized + + +__all__ = [ + "AppInfo", + "InputAppInfo", + "RecipeInit", + "RecipeUserId", + "Supertokens", + "SupertokensConfig", + "SupertokensExperimentalConfig", + "SupertokensPublicConfig", + "convert_to_recipe_user_id", + "get_all_cors_headers", + "get_request_from_user_context", + "init", + "is_recipe_initialized", +] diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index c83ae6a46..34c8a4118 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -15,7 +15,7 @@ from __future__ import annotations SUPPORTED_CDI_VERSIONS = ["5.3"] -VERSION = "0.30.1" +VERSION = "0.31.0" TELEMETRY = "/telemetry" USER_COUNT = "/users/count" USER_DELETE = "/user/remove" diff --git a/supertokens_python/exceptions.py b/supertokens_python/exceptions.py index 30f4f8514..d7ccae5f4 100644 --- a/supertokens_python/exceptions.py +++ b/supertokens_python/exceptions.py @@ -40,3 +40,7 @@ class GeneralError(SuperTokensError): class BadInputError(SuperTokensError): pass + + +class PluginError(SuperTokensError): + pass diff --git a/supertokens_python/plugins.py b/supertokens_python/plugins.py new file mode 100644 index 000000000..5e615cf39 --- /dev/null +++ b/supertokens_python/plugins.py @@ -0,0 +1,506 @@ +from collections import deque +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Set, + TypeVar, + Union, + cast, + runtime_checkable, +) + +from packaging.specifiers import SpecifierSet +from packaging.version import Version +from typing_extensions import Protocol + +from supertokens_python.constants import VERSION +from supertokens_python.framework.request import BaseRequest +from supertokens_python.framework.response import BaseResponse +from supertokens_python.logger import log_debug_message +from supertokens_python.post_init_callbacks import PostSTInitCallbacks +from supertokens_python.recipe.session.interfaces import ( + SessionClaimValidator, + SessionContainer, +) +from supertokens_python.types import MaybeAwaitable +from supertokens_python.types.base import UserContext +from supertokens_python.types.config import ( + BaseConfig, + BaseConfigWithoutAPIOverride, + BaseOverrideConfig, + BaseOverrideConfigWithoutAPI, +) +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface +from supertokens_python.types.response import CamelCaseBaseModel + +if TYPE_CHECKING: + from supertokens_python.recipe.accountlinking.types import AccountLinkingConfig + from supertokens_python.recipe.dashboard.utils import DashboardConfig + from supertokens_python.recipe.emailpassword.utils import EmailPasswordConfig + from supertokens_python.recipe.emailverification.utils import ( + EmailVerificationConfig, + ) + from supertokens_python.recipe.jwt.utils import JWTConfig + from supertokens_python.recipe.multifactorauth.types import MultiFactorAuthConfig + from supertokens_python.recipe.multitenancy.utils import MultitenancyConfig + from supertokens_python.recipe.oauth2provider.utils import OAuth2ProviderConfig + from supertokens_python.recipe.openid.utils import OpenIdConfig + from supertokens_python.recipe.passwordless.utils import PasswordlessConfig + from supertokens_python.recipe.session.utils import SessionConfig + from supertokens_python.recipe.thirdparty.utils import ThirdPartyConfig + from supertokens_python.recipe.totp.types import TOTPConfig + from supertokens_python.recipe.usermetadata.utils import UserMetadataConfig + from supertokens_python.recipe.userroles.utils import UserRolesConfig + from supertokens_python.recipe.webauthn.types.config import WebauthnConfig + from supertokens_python.supertokens import SupertokensPublicConfig + +RecipeConfigType = TypeVar( + "RecipeConfigType", + bound=Union[ + "AccountLinkingConfig", + "DashboardConfig", + "EmailPasswordConfig", + "EmailVerificationConfig", + "JWTConfig", + "MultiFactorAuthConfig", + "MultitenancyConfig", + "OAuth2ProviderConfig", + "OpenIdConfig", + "PasswordlessConfig", + "SessionConfig", + "ThirdPartyConfig", + "TOTPConfig", + "UserMetadataConfig", + "UserRolesConfig", + "WebauthnConfig", + ], +) + + +RecipeInterfaceType = TypeVar("RecipeInterfaceType", bound=BaseRecipeInterface) +APIInterfaceType = TypeVar("APIInterfaceType", bound=BaseAPIInterface) + + +class RecipeInitRequiredFunction(Protocol): + def __call__(self, sdk_version: str) -> bool: ... + + +class RecipePluginOverride: + functions: Optional[Callable[[BaseRecipeInterface], BaseRecipeInterface]] = None + apis: Optional[Callable[[BaseAPIInterface], BaseAPIInterface]] = None + config: Optional[Callable[[Any], Any]] = None + recipe_init_required: Optional[Union[bool, RecipeInitRequiredFunction]] = None + + +class PluginRouteHandlerResponse(CamelCaseBaseModel): + status: int + body: Any + + +@runtime_checkable +class PluginRouteHandlerHandlerFunction(Protocol): + async def __call__( + self, + request: BaseRequest, + response: BaseResponse, + session: Optional["SessionContainer"], + user_context: UserContext, + ) -> BaseResponse: ... + + +@runtime_checkable +class OverrideGlobalClaimValidatorsFunction(Protocol): + def __call__( + self, + global_claim_validators: List["SessionClaimValidator"], + session: "SessionContainer", + user_context: UserContext, + ) -> MaybeAwaitable[List["SessionClaimValidator"]]: ... + + +class VerifySessionOptions(CamelCaseBaseModel): + session_required: bool + anti_csrf_check: Optional[bool] = None + check_database: bool + override_global_claim_validators: Optional[ + OverrideGlobalClaimValidatorsFunction + ] = None + + +class PluginRouteHandler(CamelCaseBaseModel): + method: str + path: str + handler: PluginRouteHandlerHandlerFunction + verify_session_options: Optional[VerifySessionOptions] + + +class PluginRouteHandlerWithPluginId(PluginRouteHandler): + plugin_id: str + """ + This is useful when multiple plugins handle the same route. + """ + + @classmethod + def from_route_handler( + cls, + route_handler: PluginRouteHandler, + plugin_id: str, + ): + return cls( + **route_handler.model_dump(), + plugin_id=plugin_id, + ) + + +@runtime_checkable +class SuperTokensPluginInit(Protocol): + def __call__( + self, + config: "SupertokensPublicConfig", + all_plugins: List["SuperTokensPublicPlugin"], + sdk_version: str, + ) -> None: ... + + +class PluginDependenciesOkResponse(CamelCaseBaseModel): + status: Literal["OK"] = "OK" + plugins_to_add: List["SuperTokensPlugin"] + + +class PluginDependenciesErrorResponse(CamelCaseBaseModel): + status: Literal["ERROR"] = "ERROR" + message: str + + +@runtime_checkable +class SuperTokensPluginDependencies(Protocol): + def __call__( + self, + config: "SupertokensPublicConfig", + plugins_above: List["SuperTokensPublicPlugin"], + sdk_version: str, + ) -> Union[PluginDependenciesOkResponse, PluginDependenciesErrorResponse]: ... + + +class PluginRouteHandlerFunctionOkResponse(CamelCaseBaseModel): + status: Literal["OK"] = "OK" + route_handlers: List[PluginRouteHandler] + + +class PluginRouteHandlerFunctionErrorResponse(CamelCaseBaseModel): + status: Literal["ERROR"] = "ERROR" + message: str + + +@runtime_checkable +class PluginRouteHandlerFunction(Protocol): + def __call__( + self, + config: "SupertokensPublicConfig", + all_plugins: List["SuperTokensPublicPlugin"], + sdk_version: str, + ) -> Union[ + PluginRouteHandlerFunctionOkResponse, PluginRouteHandlerFunctionErrorResponse + ]: ... + + +@runtime_checkable +class PluginConfig(Protocol): + def __call__( + self, config: "SupertokensPublicConfig" + ) -> "SupertokensPublicConfig": ... + + +class SuperTokensPluginBase(CamelCaseBaseModel): + id: str + version: Optional[str] = None + compatible_sdk_versions: Union[str, List[str]] + exports: Optional[Dict[str, Any]] = None + + +OverrideMap = Dict[str, RecipePluginOverride] + + +class SuperTokensPlugin(SuperTokensPluginBase): + init: Optional[SuperTokensPluginInit] = None + dependencies: Optional[SuperTokensPluginDependencies] = None + override_map: Optional[OverrideMap] = None + route_handlers: Optional[ + Union[List[PluginRouteHandler], PluginRouteHandlerFunction] + ] = None + config: Optional[PluginConfig] = None + + def get_dependencies( + self, + public_config: "SupertokensPublicConfig", + plugins_above: List["SuperTokensPlugin"], + sdk_version: str, + ): + """ + Pre-order DFS traversal to get all dependencies of a plugin. + """ + + def recurse_deps( + plugin: SuperTokensPlugin, + deps: Optional[List[SuperTokensPlugin]] = None, + visited: Optional[Set[str]] = None, + ) -> List[SuperTokensPlugin]: + if deps is None: + deps = [] + + if visited is None: + visited = set() + + if plugin.id in visited: + return deps + visited.add(plugin.id) + + if plugin.dependencies is not None: + # Get all dependencies of the plugin + dep_result = plugin.dependencies( + config=public_config, + plugins_above=[ + SuperTokensPublicPlugin.from_plugin(plugin) + for plugin in plugins_above + ], + sdk_version=sdk_version, + ) + + # Errors fall through + if isinstance(dep_result, PluginDependenciesErrorResponse): + raise Exception(dep_result.message) + + # Recurse through all dependencies and add the resultant plugins to the list + # Pre-order DFS traversal + for dep_plugin in dep_result.plugins_to_add: + recurse_deps(dep_plugin, deps) + + # Add the current plugin and mark it as visited + deps.append(plugin) + + return deps + + return recurse_deps(self) + + +class SuperTokensPublicPlugin(SuperTokensPluginBase): + initialized: bool + + @classmethod + def from_plugin(cls, plugin: SuperTokensPlugin) -> "SuperTokensPublicPlugin": + return cls( + id=plugin.id, + initialized=plugin.init is None, + version=plugin.version, + exports=plugin.exports, + compatible_sdk_versions=plugin.compatible_sdk_versions, + ) + + +def apply_plugins( + recipe_id: str, + config: RecipeConfigType, + plugins: List[OverrideMap], +) -> RecipeConfigType: + if not isinstance(config, (BaseConfig, BaseConfigWithoutAPIOverride)): # type: ignore + raise TypeError( + f"Expected config to be an instance of BaseConfig or BaseConfigWithoutAPIOverride. {recipe_id=} {config=}" + ) + + def default_fn_override( + original_implementation: RecipeInterfaceType, + ) -> RecipeInterfaceType: + return original_implementation + + def default_api_override( + original_implementation: APIInterfaceType, + ) -> APIInterfaceType: + return original_implementation + + if config.override is None: + if isinstance(config, BaseConfigWithoutAPIOverride): + config.override = BaseOverrideConfigWithoutAPI() + else: + config.override = BaseOverrideConfig() # type: ignore - generic type invariance + + function_overrides = getattr(config.override, "functions", default_fn_override) + api_overrides = getattr(config.override, "apis", default_api_override) + + function_layers: deque[Any] = deque() + api_layers: deque[Any] = deque() + + # If we have plugins like 4->3->(2, 1) along with a recipe override, + # we want to load/init them as: override, 2, 1, 3, 4 + # and call them as: override, 4, 3, 2, 1, original + # Order of 1/2 does not matter since they are independent from each other. + + for plugin in plugins: + overrides = plugin.get(recipe_id) + if overrides is not None: + if overrides.config is not None: + config = overrides.config(config) + + if overrides.functions is not None: + function_layers.append(overrides.functions) + if overrides.apis is not None: + api_layers.append(overrides.apis) + + if function_overrides is not None: + function_layers.append(function_overrides) + if api_overrides is not None: + api_layers.append(api_overrides) + + # Apply overrides in reverse order of definition + # Plugins: [plugin1, plugin2] would be applied as [override, plugin2, plugin1, original] + if len(function_layers) > 0: + + def fn_override( + original_implementation: RecipeInterfaceType, + ) -> RecipeInterfaceType: + # The layers will get called in reversed order + for function_layer in function_layers: + original_implementation = function_layer(original_implementation) + return original_implementation + + config.override.functions = fn_override # type: ignore + + if ( + len(api_layers) > 0 + # AccountLinking recipe does not have an API implementation, uses `BaseConfigWithoutAPIOverride` as base + and recipe_id != "accountlinking" + # `BaseConfig` is the base class for all configs with an API override. + and isinstance(config, BaseConfig) + ): + + def api_override(original_implementation: APIInterfaceType) -> APIInterfaceType: + for api_layer in api_layers: + original_implementation = api_layer(original_implementation) + return original_implementation + + config.override.apis = api_override # type: ignore + + return config + + +class LoadPluginsResponse(CamelCaseBaseModel): + public_config: "SupertokensPublicConfig" + processed_plugins: List[SuperTokensPublicPlugin] + plugin_route_handlers: List[PluginRouteHandlerWithPluginId] + override_maps: List[OverrideMap] + + +def load_plugins( + plugins: List[SuperTokensPlugin], + public_config: "SupertokensPublicConfig", +) -> LoadPluginsResponse: + input_plugin_seen_list: Set[str] = set() + final_plugin_list: List[SuperTokensPlugin] = [] + plugin_route_handlers: List[PluginRouteHandlerWithPluginId] = [] + + for plugin in plugins: + if plugin.id in input_plugin_seen_list: + log_debug_message(f"Skipping {plugin.id=} as it has already been added") + continue + + if isinstance(plugin.compatible_sdk_versions, list): + version_constraints = ",".join(plugin.compatible_sdk_versions) + else: + version_constraints = plugin.compatible_sdk_versions + + if not SpecifierSet(version_constraints).contains(Version(VERSION)): + raise Exception( + f"Incompatible SDK version for plugin {plugin.id}. " + f"Version {VERSION} not found in compatible versions {version_constraints}" + ) + + # TODO: Overkill, but could topologically sort the plugins based on dependencies + dependencies = plugin.get_dependencies( + public_config=public_config, + plugins_above=final_plugin_list, + sdk_version=VERSION, + ) + final_plugin_list.extend(dependencies) + input_plugin_seen_list.update({dep.id for dep in dependencies}) + + # Secondary check to ensure no duplicate plugins + # Should ideally be handled in the dependency resolution above. + unique_plugins: Set[str] = set() + duplicate_plugins: List[str] = [] + for plugin in final_plugin_list: + if plugin.id in unique_plugins: + duplicate_plugins.append(plugin.id) + unique_plugins.add(plugin.id) + + if len(duplicate_plugins) > 0: + raise Exception(f"Duplicate plugins found: {', '.join(duplicate_plugins)}") + + processed_plugin_list = [ + SuperTokensPublicPlugin.from_plugin(plugin) for plugin in final_plugin_list + ] + + for plugin_idx, plugin in enumerate(final_plugin_list): + # Override the public supertokens config using the config override defined in the plugin + if plugin.config is not None: + public_config = plugin.config(public_config) + + if plugin.route_handlers is not None: + handlers: List[PluginRouteHandler] = [] + + if callable(plugin.route_handlers): + handler_result = plugin.route_handlers( + config=public_config, + all_plugins=processed_plugin_list, + sdk_version=VERSION, + ) + if isinstance(handler_result, PluginRouteHandlerFunctionErrorResponse): + raise Exception(handler_result.message) + + handlers = handler_result.route_handlers + else: + handlers = plugin.route_handlers + + plugin_route_handlers.extend( + [ + PluginRouteHandlerWithPluginId.from_route_handler( + handler, plugin.id + ) + for handler in handlers + ] + ) + + if plugin.init is not None: + + def callback_factory(): + # This has to be part of the factory to ensure we pick up the correct plugin + init_fn = cast(SuperTokensPluginInit, plugin.init) + idx = plugin_idx + + def callback(): + init_fn( + config=public_config, + all_plugins=processed_plugin_list, + sdk_version=VERSION, + ) + processed_plugin_list[idx].initialized = True + + return callback + + PostSTInitCallbacks.add_post_init_callback(callback_factory()) + + override_maps = [ + plugin.override_map + for plugin in final_plugin_list + if plugin.override_map is not None + ] + + return LoadPluginsResponse( + public_config=public_config, + processed_plugins=processed_plugin_list, + plugin_route_handlers=plugin_route_handlers, + override_maps=override_maps, + ) diff --git a/supertokens_python/recipe/accountlinking/__init__.py b/supertokens_python/recipe/accountlinking/__init__.py index 97d295bd3..c64e1051e 100644 --- a/supertokens_python/recipe/accountlinking/__init__.py +++ b/supertokens_python/recipe/accountlinking/__init__.py @@ -13,18 +13,22 @@ # under the License. from __future__ import annotations -from typing import Any, Awaitable, Callable, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Union + +from supertokens_python.types import User -from ...types import User -from ..session.interfaces import SessionContainer -from . import types from .recipe import AccountLinkingRecipe +from .types import ( + AccountInfoWithRecipeIdAndUserId, + AccountLinkingOverrideConfig, + InputOverrideConfig, + RecipeLevelUser, + ShouldAutomaticallyLink, + ShouldNotAutomaticallyLink, +) -InputOverrideConfig = types.InputOverrideConfig -RecipeLevelUser = types.RecipeLevelUser -AccountInfoWithRecipeIdAndUserId = types.AccountInfoWithRecipeIdAndUserId -ShouldAutomaticallyLink = types.ShouldAutomaticallyLink -ShouldNotAutomaticallyLink = types.ShouldNotAutomaticallyLink +if TYPE_CHECKING: + from ..session.interfaces import SessionContainer def init( @@ -43,8 +47,20 @@ def init( Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], ] ] = None, - override: Optional[InputOverrideConfig] = None, + override: Optional[AccountLinkingOverrideConfig] = None, ): return AccountLinkingRecipe.init( on_account_linked, should_do_automatic_account_linking, override ) + + +__all__ = [ + "AccountInfoWithRecipeIdAndUserId", + "AccountLinkingOverrideConfig", + "AccountLinkingRecipe", + "InputOverrideConfig", # deprecated, use AccountLinkingOverrideConfig instead + "RecipeLevelUser", + "ShouldAutomaticallyLink", + "ShouldNotAutomaticallyLink", + "init", +] diff --git a/supertokens_python/recipe/accountlinking/interfaces.py b/supertokens_python/recipe/accountlinking/interfaces.py index 105b5bda5..04274b6f0 100644 --- a/supertokens_python/recipe/accountlinking/interfaces.py +++ b/supertokens_python/recipe/accountlinking/interfaces.py @@ -13,12 +13,13 @@ # under the License. from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing_extensions import Literal from supertokens_python.types.base import AccountInfoInput +from supertokens_python.types.recipe import BaseRecipeInterface if TYPE_CHECKING: from supertokens_python.types import ( @@ -27,7 +28,7 @@ ) -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): @abstractmethod async def get_users( self, diff --git a/supertokens_python/recipe/accountlinking/recipe.py b/supertokens_python/recipe/accountlinking/recipe.py index 3b34801ca..cdaf50185 100644 --- a/supertokens_python/recipe/accountlinking/recipe.py +++ b/supertokens_python/recipe/accountlinking/recipe.py @@ -26,7 +26,6 @@ from supertokens_python.process_state import PROCESS_STATE, ProcessState from supertokens_python.querier import Querier from supertokens_python.recipe_module import APIHandled, RecipeModule -from supertokens_python.supertokens import Supertokens from supertokens_python.types.base import AccountInfoInput from .interfaces import RecipeInterface @@ -34,7 +33,8 @@ from .types import ( AccountInfoWithRecipeId, AccountInfoWithRecipeIdAndUserId, - InputOverrideConfig, + AccountLinkingConfig, + AccountLinkingOverrideConfig, RecipeLevelUser, ShouldAutomaticallyLink, ShouldNotAutomaticallyLink, @@ -77,35 +77,16 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - on_account_linked: Optional[ - Callable[[User, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] - ] = None, - should_do_automatic_account_linking: Optional[ - Callable[ - [ - AccountInfoWithRecipeIdAndUserId, - Optional[User], - Optional[SessionContainer], - str, - Dict[str, Any], - ], - Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], - ] - ] = None, - override: Optional[InputOverrideConfig] = None, + config: AccountLinkingConfig, ): super().__init__(recipe_id, app_info) - self.config = validate_and_normalise_user_input( - app_info, on_account_linked, should_do_automatic_account_linking, override - ) + self.config = validate_and_normalise_user_input(app_info, config=config) recipe_implementation: RecipeInterface = RecipeImplementation( Querier.get_instance(recipe_id), self, self.config ) - self.recipe_implementation: RecipeInterface = ( + self.recipe_implementation: RecipeInterface = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) self.email_verification_recipe: EmailVerificationRecipe | None = None @@ -162,16 +143,26 @@ def init( Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], ] ] = None, - override: Optional[InputOverrideConfig] = None, - ): - def func(app_info: AppInfo): + override: Optional[AccountLinkingOverrideConfig] = None, + ) -> Callable[..., AccountLinkingRecipe]: + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = AccountLinkingConfig( + on_account_linked=on_account_linked, + should_do_automatic_account_linking=should_do_automatic_account_linking, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if AccountLinkingRecipe.__instance is None: AccountLinkingRecipe.__instance = AccountLinkingRecipe( - AccountLinkingRecipe.recipe_id, - app_info, - on_account_linked, - should_do_automatic_account_linking, - override, + recipe_id=AccountLinkingRecipe.recipe_id, + app_info=app_info, + config=apply_plugins( + recipe_id=AccountLinkingRecipe.recipe_id, + config=config, + plugins=plugins, + ), ) return AccountLinkingRecipe.__instance raise Exception( @@ -183,11 +174,12 @@ def func(app_info: AppInfo): @staticmethod def get_instance() -> AccountLinkingRecipe: - if AccountLinkingRecipe.__instance is None: - AccountLinkingRecipe.init()(Supertokens.get_instance().app_info) + if AccountLinkingRecipe.__instance is not None: + return AccountLinkingRecipe.__instance - assert AccountLinkingRecipe.__instance is not None - return AccountLinkingRecipe.__instance + raise_general_exception( + "Initialisation not done. Did you forget to call the SuperTokens.init function?" + ) @staticmethod def reset(): diff --git a/supertokens_python/recipe/accountlinking/recipe_implementation.py b/supertokens_python/recipe/accountlinking/recipe_implementation.py index 2eb9a89e6..43991ab5c 100644 --- a/supertokens_python/recipe/accountlinking/recipe_implementation.py +++ b/supertokens_python/recipe/accountlinking/recipe_implementation.py @@ -40,7 +40,7 @@ RecipeInterface, UnlinkAccountOkResult, ) -from .types import AccountLinkingConfig, RecipeLevelUser +from .types import NormalisedAccountLinkingConfig, RecipeLevelUser if TYPE_CHECKING: from supertokens_python.querier import Querier @@ -53,7 +53,7 @@ def __init__( self, querier: Querier, recipe_instance: AccountLinkingRecipe, - config: AccountLinkingConfig, + config: NormalisedAccountLinkingConfig, ): super().__init__() self.querier = querier diff --git a/supertokens_python/recipe/accountlinking/types.py b/supertokens_python/recipe/accountlinking/types.py index 22977eabe..74eb752fb 100644 --- a/supertokens_python/recipe/accountlinking/types.py +++ b/supertokens_python/recipe/accountlinking/types.py @@ -21,17 +21,30 @@ from supertokens_python.recipe.accountlinking.interfaces import ( RecipeInterface, ) -from supertokens_python.types import AccountInfo +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.types import ( + AccountInfo, + LoginMethod, + RecipeUserId, + User, +) +from supertokens_python.types.config import ( + BaseConfigWithoutAPIOverride, + BaseNormalisedConfigWithoutAPIOverride, + BaseNormalisedOverrideConfigWithoutAPI, + BaseOverrideConfigWithoutAPI, +) if TYPE_CHECKING: - from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo from supertokens_python.recipe.webauthn.types.base import WebauthnInfo - from supertokens_python.types import ( - LoginMethod, - RecipeUserId, - User, - ) + +AccountLinkingOverrideConfig = BaseOverrideConfigWithoutAPI[RecipeInterface] +NormalisedAccountLinkingOverrideConfig = BaseNormalisedOverrideConfigWithoutAPI[ + RecipeInterface +] +InputOverrideConfig = AccountLinkingOverrideConfig +"""Deprecated, use `AccountLinkingOverrideConfig` instead.""" class AccountInfoWithRecipeId(AccountInfo): @@ -131,29 +144,12 @@ def __init__(self, should_require_verification: bool): self.should_require_verification = should_require_verification -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - ): - self.functions = functions - - -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - ): - self.functions = functions - - -class AccountLinkingConfig: - def __init__( - self, - on_account_linked: Callable[ - [User, RecipeLevelUser, Dict[str, Any]], Awaitable[None] - ], - should_do_automatic_account_linking: Callable[ +class AccountLinkingConfig(BaseConfigWithoutAPIOverride[RecipeInterface]): + on_account_linked: Optional[ + Callable[[User, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] + ] = None + should_do_automatic_account_linking: Optional[ + Callable[ [ AccountInfoWithRecipeIdAndUserId, Optional[User], @@ -162,9 +158,23 @@ def __init__( Dict[str, Any], ], Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], + ] + ] = None + + +class NormalisedAccountLinkingConfig( + BaseNormalisedConfigWithoutAPIOverride[RecipeInterface] +): + on_account_linked: Callable[ + [User, RecipeLevelUser, Dict[str, Any]], Awaitable[None] + ] + should_do_automatic_account_linking: Callable[ + [ + AccountInfoWithRecipeIdAndUserId, + Optional[User], + Optional[SessionContainer], + str, + Dict[str, Any], ], - override: OverrideConfig, - ): - self.on_account_linked = on_account_linked - self.should_do_automatic_account_linking = should_do_automatic_account_linking - self.override = override + Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], + ] diff --git a/supertokens_python/recipe/accountlinking/utils.py b/supertokens_python/recipe/accountlinking/utils.py index 763bfd48e..8a496de2a 100644 --- a/supertokens_python/recipe/accountlinking/utils.py +++ b/supertokens_python/recipe/accountlinking/utils.py @@ -13,22 +13,22 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Union - -if TYPE_CHECKING: - from .types import ( - AccountInfoWithRecipeIdAndUserId, - AccountLinkingConfig, - InputOverrideConfig, - RecipeLevelUser, - SessionContainer, - ShouldAutomaticallyLink, - ShouldNotAutomaticallyLink, - User, - ) +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +from supertokens_python.recipe.accountlinking.types import ( + AccountInfoWithRecipeIdAndUserId, + AccountLinkingConfig, + NormalisedAccountLinkingConfig, + NormalisedAccountLinkingOverrideConfig, + RecipeLevelUser, + ShouldAutomaticallyLink, + ShouldNotAutomaticallyLink, +) if TYPE_CHECKING: + from supertokens_python.recipe.session.interfaces import SessionContainer from supertokens_python.supertokens import AppInfo + from supertokens_python.types.base import User async def default_on_account_linked(_: User, __: RecipeLevelUser, ___: Dict[str, Any]): @@ -58,51 +58,28 @@ def recipe_init_defined_should_do_automatic_account_linking() -> bool: def validate_and_normalise_user_input( _: AppInfo, - on_account_linked: Optional[ - Callable[[User, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] - ] = None, - should_do_automatic_account_linking: Optional[ - Callable[ - [ - AccountInfoWithRecipeIdAndUserId, - Optional[User], - Optional[SessionContainer], - str, - Dict[str, Any], - ], - Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], - ] - ] = None, - override: Union[InputOverrideConfig, None] = None, -) -> AccountLinkingConfig: - from .types import ( - AccountLinkingConfig as ALC, - ) - from .types import ( - InputOverrideConfig as IOC, - ) - from .types import ( - OverrideConfig, - ) - + config: AccountLinkingConfig, +) -> NormalisedAccountLinkingConfig: global _did_use_default_should_do_automatic_account_linking - if override is None: - override = IOC() + + override_config = NormalisedAccountLinkingOverrideConfig.from_input_config( + override_config=config.override + ) _did_use_default_should_do_automatic_account_linking = ( - should_do_automatic_account_linking is None + config.should_do_automatic_account_linking is None ) - return ALC( - override=OverrideConfig(functions=override.functions), + return NormalisedAccountLinkingConfig( + override=override_config, on_account_linked=( default_on_account_linked - if on_account_linked is None - else on_account_linked + if config.on_account_linked is None + else config.on_account_linked ), should_do_automatic_account_linking=( default_should_do_automatic_account_linking - if should_do_automatic_account_linking is None - else should_do_automatic_account_linking + if config.should_do_automatic_account_linking is None + else config.should_do_automatic_account_linking ), ) diff --git a/supertokens_python/recipe/dashboard/__init__.py b/supertokens_python/recipe/dashboard/__init__.py index 46f46f9a2..f09638c68 100644 --- a/supertokens_python/recipe/dashboard/__init__.py +++ b/supertokens_python/recipe/dashboard/__init__.py @@ -14,23 +14,29 @@ from __future__ import annotations -from typing import Callable, List, Optional +from typing import List, Optional -from supertokens_python import AppInfo, RecipeModule -from supertokens_python.recipe.dashboard import utils +from supertokens_python.supertokens import RecipeInit from .recipe import DashboardRecipe - -InputOverrideConfig = utils.InputOverrideConfig +from .utils import DashboardOverrideConfig, InputOverrideConfig def init( api_key: Optional[str] = None, admins: Optional[List[str]] = None, - override: Optional[InputOverrideConfig] = None, -) -> Callable[[AppInfo], RecipeModule]: + override: Optional[DashboardOverrideConfig] = None, +) -> RecipeInit: return DashboardRecipe.init( api_key, admins, override, ) + + +__all__ = [ + "DashboardOverrideConfig", + "DashboardRecipe", + "InputOverrideConfig", # deprecated, use DashboardOverrideConfig instead + "init", +] diff --git a/supertokens_python/recipe/dashboard/api/__init__.py b/supertokens_python/recipe/dashboard/api/__init__.py index 3ae3a869a..527f08fe2 100644 --- a/supertokens_python/recipe/dashboard/api/__init__.py +++ b/supertokens_python/recipe/dashboard/api/__init__.py @@ -33,24 +33,24 @@ from .validate_key import handle_validate_key_api __all__ = [ - "handle_dashboard_api", "api_key_protector", - "handle_users_count_get_api", - "handle_users_get_api", - "handle_validate_key_api", - "handle_user_email_verify_get", - "handle_user_get", + "handle_analytics_post", + "handle_dashboard_api", + "handle_email_verify_token_post", + "handle_emailpassword_signin_api", + "handle_emailpassword_signout_api", + "handle_get_tags", "handle_metadata_get", + "handle_metadata_put", "handle_sessions_get", "handle_user_delete", - "handle_user_put", + "handle_user_email_verify_get", "handle_user_email_verify_put", - "handle_metadata_put", - "handle_user_sessions_post", + "handle_user_get", "handle_user_password_put", - "handle_email_verify_token_post", - "handle_emailpassword_signin_api", - "handle_emailpassword_signout_api", - "handle_get_tags", - "handle_analytics_post", + "handle_user_put", + "handle_user_sessions_post", + "handle_users_count_get_api", + "handle_users_get_api", + "handle_validate_key_api", ] diff --git a/supertokens_python/recipe/dashboard/interfaces.py b/supertokens_python/recipe/dashboard/interfaces.py index a0e8564a4..59c20ff88 100644 --- a/supertokens_python/recipe/dashboard/interfaces.py +++ b/supertokens_python/recipe/dashboard/interfaces.py @@ -13,12 +13,13 @@ # under the License. from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union from typing_extensions import Literal from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from ...types.response import APIResponse @@ -27,7 +28,7 @@ from supertokens_python.recipe.session.interfaces import SessionInformationResult from ...supertokens import AppInfo - from .utils import DashboardConfig, UserWithMetadata + from .utils import NormalisedDashboardConfig, UserWithMetadata class SessionInfo: @@ -41,7 +42,7 @@ def __init__(self, info: SessionInformationResult) -> None: self.tenant_id = info.tenant_id -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): def __init__(self): pass @@ -53,7 +54,7 @@ async def get_dashboard_bundle_location(self, user_context: Dict[str, Any]) -> s async def should_allow_access( self, request: BaseRequest, - config: DashboardConfig, + config: NormalisedDashboardConfig, user_context: Dict[str, Any], ) -> bool: pass @@ -65,19 +66,19 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: DashboardConfig, + config: NormalisedDashboardConfig, recipe_implementation: RecipeInterface, app_info: AppInfo, ): self.request: BaseRequest = request self.response: BaseResponse = response self.recipe_id: str = recipe_id - self.config: DashboardConfig = config + self.config: NormalisedDashboardConfig = config self.recipe_implementation: RecipeInterface = recipe_implementation self.app_info = app_info -class APIInterface: +class APIInterface(BaseAPIInterface): def __init__(self): # undefined should be allowed self.dashboard_get: Optional[ diff --git a/supertokens_python/recipe/dashboard/recipe.py b/supertokens_python/recipe/dashboard/recipe.py index f789aea66..aae27c256 100644 --- a/supertokens_python/recipe/dashboard/recipe.py +++ b/supertokens_python/recipe/dashboard/recipe.py @@ -149,7 +149,8 @@ VALIDATE_KEY_API, ) from .utils import ( - InputOverrideConfig, + DashboardConfig, + DashboardOverrideConfig, validate_and_normalise_user_input, ) @@ -162,29 +163,19 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - api_key: Optional[str], - admins: Optional[List[str]], - override: Optional[InputOverrideConfig] = None, + config: DashboardConfig, ): super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( - api_key, - admins, - override, + config=config, ) recipe_implementation = RecipeImplementation() - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) + self.api_implementation = self.config.override.apis(api_implementation) def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: return isinstance(err, SuperTokensError) and ( @@ -650,16 +641,26 @@ def get_all_cors_headers(self) -> List[str]: def init( api_key: Optional[str], admins: Optional[List[str]] = None, - override: Optional[InputOverrideConfig] = None, + override: Optional[DashboardOverrideConfig] = None, ): - def func(app_info: AppInfo): + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = DashboardConfig( + api_key=api_key, + admins=admins, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if DashboardRecipe.__instance is None: DashboardRecipe.__instance = DashboardRecipe( - DashboardRecipe.recipe_id, - app_info, - api_key, - admins, - override, + recipe_id=DashboardRecipe.recipe_id, + app_info=app_info, + config=apply_plugins( + recipe_id=DashboardRecipe.recipe_id, + config=config, + plugins=plugins, + ), ) return DashboardRecipe.__instance raise Exception( diff --git a/supertokens_python/recipe/dashboard/recipe_implementation.py b/supertokens_python/recipe/dashboard/recipe_implementation.py index 881c14ebd..e46615ebd 100644 --- a/supertokens_python/recipe/dashboard/recipe_implementation.py +++ b/supertokens_python/recipe/dashboard/recipe_implementation.py @@ -27,7 +27,7 @@ from .exceptions import DashboardOperationNotAllowedError from .interfaces import RecipeInterface -from .utils import DashboardConfig, validate_api_key +from .utils import NormalisedDashboardConfig, validate_api_key class RecipeImplementation(RecipeInterface): @@ -37,7 +37,7 @@ async def get_dashboard_bundle_location(self, user_context: Dict[str, Any]) -> s async def should_allow_access( self, request: BaseRequest, - config: DashboardConfig, + config: NormalisedDashboardConfig, user_context: Dict[str, Any], ) -> bool: # For cases where we're not using the API key, the JWT is being used; we allow their access by default diff --git a/supertokens_python/recipe/dashboard/utils.py b/supertokens_python/recipe/dashboard/utils.py index 50324c0fd..d47e945f0 100644 --- a/supertokens_python/recipe/dashboard/utils.py +++ b/supertokens_python/recipe/dashboard/utils.py @@ -13,11 +13,17 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing_extensions import Literal from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.types.config import ( + BaseConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, + BaseOverrideConfig, +) if TYPE_CHECKING: from supertokens_python.framework.request import BaseRequest @@ -45,9 +51,7 @@ USERS_LIST_GET_API, VALIDATE_KEY_API, ) - -if TYPE_CHECKING: - from .interfaces import APIInterface, RecipeInterface +from .interfaces import APIInterface, RecipeInterface class UserWithMetadata: @@ -73,64 +77,49 @@ def to_json(self) -> Dict[str, Any]: return user_json -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +DashboardOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedDashboardOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] +InputOverrideConfig = DashboardOverrideConfig +"""Deprecated, use `DashboardOverrideConfig` instead.""" -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class DashboardConfig(BaseConfig[RecipeInterface, APIInterface]): + api_key: Optional[str] = None + admins: Optional[List[str]] = None -class DashboardConfig: - def __init__( - self, - api_key: Optional[str], - admins: Optional[List[str]], - override: OverrideConfig, - auth_mode: str, - ): - self.api_key = api_key - self.admins = admins - self.override = override - self.auth_mode = auth_mode +class NormalisedDashboardConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): + api_key: Optional[str] + admins: Optional[List[str]] + auth_mode: str def validate_and_normalise_user_input( - # app_info: AppInfo, - api_key: Union[str, None], - admins: Optional[List[str]], - override: Optional[InputOverrideConfig] = None, -) -> DashboardConfig: - if override is None: - override = InputOverrideConfig() - - if api_key is not None and admins is not None: + config: DashboardConfig, +) -> NormalisedDashboardConfig: + override_config = NormalisedDashboardOverrideConfig.from_input_config( + override_config=config.override + ) + + if config.api_key is not None and config.admins is not None: log_debug_message( "User Dashboard: Providing 'admins' has no effect when using an api key." ) - admins = [normalise_email(a) for a in admins] if admins is not None else None + admins = ( + [normalise_email(a) for a in config.admins] + if config.admins is not None + else None + ) + auth_mode = "api-key" if config.api_key else "email-password" - return DashboardConfig( - api_key, - admins, - OverrideConfig( - functions=override.functions, - apis=override.apis, - ), - "api-key" if api_key else "email-password", + return NormalisedDashboardConfig( + api_key=config.api_key, + admins=admins, + auth_mode=auth_mode, + override=override_config, ) @@ -262,7 +251,7 @@ async def _get_user_for_recipe_id( async def validate_api_key( - req: BaseRequest, config: DashboardConfig, _user_context: Dict[str, Any] + req: BaseRequest, config: NormalisedDashboardConfig, _user_context: Dict[str, Any] ) -> bool: api_key_header_value = req.get_header("authorization") if not api_key_header_value: diff --git a/supertokens_python/recipe/emailpassword/__init__.py b/supertokens_python/recipe/emailpassword/__init__.py index 731009e43..01d78d4a9 100644 --- a/supertokens_python/recipe/emailpassword/__init__.py +++ b/supertokens_python/recipe/emailpassword/__init__.py @@ -13,37 +13,48 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Union -from supertokens_python.ingredients.emaildelivery import types as emaildelivery_types -from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig +from supertokens_python.ingredients.emaildelivery.types import ( + EmailDeliveryConfig, + EmailDeliveryInterface, +) from supertokens_python.recipe.emailpassword.types import EmailTemplateVars -from . import exceptions as ex -from . import utils -from .emaildelivery import services as emaildelivery_services +from .emaildelivery.services import SMTPService from .recipe import EmailPasswordRecipe - -exceptions = ex -InputOverrideConfig = utils.InputOverrideConfig -InputSignUpFeature = utils.InputSignUpFeature -InputFormField = utils.InputFormField -SMTPService = emaildelivery_services.SMTPService -EmailDeliveryInterface = emaildelivery_types.EmailDeliveryInterface +from .utils import ( + EmailPasswordOverrideConfig, + InputFormField, + InputOverrideConfig, + InputSignUpFeature, +) if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( - sign_up_feature: Union[utils.InputSignUpFeature, None] = None, - override: Union[utils.InputOverrideConfig, None] = None, + sign_up_feature: Union[InputSignUpFeature, None] = None, + override: Union[EmailPasswordOverrideConfig, None] = None, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, -) -> Callable[[AppInfo], RecipeModule]: +) -> RecipeInit: return EmailPasswordRecipe.init( sign_up_feature, override, email_delivery, ) + + +__all__ = [ + "EmailDeliveryConfig", + "EmailDeliveryInterface", + "EmailPasswordOverrideConfig", + "EmailPasswordRecipe", + "EmailTemplateVars", + "InputFormField", + "InputOverrideConfig", # deprecated, use EmailPasswordOverrideConfig instead + "InputSignUpFeature", + "SMTPService", + "init", +] diff --git a/supertokens_python/recipe/emailpassword/api/utils.py b/supertokens_python/recipe/emailpassword/api/utils.py index 06af69e9c..d35d96234 100644 --- a/supertokens_python/recipe/emailpassword/api/utils.py +++ b/supertokens_python/recipe/emailpassword/api/utils.py @@ -13,7 +13,7 @@ # under the License. from __future__ import annotations -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, cast from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.emailpassword.constants import ( @@ -78,7 +78,9 @@ async def validate_form_fields_or_throw_error( form_fields: List[FormField] = [] - form_fields_list_raw: List[Dict[str, Any]] = form_fields_raw + form_fields_list_raw: List[Dict[str, Any]] = cast( + List[Dict[str, Any]], form_fields_raw + ) for current_form_field in form_fields_list_raw: if ( "id" not in current_form_field diff --git a/supertokens_python/recipe/emailpassword/interfaces.py b/supertokens_python/recipe/emailpassword/interfaces.py index 96df9426e..b3f669d71 100644 --- a/supertokens_python/recipe/emailpassword/interfaces.py +++ b/supertokens_python/recipe/emailpassword/interfaces.py @@ -13,12 +13,13 @@ # under the License. from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Union from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.recipe.emailpassword.types import EmailTemplateVars from supertokens_python.types.auth_utils import LinkingToSessionUserFailedError +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from ...supertokens import AppInfo from ...types import ( @@ -32,7 +33,7 @@ from ...types import User from .types import FormField - from .utils import EmailPasswordConfig + from .utils import NormalisedEmailPasswordConfig class SignUpOkResult: @@ -118,7 +119,7 @@ def to_json(self) -> Dict[str, Any]: } -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): def __init__(self): pass @@ -218,7 +219,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: EmailPasswordConfig, + config: NormalisedEmailPasswordConfig, recipe_implementation: RecipeInterface, app_info: AppInfo, email_delivery: EmailDeliveryIngredient[EmailTemplateVars], @@ -226,7 +227,7 @@ def __init__( self.request: BaseRequest = request self.response: BaseResponse = response self.recipe_id: str = recipe_id - self.config: EmailPasswordConfig = config + self.config: NormalisedEmailPasswordConfig = config self.recipe_implementation: RecipeInterface = recipe_implementation self.app_info = app_info self.email_delivery = email_delivery @@ -315,7 +316,7 @@ def to_json(self) -> Dict[str, Any]: return {"status": self.status, "reason": self.reason} -class APIInterface: +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_email_exists_get = False self.disable_generate_password_reset_token_post = False diff --git a/supertokens_python/recipe/emailpassword/recipe.py b/supertokens_python/recipe/emailpassword/recipe.py index 9a01031e3..5c506986d 100644 --- a/supertokens_python/recipe/emailpassword/recipe.py +++ b/supertokens_python/recipe/emailpassword/recipe.py @@ -14,7 +14,7 @@ from __future__ import annotations from os import environ -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from supertokens_python.auth_utils import is_fake_email from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient @@ -68,7 +68,8 @@ USER_PASSWORD_RESET_TOKEN, ) from .utils import ( - InputOverrideConfig, + EmailPasswordConfig, + EmailPasswordOverrideConfig, InputSignUpFeature, validate_and_normalise_user_input, ) @@ -84,25 +85,19 @@ def __init__( recipe_id: str, app_info: AppInfo, ingredients: EmailPasswordIngredients, - sign_up_feature: Union[InputSignUpFeature, None] = None, - override: Union[InputOverrideConfig, None] = None, - email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, + config: EmailPasswordConfig, ): super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( app_info, - sign_up_feature, - override, - email_delivery, + config=config, ) recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self.config ) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) email_delivery_ingredient = ingredients.email_delivery @@ -114,11 +109,7 @@ def __init__( self.email_delivery = email_delivery_ingredient api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) + self.api_implementation = self.config.override.apis(api_implementation) def callback(): mfa_instance = MultiFactorAuthRecipe.get_instance() @@ -370,20 +361,30 @@ def get_all_cors_headers(self) -> List[str]: @staticmethod def init( - sign_up_feature: Union[InputSignUpFeature, None] = None, - override: Union[InputOverrideConfig, None] = None, - email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, + sign_up_feature: Optional[InputSignUpFeature] = None, + override: Optional[EmailPasswordOverrideConfig] = None, + email_delivery: Optional[EmailDeliveryConfig[EmailTemplateVars]] = None, ): - def func(app_info: AppInfo): + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = EmailPasswordConfig( + sign_up_feature=sign_up_feature, + email_delivery=email_delivery, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if EmailPasswordRecipe.__instance is None: ingredients = EmailPasswordIngredients(None) EmailPasswordRecipe.__instance = EmailPasswordRecipe( - EmailPasswordRecipe.recipe_id, - app_info, - ingredients, - sign_up_feature, - override, - email_delivery=email_delivery, + recipe_id=EmailPasswordRecipe.recipe_id, + app_info=app_info, + ingredients=ingredients, + config=apply_plugins( + recipe_id=EmailPasswordRecipe.recipe_id, + config=config, + plugins=plugins, + ), ) return EmailPasswordRecipe.__instance raise Exception( diff --git a/supertokens_python/recipe/emailpassword/recipe_implementation.py b/supertokens_python/recipe/emailpassword/recipe_implementation.py index 453506ef1..d21fb9eed 100644 --- a/supertokens_python/recipe/emailpassword/recipe_implementation.py +++ b/supertokens_python/recipe/emailpassword/recipe_implementation.py @@ -42,7 +42,7 @@ UpdateEmailOrPasswordOkResult, WrongCredentialsError, ) -from .utils import EmailPasswordConfig +from .utils import NormalisedEmailPasswordConfig if TYPE_CHECKING: from supertokens_python.querier import Querier @@ -52,7 +52,7 @@ class RecipeImplementation(RecipeInterface): def __init__( self, querier: Querier, - ep_config: EmailPasswordConfig, + ep_config: NormalisedEmailPasswordConfig, ): super().__init__() self.querier = querier diff --git a/supertokens_python/recipe/emailpassword/utils.py b/supertokens_python/recipe/emailpassword/utils.py index fa267eae7..62d4c72e3 100644 --- a/supertokens_python/recipe/emailpassword/utils.py +++ b/supertokens_python/recipe/emailpassword/utils.py @@ -24,17 +24,21 @@ from supertokens_python.recipe.emailpassword.emaildelivery.services.backward_compatibility import ( BackwardCompatibilityService, ) +from supertokens_python.types.config import ( + BaseConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, + BaseOverrideConfig, +) +from supertokens_python.utils import get_filtered_list +from .constants import FORM_FIELD_EMAIL_ID, FORM_FIELD_PASSWORD_ID from .interfaces import APIInterface, RecipeInterface from .types import EmailTemplateVars, InputFormField, NormalisedFormField if TYPE_CHECKING: from supertokens_python.supertokens import AppInfo -from supertokens_python.utils import get_filtered_list - -from .constants import FORM_FIELD_EMAIL_ID, FORM_FIELD_PASSWORD_ID - async def default_validator(_: str, __: str) -> Union[str, None]: return None @@ -213,82 +217,76 @@ def validate_and_normalise_reset_password_using_token_config( ) -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +EmailPasswordOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedEmailPasswordOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] +InputOverrideConfig = EmailPasswordOverrideConfig +"""Deprecated, use `EmailPasswordOverrideConfig` instead.""" -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class EmailPasswordConfig(BaseConfig[RecipeInterface, APIInterface]): + sign_up_feature: Union[InputSignUpFeature, None] = None + email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None -class EmailPasswordConfig: - def __init__( - self, - sign_up_feature: SignUpFeature, - sign_in_feature: SignInFeature, - reset_password_using_token_feature: ResetPasswordUsingTokenFeature, - override: OverrideConfig, - get_email_delivery_config: Callable[ - [RecipeInterface], EmailDeliveryConfigWithService[EmailTemplateVars] - ], - ): - self.sign_up_feature = sign_up_feature - self.sign_in_feature = sign_in_feature - self.reset_password_using_token_feature = reset_password_using_token_feature - self.override = override - self.get_email_delivery_config = get_email_delivery_config +class NormalisedEmailPasswordConfig( + BaseNormalisedConfig[RecipeInterface, APIInterface] +): + sign_up_feature: SignUpFeature + sign_in_feature: SignInFeature + reset_password_using_token_feature: ResetPasswordUsingTokenFeature + get_email_delivery_config: Callable[ + [RecipeInterface], EmailDeliveryConfigWithService[EmailTemplateVars] + ] def validate_and_normalise_user_input( app_info: AppInfo, - sign_up_feature: Union[InputSignUpFeature, None] = None, - override: Union[InputOverrideConfig, None] = None, - email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, -) -> EmailPasswordConfig: + config: EmailPasswordConfig, +) -> NormalisedEmailPasswordConfig: # NOTE: We don't need to check the instance of sign_up_feature and override # as they will always be either None or the specified type. - if override is None: - override = InputOverrideConfig() + override_config = NormalisedEmailPasswordOverrideConfig.from_input_config( + override_config=config.override + ) + sign_up_feature = config.sign_up_feature if sign_up_feature is None: sign_up_feature = InputSignUpFeature() def get_email_delivery_config( ep_recipe: RecipeInterface, ) -> EmailDeliveryConfigWithService[EmailTemplateVars]: - if email_delivery and email_delivery.service: + if config.email_delivery and config.email_delivery.service: return EmailDeliveryConfigWithService( - service=email_delivery.service, override=email_delivery.override + service=config.email_delivery.service, + override=config.email_delivery.override, ) email_service = BackwardCompatibilityService( app_info=app_info, recipe_interface_impl=ep_recipe, ) - if email_delivery is not None and email_delivery.override is not None: - override = email_delivery.override + if ( + config.email_delivery is not None + and config.email_delivery.override is not None + ): + override = config.email_delivery.override else: override = None return EmailDeliveryConfigWithService(email_service, override=override) - return EmailPasswordConfig( - SignUpFeature(sign_up_feature.form_fields), - SignInFeature(normalise_sign_in_form_fields(sign_up_feature.form_fields)), - validate_and_normalise_reset_password_using_token_config(sign_up_feature), - OverrideConfig(functions=override.functions, apis=override.apis), + return NormalisedEmailPasswordConfig( + sign_up_feature=SignUpFeature(sign_up_feature.form_fields), + sign_in_feature=SignInFeature( + normalise_sign_in_form_fields(sign_up_feature.form_fields) + ), + reset_password_using_token_feature=validate_and_normalise_reset_password_using_token_config( + sign_up_feature + ), + override=override_config, get_email_delivery_config=get_email_delivery_config, ) diff --git a/supertokens_python/recipe/emailverification/__init__.py b/supertokens_python/recipe/emailverification/__init__.py index 4836bb3e0..6d53c2940 100644 --- a/supertokens_python/recipe/emailverification/__init__.py +++ b/supertokens_python/recipe/emailverification/__init__.py @@ -13,39 +13,42 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Optional, Union -from ...ingredients.emaildelivery.types import EmailDeliveryConfig -from . import exceptions as ex -from . import recipe, types, utils -from .emaildelivery import services as emaildelivery_services -from .interfaces import TypeGetEmailForUserIdFunction -from .recipe import EmailVerificationRecipe -from .types import EmailTemplateVars -from .utils import MODE_TYPE, OverrideConfig - -InputOverrideConfig = utils.OverrideConfig -exception = ex -SMTPService = emaildelivery_services.SMTPService -EmailVerificationClaim = recipe.EmailVerificationClaim -EmailDeliveryInterface = types.EmailDeliveryInterface +from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig +from .emaildelivery.services import SMTPService +from .interfaces import TypeGetEmailForUserIdFunction +from .recipe import EmailVerificationClaim, EmailVerificationRecipe +from .types import EmailDeliveryInterface, EmailTemplateVars +from .utils import MODE_TYPE, EmailVerificationOverrideConfig, InputOverrideConfig if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( mode: MODE_TYPE, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, - override: Union[OverrideConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: + override: Union[EmailVerificationOverrideConfig, None] = None, +) -> RecipeInit: return EmailVerificationRecipe.init( mode, email_delivery, get_email_for_recipe_user_id, override, ) + + +__all__ = [ + "EmailDeliveryInterface", + "EmailTemplateVars", + "EmailVerificationClaim", + "EmailVerificationOverrideConfig", + "EmailVerificationRecipe", + "InputOverrideConfig", # deprecated, use EmailVerificationOverrideConfig instead + "SMTPService", + "TypeGetEmailForUserIdFunction", + "init", +] diff --git a/supertokens_python/recipe/emailverification/interfaces.py b/supertokens_python/recipe/emailverification/interfaces.py index bbf829047..88f1f4522 100644 --- a/supertokens_python/recipe/emailverification/interfaces.py +++ b/supertokens_python/recipe/emailverification/interfaces.py @@ -13,13 +13,14 @@ # under the License. from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Union from typing_extensions import Literal from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.types import RecipeUserId +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from supertokens_python.types.response import APIResponse, GeneralErrorResponse from ...supertokens import AppInfo @@ -29,7 +30,7 @@ from supertokens_python.framework import BaseRequest, BaseResponse from .types import EmailVerificationUser, VerificationEmailTemplateVars - from .utils import EmailVerificationConfig + from .utils import NormalisedEmailVerificationConfig class CreateEmailVerificationTokenOkResult: @@ -84,7 +85,7 @@ class UnverifyEmailOkResult: pass -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): def __init__(self): pass @@ -140,7 +141,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: EmailVerificationConfig, + config: NormalisedEmailVerificationConfig, recipe_implementation: RecipeInterface, app_info: AppInfo, email_delivery: EmailDeliveryIngredient[VerificationEmailTemplateVars], @@ -201,7 +202,7 @@ def to_json(self) -> Dict[str, Any]: return {"status": self.status} -class APIInterface(ABC): +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_email_verify_post = False self.disable_is_email_verified_get = False diff --git a/supertokens_python/recipe/emailverification/recipe.py b/supertokens_python/recipe/emailverification/recipe.py index a4abfb47c..2db012b64 100644 --- a/supertokens_python/recipe/emailverification/recipe.py +++ b/supertokens_python/recipe/emailverification/recipe.py @@ -18,6 +18,8 @@ from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.querier import Querier from supertokens_python.recipe.emailverification.exceptions import ( EmailVerificationInvalidTokenError, ) @@ -27,6 +29,7 @@ VerificationEmailTemplateVars, VerificationEmailTemplateVarsUser, ) +from supertokens_python.recipe.emailverification.utils import get_email_verify_link from supertokens_python.recipe_module import APIHandled, RecipeModule from ...ingredients.emaildelivery.types import EmailDeliveryConfig @@ -46,6 +49,9 @@ SessionClaimValidator, SessionContainer, ) +from .api import handle_email_verify_api, handle_generate_email_verify_token_api +from .constants import USER_EMAIL_VERIFY, USER_EMAIL_VERIFY_TOKEN +from .exceptions import SuperTokensEmailVerificationError from .interfaces import ( APIInterface, APIOptions, @@ -62,6 +68,12 @@ VerifyEmailUsingTokenOkResult, ) from .recipe_implementation import RecipeImplementation +from .utils import ( + MODE_TYPE, + EmailVerificationConfig, + EmailVerificationOverrideConfig, + validate_and_normalise_user_input, +) if TYPE_CHECKING: from supertokens_python.framework.request import BaseRequest @@ -71,15 +83,6 @@ from ...types import MaybeAwaitable, User -from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.querier import Querier -from supertokens_python.recipe.emailverification.utils import get_email_verify_link - -from .api import handle_email_verify_api, handle_generate_email_verify_token_api -from .constants import USER_EMAIL_VERIFY, USER_EMAIL_VERIFY_TOKEN -from .exceptions import SuperTokensEmailVerificationError -from .utils import MODE_TYPE, OverrideConfig, validate_and_normalise_user_input - class EmailVerificationRecipe(RecipeModule): recipe_id = "emailverification" @@ -91,36 +94,24 @@ def __init__( recipe_id: str, app_info: AppInfo, ingredients: EmailVerificationIngredients, - mode: MODE_TYPE, - email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, - get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, - override: Union[OverrideConfig, None] = None, + config: EmailVerificationConfig, ) -> None: super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( - app_info, - mode, - email_delivery, - get_email_for_recipe_user_id, - override, + app_info=app_info, + config=config, ) recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self.get_email_for_recipe_user_id, ) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) + self.api_implementation = self.config.override.apis(api_implementation) email_delivery_ingredient = ingredients.email_delivery if email_delivery_ingredient is None: @@ -207,19 +198,31 @@ def init( mode: MODE_TYPE, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, - override: Union[OverrideConfig, None] = None, + override: Optional[EmailVerificationOverrideConfig] = None, ): - def func(app_info: AppInfo) -> EmailVerificationRecipe: + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = EmailVerificationConfig( + mode=mode, + email_delivery=email_delivery, + get_email_for_recipe_user_id=get_email_for_recipe_user_id, + override=override, + ) + + def func( + app_info: AppInfo, plugins: List[OverrideMap] + ) -> EmailVerificationRecipe: if EmailVerificationRecipe.__instance is None: ingredients = EmailVerificationIngredients(email_delivery=None) EmailVerificationRecipe.__instance = EmailVerificationRecipe( EmailVerificationRecipe.recipe_id, app_info, ingredients, - mode, - email_delivery, - get_email_for_recipe_user_id, - override, + config=apply_plugins( + recipe_id=EmailVerificationRecipe.recipe_id, + config=config, + plugins=plugins, + ), ) def callback(): diff --git a/supertokens_python/recipe/emailverification/utils.py b/supertokens_python/recipe/emailverification/utils.py index 8811f073a..5ff3754db 100644 --- a/supertokens_python/recipe/emailverification/utils.py +++ b/supertokens_python/recipe/emailverification/utils.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union from typing_extensions import Literal @@ -26,53 +26,51 @@ from supertokens_python.recipe.emailverification.emaildelivery.services.backward_compatibility import ( BackwardCompatibilityService, ) +from supertokens_python.types.config import ( + BaseConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, + BaseOverrideConfig, +) -if TYPE_CHECKING: - from typing import Callable, Union +from .interfaces import APIInterface, RecipeInterface, TypeGetEmailForUserIdFunction +from .types import EmailTemplateVars, VerificationEmailTemplateVars +if TYPE_CHECKING: from supertokens_python.supertokens import AppInfo - from .interfaces import APIInterface, RecipeInterface, TypeGetEmailForUserIdFunction - from .types import EmailTemplateVars, VerificationEmailTemplateVars +MODE_TYPE = Literal["REQUIRED", "OPTIONAL"] -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +EmailVerificationOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedEmailVerificationOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] +InputOverrideConfig = EmailVerificationOverrideConfig +"""Deprecated, use `EmailVerificationOverrideConfig` instead.""" -MODE_TYPE = Literal["REQUIRED", "OPTIONAL"] +class EmailVerificationConfig(BaseConfig[RecipeInterface, APIInterface]): + mode: MODE_TYPE + email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None + get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None -class EmailVerificationConfig: - def __init__( - self, - mode: MODE_TYPE, - get_email_delivery_config: Callable[ - [], EmailDeliveryConfigWithService[VerificationEmailTemplateVars] - ], - get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction], - override: OverrideConfig, - ): - self.mode = mode - self.override = override - self.get_email_delivery_config = get_email_delivery_config - self.get_email_for_recipe_user_id = get_email_for_recipe_user_id +class NormalisedEmailVerificationConfig( + BaseNormalisedConfig[RecipeInterface, APIInterface] +): + mode: MODE_TYPE + get_email_delivery_config: Callable[ + [], EmailDeliveryConfigWithService[VerificationEmailTemplateVars] + ] + get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] def validate_and_normalise_user_input( app_info: AppInfo, - mode: MODE_TYPE, - email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, - get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, - override: Union[OverrideConfig, None] = None, -) -> EmailVerificationConfig: - if mode not in ["REQUIRED", "OPTIONAL"]: + config: EmailVerificationConfig, +) -> NormalisedEmailVerificationConfig: + if config.mode not in ["REQUIRED", "OPTIONAL"]: raise ValueError( "Email Verification recipe mode must be one of 'REQUIRED' or 'OPTIONAL'" ) @@ -80,27 +78,30 @@ def validate_and_normalise_user_input( def get_email_delivery_config() -> EmailDeliveryConfigWithService[ VerificationEmailTemplateVars ]: - email_service = email_delivery.service if email_delivery is not None else None + email_service = ( + config.email_delivery.service if config.email_delivery is not None else None + ) if email_service is None: email_service = BackwardCompatibilityService(app_info) - if email_delivery is not None and email_delivery.override is not None: - override = email_delivery.override + if ( + config.email_delivery is not None + and config.email_delivery.override is not None + ): + override = config.email_delivery.override else: override = None return EmailDeliveryConfigWithService(email_service, override=override) - if override is not None and not isinstance(override, OverrideConfig): # type: ignore - raise ValueError("override must be of type OverrideConfig or None") - - if override is None: - override = OverrideConfig() + override_config = NormalisedEmailVerificationOverrideConfig.from_input_config( + override_config=config.override + ) - return EmailVerificationConfig( - mode, - get_email_delivery_config, - get_email_for_recipe_user_id, - override, + return NormalisedEmailVerificationConfig( + mode=config.mode, + get_email_delivery_config=get_email_delivery_config, + get_email_for_recipe_user_id=config.get_email_for_recipe_user_id, + override=override_config, ) diff --git a/supertokens_python/recipe/jwt/__init__.py b/supertokens_python/recipe/jwt/__init__.py index e53bb46c1..0b12165e0 100644 --- a/supertokens_python/recipe/jwt/__init__.py +++ b/supertokens_python/recipe/jwt/__init__.py @@ -13,19 +13,25 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Union from .recipe import JWTRecipe -from .utils import OverrideConfig +from .utils import JWTOverrideConfig, OverrideConfig if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( jwt_validity_seconds: Union[int, None] = None, - override: Union[OverrideConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: + override: Union[JWTOverrideConfig, None] = None, +) -> RecipeInit: return JWTRecipe.init(jwt_validity_seconds, override) + + +__all__ = [ + "JWTOverrideConfig", + "JWTRecipe", + "OverrideConfig", # deprecated, use JWTOverrideConfig instead + "init", +] diff --git a/supertokens_python/recipe/jwt/interfaces.py b/supertokens_python/recipe/jwt/interfaces.py index 2c398f8da..fbedd8ce6 100644 --- a/supertokens_python/recipe/jwt/interfaces.py +++ b/supertokens_python/recipe/jwt/interfaces.py @@ -11,13 +11,15 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Union +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from supertokens_python.framework import BaseRequest, BaseResponse +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from supertokens_python.types.response import APIResponse, GeneralErrorResponse -from .utils import JWTConfig +if TYPE_CHECKING: + from .utils import NormalisedJWTConfig class JsonWebKey: @@ -45,7 +47,7 @@ def __init__(self, keys: List[JsonWebKey], validity_in_secs: Optional[int]): self.validity_in_secs = validity_in_secs -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): def __init__(self): pass @@ -70,7 +72,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: JWTConfig, + config: "NormalisedJWTConfig", recipe_implementation: RecipeInterface, ): self.request = request @@ -101,7 +103,7 @@ def to_json(self) -> Dict[str, Any]: return {"keys": keys} -class APIInterface: +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_jwks_get = False diff --git a/supertokens_python/recipe/jwt/recipe.py b/supertokens_python/recipe/jwt/recipe.py index 4f28c5d88..5cfe20ccd 100644 --- a/supertokens_python/recipe/jwt/recipe.py +++ b/supertokens_python/recipe/jwt/recipe.py @@ -24,7 +24,8 @@ from supertokens_python.recipe.jwt.interfaces import APIOptions from supertokens_python.recipe.jwt.recipe_implementation import RecipeImplementation from supertokens_python.recipe.jwt.utils import ( - OverrideConfig, + JWTConfig, + JWTOverrideConfig, validate_and_normalise_user_input, ) @@ -46,26 +47,19 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - jwt_validity_seconds: Union[int, None] = None, - override: Union[OverrideConfig, None] = None, + config: JWTConfig, ): super().__init__(recipe_id, app_info) - self.config = validate_and_normalise_user_input(jwt_validity_seconds, override) + self.config = validate_and_normalise_user_input(config=config) recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self.config, app_info ) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) + self.api_implementation = self.config.override.apis(api_implementation) def get_apis_handled(self) -> List[APIHandled]: return [ @@ -120,12 +114,25 @@ def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: @staticmethod def init( jwt_validity_seconds: Union[int, None] = None, - override: Union[OverrideConfig, None] = None, + override: Union[JWTOverrideConfig, None] = None, ): - def func(app_info: AppInfo): + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = JWTConfig( + jwt_validity_seconds=jwt_validity_seconds, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if JWTRecipe.__instance is None: JWTRecipe.__instance = JWTRecipe( - JWTRecipe.recipe_id, app_info, jwt_validity_seconds, override + JWTRecipe.recipe_id, + app_info, + config=apply_plugins( + recipe_id=JWTRecipe.recipe_id, + config=config, + plugins=plugins, + ), ) return JWTRecipe.__instance raise_general_exception( diff --git a/supertokens_python/recipe/jwt/recipe_implementation.py b/supertokens_python/recipe/jwt/recipe_implementation.py index 0e7a3f464..9c308a4d1 100644 --- a/supertokens_python/recipe/jwt/recipe_implementation.py +++ b/supertokens_python/recipe/jwt/recipe_implementation.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from supertokens_python.supertokens import AppInfo - from .utils import JWTConfig + from .utils import NormalisedJWTConfig from supertokens_python.recipe.jwt.interfaces import ( CreateJwtOkResult, @@ -38,7 +38,9 @@ class RecipeImplementation(RecipeInterface): - def __init__(self, querier: Querier, config: JWTConfig, app_info: AppInfo): + def __init__( + self, querier: Querier, config: NormalisedJWTConfig, app_info: AppInfo + ): super().__init__() self.querier = querier self.config = config diff --git a/supertokens_python/recipe/jwt/utils.py b/supertokens_python/recipe/jwt/utils.py index 35dbef8e1..5fb2e007a 100644 --- a/supertokens_python/recipe/jwt/utils.py +++ b/supertokens_python/recipe/jwt/utils.py @@ -13,41 +13,46 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import Optional -if TYPE_CHECKING: - from .interfaces import APIInterface, RecipeInterface +from supertokens_python.types.config import ( + BaseConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, + BaseOverrideConfig, +) +from .interfaces import APIInterface, RecipeInterface -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +JWTOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedJWTOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] +OverrideConfig = JWTOverrideConfig +"""Deprecated, use `JWTOverrideConfig` instead.""" -class JWTConfig: - def __init__(self, override: OverrideConfig, jwt_validity_seconds: int): - self.override = override - self.jwt_validity_seconds = jwt_validity_seconds +class JWTConfig(BaseConfig[RecipeInterface, APIInterface]): + jwt_validity_seconds: Optional[int] = None -def validate_and_normalise_user_input( - jwt_validity_seconds: Union[int, None] = None, - override: Union[OverrideConfig, None] = None, -): - if jwt_validity_seconds is not None and not isinstance(jwt_validity_seconds, int): # type: ignore - raise ValueError("jwt_validity_seconds must be an integer or None") +class NormalisedJWTConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): + jwt_validity_seconds: int + + +def validate_and_normalise_user_input(config: JWTConfig): + override_config = NormalisedJWTOverrideConfig.from_input_config( + override_config=config.override + ) - if override is not None and not isinstance(override, OverrideConfig): # type: ignore - raise ValueError("override must be an instance of OverrideConfig or None") + jwt_validity_seconds = config.jwt_validity_seconds - if override is None: - override = OverrideConfig() - if jwt_validity_seconds is None: + if config.jwt_validity_seconds is None: jwt_validity_seconds = 3153600000 - return JWTConfig(override, jwt_validity_seconds) + if not isinstance(jwt_validity_seconds, int): # type: ignore + raise ValueError("jwt_validity_seconds must be an integer or None") + + return NormalisedJWTConfig( + jwt_validity_seconds=jwt_validity_seconds, override=override_config + ) diff --git a/supertokens_python/recipe/multifactorauth/__init__.py b/supertokens_python/recipe/multifactorauth/__init__.py index fa3aaf6f8..614d6bf5e 100644 --- a/supertokens_python/recipe/multifactorauth/__init__.py +++ b/supertokens_python/recipe/multifactorauth/__init__.py @@ -13,23 +13,32 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union -from supertokens_python.recipe.multifactorauth.types import OverrideConfig +from supertokens_python.recipe.multifactorauth.types import ( + MultiFactorAuthOverrideConfig, + OverrideConfig, +) from .recipe import MultiFactorAuthRecipe if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( first_factors: Optional[List[str]] = None, - override: Union[OverrideConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: + override: Union[MultiFactorAuthOverrideConfig, None] = None, +) -> RecipeInit: return MultiFactorAuthRecipe.init( first_factors, override, ) + + +__all__ = [ + "MultiFactorAuthOverrideConfig", + "MultiFactorAuthRecipe", + "OverrideConfig", # deprecated, use MultiFactorAuthOverrideConfig instead + "init", +] diff --git a/supertokens_python/recipe/multifactorauth/interfaces.py b/supertokens_python/recipe/multifactorauth/interfaces.py index b5b0ec928..42dc1a4f8 100644 --- a/supertokens_python/recipe/multifactorauth/interfaces.py +++ b/supertokens_python/recipe/multifactorauth/interfaces.py @@ -14,9 +14,11 @@ from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Union +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface + from ...types.response import APIResponse, GeneralErrorResponse if TYPE_CHECKING: @@ -26,10 +28,10 @@ from supertokens_python.types import User from ...supertokens import AppInfo - from .types import MFARequirementList, MultiFactorAuthConfig + from .types import MFARequirementList, NormalisedMultiFactorAuthConfig -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): @abstractmethod async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( self, @@ -95,7 +97,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: MultiFactorAuthConfig, + config: NormalisedMultiFactorAuthConfig, recipe_implementation: RecipeInterface, app_info: AppInfo, recipe_instance: MultiFactorAuthRecipe, @@ -109,7 +111,7 @@ def __init__( self.recipe_instance = recipe_instance -class APIInterface: +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_resync_session_and_fetch_mfa_info_put = False diff --git a/supertokens_python/recipe/multifactorauth/recipe.py b/supertokens_python/recipe/multifactorauth/recipe.py index 995137b1a..be28497ef 100644 --- a/supertokens_python/recipe/multifactorauth/recipe.py +++ b/supertokens_python/recipe/multifactorauth/recipe.py @@ -13,7 +13,6 @@ # under the License. from __future__ import annotations -import importlib from os import environ from typing import Any, Dict, List, Optional, Union @@ -37,7 +36,9 @@ from supertokens_python.supertokens import AppInfo from supertokens_python.types import RecipeUserId, User +from .api.implementation import APIImplementation from .interfaces import APIOptions +from .recipe_implementation import RecipeImplementation from .types import ( GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc, GetEmailsForFactorFromOtherRecipesFunc, @@ -47,8 +48,10 @@ GetPhoneNumbersForFactorsFromOtherRecipesFunc, GetPhoneNumbersForFactorsOkResult, GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult, - OverrideConfig, + MultiFactorAuthConfig, + MultiFactorAuthOverrideConfig, ) +from .utils import validate_and_normalise_user_input class MultiFactorAuthRecipe(RecipeModule): @@ -59,8 +62,7 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - first_factors: Optional[List[str]] = None, - override: Union[OverrideConfig, None] = None, + config: MultiFactorAuthConfig, ): super().__init__(recipe_id, app_info) self.get_factors_setup_for_user_from_other_recipes_funcs: List[ @@ -77,32 +79,19 @@ def __init__( ] = [] self.is_get_mfa_requirements_for_auth_overridden: bool = False - module = importlib.import_module( - "supertokens_python.recipe.multifactorauth.utils" - ) - - self.config = module.validate_and_normalise_user_input( - first_factors, - override, + self.config = validate_and_normalise_user_input( + config=config, ) - from .recipe_implementation import RecipeImplementation recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self ) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) - from .api.implementation import APIImplementation api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) + self.api_implementation = self.config.override.apis(api_implementation) def callback(): from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe @@ -169,15 +158,25 @@ def get_all_cors_headers(self) -> List[str]: @staticmethod def init( first_factors: Optional[List[str]] = None, - override: Union[OverrideConfig, None] = None, + override: Union[MultiFactorAuthOverrideConfig, None] = None, ): - def func(app_info: AppInfo): + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = MultiFactorAuthConfig( + first_factors=first_factors, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if MultiFactorAuthRecipe.__instance is None: MultiFactorAuthRecipe.__instance = MultiFactorAuthRecipe( MultiFactorAuthRecipe.recipe_id, app_info, - first_factors, - override, + config=apply_plugins( + recipe_id=MultiFactorAuthRecipe.recipe_id, + config=config, + plugins=plugins, + ), ) return MultiFactorAuthRecipe.__instance raise_general_exception( diff --git a/supertokens_python/recipe/multifactorauth/types.py b/supertokens_python/recipe/multifactorauth/types.py index 1f5587189..e53e6362f 100644 --- a/supertokens_python/recipe/multifactorauth/types.py +++ b/supertokens_python/recipe/multifactorauth/types.py @@ -13,16 +13,20 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union from typing_extensions import Literal from supertokens_python.recipe.multitenancy.interfaces import TenantConfig from supertokens_python.types import RecipeUserId, User +from supertokens_python.types.config import ( + BaseConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, + BaseOverrideConfig, +) -if TYPE_CHECKING: - from .interfaces import APIInterface, RecipeInterface - +from .interfaces import APIInterface, RecipeInterface MFARequirementList = List[ Union[str, Dict[Union[Literal["oneOf"], Literal["allOfInAnyOrder"]], List[str]]] @@ -38,24 +42,22 @@ def __init__(self, c: Dict[str, Any], v: bool): self.v = v -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +MultiFactorAuthOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedMultiFactorAuthOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] +OverrideConfig = MultiFactorAuthOverrideConfig +"""Deprecated, use `MultiFactorAuthOverrideConfig` instead.""" -class MultiFactorAuthConfig: - def __init__( - self, - first_factors: Optional[List[str]], - override: OverrideConfig, - ): - self.first_factors = first_factors - self.override = override +class MultiFactorAuthConfig(BaseConfig[RecipeInterface, APIInterface]): + first_factors: Optional[List[str]] = None + + +class NormalisedMultiFactorAuthConfig( + BaseNormalisedConfig[RecipeInterface, APIInterface] +): + first_factors: Optional[List[str]] class FactorIds: diff --git a/supertokens_python/recipe/multifactorauth/utils.py b/supertokens_python/recipe/multifactorauth/utils.py index 38bd635c5..a5990360d 100644 --- a/supertokens_python/recipe/multifactorauth/utils.py +++ b/supertokens_python/recipe/multifactorauth/utils.py @@ -15,7 +15,7 @@ import math import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from typing_extensions import Literal @@ -36,28 +36,28 @@ from supertokens_python.types import RecipeUserId from supertokens_python.utils import log_debug_message -if TYPE_CHECKING: - from .types import MultiFactorAuthConfig, OverrideConfig +from .types import ( + MultiFactorAuthConfig, + NormalisedMultiFactorAuthConfig, + NormalisedMultiFactorAuthOverrideConfig, +) # IMPORTANT: If this function signature is modified, please update all tha places where this function is called. # There will be no type errors cause we use importLib to dynamically import if to prevent cyclic import issues. def validate_and_normalise_user_input( - first_factors: Optional[List[str]], - override: Union[OverrideConfig, None] = None, -) -> MultiFactorAuthConfig: - if first_factors is not None and len(first_factors) == 0: + config: MultiFactorAuthConfig, +) -> NormalisedMultiFactorAuthConfig: + if config.first_factors is not None and len(config.first_factors) == 0: raise ValueError("'first_factors' can be either None or a non-empty list") - from .types import MultiFactorAuthConfig as MFAC - from .types import OverrideConfig as OC - - if override is None: - override = OC() + override_config = NormalisedMultiFactorAuthOverrideConfig.from_input_config( + override_config=config.override + ) - return MFAC( - first_factors=first_factors, - override=override, + return NormalisedMultiFactorAuthConfig( + first_factors=config.first_factors, + override=override_config, ) diff --git a/supertokens_python/recipe/multitenancy/__init__.py b/supertokens_python/recipe/multitenancy/__init__.py index 5a7180a45..a91f5fa83 100644 --- a/supertokens_python/recipe/multitenancy/__init__.py +++ b/supertokens_python/recipe/multitenancy/__init__.py @@ -13,29 +13,33 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Union -from . import exceptions as ex -from . import recipe - -AllowedDomainsClaim = recipe.AllowedDomainsClaim -exceptions = ex +from .interfaces import TypeGetAllowedDomainsForTenantId +from .recipe import AllowedDomainsClaim, MultitenancyRecipe +from .utils import InputOverrideConfig, MultitenancyOverrideConfig if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule - from .interfaces import TypeGetAllowedDomainsForTenantId - from .utils import InputOverrideConfig + from supertokens_python.supertokens import RecipeInit def init( get_allowed_domains_for_tenant_id: Union[ TypeGetAllowedDomainsForTenantId, None ] = None, - override: Union[InputOverrideConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: - return recipe.MultitenancyRecipe.init( + override: Union[MultitenancyOverrideConfig, None] = None, +) -> RecipeInit: + return MultitenancyRecipe.init( get_allowed_domains_for_tenant_id, override, ) + + +__all__ = [ + "AllowedDomainsClaim", + "InputOverrideConfig", # deprecated, use MultitenancyOverrideConfig instead + "MultitenancyOverrideConfig", + "MultitenancyRecipe", + "TypeGetAllowedDomainsForTenantId", + "init", +] diff --git a/supertokens_python/recipe/multitenancy/interfaces.py b/supertokens_python/recipe/multitenancy/interfaces.py index 03bd8d012..0af997881 100644 --- a/supertokens_python/recipe/multitenancy/interfaces.py +++ b/supertokens_python/recipe/multitenancy/interfaces.py @@ -13,10 +13,11 @@ # under the License. from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union from supertokens_python.types import RecipeUserId +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from supertokens_python.types.response import APIResponse, GeneralErrorResponse if TYPE_CHECKING: @@ -26,7 +27,7 @@ ProviderInput, ) - from .utils import MultitenancyConfig + from .utils import NormalisedMultitenancyConfig class TenantConfig: @@ -192,7 +193,7 @@ def __init__(self, was_associated: bool): self.was_associated = was_associated -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): def __init__(self): pass @@ -282,7 +283,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: MultitenancyConfig, + config: NormalisedMultitenancyConfig, recipe_implementation: RecipeInterface, static_third_party_providers: List[ProviderInput], all_available_first_factors: List[str], @@ -366,7 +367,7 @@ def to_json(self) -> Dict[str, Any]: } -class APIInterface(ABC): +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_login_methods_get = False diff --git a/supertokens_python/recipe/multitenancy/recipe.py b/supertokens_python/recipe/multitenancy/recipe.py index 38873a4e1..e73b32ad3 100644 --- a/supertokens_python/recipe/multitenancy/recipe.py +++ b/supertokens_python/recipe/multitenancy/recipe.py @@ -44,7 +44,8 @@ from .constants import LOGIN_METHODS from .exceptions import MultitenancyError from .utils import ( - InputOverrideConfig, + MultitenancyConfig, + MultitenancyOverrideConfig, validate_and_normalise_user_input, ) @@ -54,35 +55,20 @@ class MultitenancyRecipe(RecipeModule): __instance = None def __init__( - self, - recipe_id: str, - app_info: AppInfo, - get_allowed_domains_for_tenant_id: Optional[ - TypeGetAllowedDomainsForTenantId - ] = None, - override: Union[InputOverrideConfig, None] = None, + self, recipe_id: str, app_info: AppInfo, config: MultitenancyConfig ) -> None: super().__init__(recipe_id, app_info) - self.config = validate_and_normalise_user_input( - get_allowed_domains_for_tenant_id, - override, - ) + self.config = validate_and_normalise_user_input(config=config) recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self.config ) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) + self.api_implementation = self.config.override.apis(api_implementation) self.static_third_party_providers: List[ProviderInput] = [] self.get_allowed_domains_for_tenant_id = ( @@ -150,15 +136,25 @@ def init( get_allowed_domains_for_tenant_id: Union[ TypeGetAllowedDomainsForTenantId, None ] = None, - override: Union[InputOverrideConfig, None] = None, + override: Union[MultitenancyOverrideConfig, None] = None, ): - def func(app_info: AppInfo): + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = MultitenancyConfig( + get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if MultitenancyRecipe.__instance is None: MultitenancyRecipe.__instance = MultitenancyRecipe( - MultitenancyRecipe.recipe_id, - app_info, - get_allowed_domains_for_tenant_id, - override, + recipe_id=MultitenancyRecipe.recipe_id, + app_info=app_info, + config=apply_plugins( + recipe_id=MultitenancyRecipe.recipe_id, + config=config, + plugins=plugins, + ), ) def callback(): diff --git a/supertokens_python/recipe/multitenancy/recipe_implementation.py b/supertokens_python/recipe/multitenancy/recipe_implementation.py index 13e7701e8..091f980f6 100644 --- a/supertokens_python/recipe/multitenancy/recipe_implementation.py +++ b/supertokens_python/recipe/multitenancy/recipe_implementation.py @@ -41,7 +41,7 @@ from supertokens_python.querier import Querier from supertokens_python.recipe.thirdparty.provider import ProviderConfig - from .utils import MultitenancyConfig + from .utils import NormalisedMultitenancyConfig from supertokens_python.querier import NormalisedURLPath @@ -119,7 +119,7 @@ def parse_tenant_config(tenant: Dict[str, Any]) -> TenantConfig: class RecipeImplementation(RecipeInterface): - def __init__(self, querier: Querier, config: MultitenancyConfig): + def __init__(self, querier: Querier, config: NormalisedMultitenancyConfig): super().__init__() self.querier = querier self.config = config diff --git a/supertokens_python/recipe/multitenancy/utils.py b/supertokens_python/recipe/multitenancy/utils.py index 43bab761b..6f83d5518 100644 --- a/supertokens_python/recipe/multitenancy/utils.py +++ b/supertokens_python/recipe/multitenancy/utils.py @@ -14,22 +14,25 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Awaitable, Callable, Optional +from typing import Awaitable, Callable, Optional, Union from supertokens_python.exceptions import SuperTokensError from supertokens_python.framework import BaseRequest, BaseResponse +from supertokens_python.types.config import ( + BaseConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, + BaseOverrideConfig, +) from supertokens_python.utils import ( resolve, ) -if TYPE_CHECKING: - from typing import Union - - from .interfaces import ( - APIInterface, - RecipeInterface, - TypeGetAllowedDomainsForTenantId, - ) +from .interfaces import ( + APIInterface, + RecipeInterface, + TypeGetAllowedDomainsForTenantId, +) class ErrorHandlers: @@ -63,47 +66,30 @@ async def on_recipe_disabled_for_tenant( ) -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +MultitenancyOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedMultitenancyOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] +InputOverrideConfig = MultitenancyOverrideConfig +"""Deprecated, use `MultitenancyOverrideConfig` instead.""" -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class MultitenancyConfig(BaseConfig[RecipeInterface, APIInterface]): + get_allowed_domains_for_tenant_id: Optional[TypeGetAllowedDomainsForTenantId] = None -class MultitenancyConfig: - def __init__( - self, - get_allowed_domains_for_tenant_id: Optional[TypeGetAllowedDomainsForTenantId], - override: OverrideConfig, - ): - self.get_allowed_domains_for_tenant_id = get_allowed_domains_for_tenant_id - self.override = override +class NormalisedMultitenancyConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): + get_allowed_domains_for_tenant_id: Optional[TypeGetAllowedDomainsForTenantId] def validate_and_normalise_user_input( - get_allowed_domains_for_tenant_id: Optional[TypeGetAllowedDomainsForTenantId], - override: Union[InputOverrideConfig, None] = None, -) -> MultitenancyConfig: - if override is not None and not isinstance(override, OverrideConfig): # type: ignore - raise ValueError("override must be of type OverrideConfig or None") - - if override is None: - override = InputOverrideConfig() - - return MultitenancyConfig( - get_allowed_domains_for_tenant_id, - OverrideConfig(override.functions, override.apis), + config: MultitenancyConfig, +) -> NormalisedMultitenancyConfig: + override_config = NormalisedMultitenancyOverrideConfig.from_input_config( + override_config=config.override + ) + + return NormalisedMultitenancyConfig( + get_allowed_domains_for_tenant_id=config.get_allowed_domains_for_tenant_id, + override=override_config, ) diff --git a/supertokens_python/recipe/oauth2provider/__init__.py b/supertokens_python/recipe/oauth2provider/__init__.py index 3397b3430..c436084b5 100644 --- a/supertokens_python/recipe/oauth2provider/__init__.py +++ b/supertokens_python/recipe/oauth2provider/__init__.py @@ -13,21 +13,24 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Union -from . import exceptions as ex -from . import recipe, utils - -exceptions = ex -InputOverrideConfig = utils.InputOverrideConfig +from .recipe import OAuth2ProviderRecipe +from .utils import InputOverrideConfig, OAuth2ProviderOverrideConfig if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( - override: Union[InputOverrideConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: - return recipe.OAuth2ProviderRecipe.init(override) + override: Union[OAuth2ProviderOverrideConfig, None] = None, +) -> RecipeInit: + return OAuth2ProviderRecipe.init(override) + + +__all__ = [ + "InputOverrideConfig", # deprecated, use OAuth2ProviderOverrideConfig instead + "OAuth2ProviderOverrideConfig", + "OAuth2ProviderRecipe", + "init", +] diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 3329013a4..d2f970b18 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -13,7 +13,7 @@ # under the License. from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union from typing_extensions import Literal @@ -23,6 +23,7 @@ RecipeUserId, User, ) +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from supertokens_python.types.response import APIResponse, GeneralErrorResponse from .oauth2_client import OAuth2Client @@ -31,7 +32,7 @@ from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.supertokens import AppInfo - from .utils import OAuth2ProviderConfig + from .utils import NormalisedOAuth2ProviderConfig class ErrorOAuth2Response(APIResponse): @@ -1016,7 +1017,7 @@ def from_json(json: Dict[str, Any]) -> "UpdateOAuth2ClientInput": ) -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): @abstractmethod async def authorization( self, @@ -1273,18 +1274,18 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: OAuth2ProviderConfig, + config: NormalisedOAuth2ProviderConfig, recipe_implementation: RecipeInterface, ): self.app_info: AppInfo = app_info self.request: BaseRequest = request self.response: BaseResponse = response self.recipe_id: str = recipe_id - self.config: OAuth2ProviderConfig = config + self.config: NormalisedOAuth2ProviderConfig = config self.recipe_implementation: RecipeInterface = recipe_implementation -class APIInterface: +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_login_get = False self.disable_auth_get = False diff --git a/supertokens_python/recipe/oauth2provider/recipe.py b/supertokens_python/recipe/oauth2provider/recipe.py index 0532091ce..a7dcb1ae4 100644 --- a/supertokens_python/recipe/oauth2provider/recipe.py +++ b/supertokens_python/recipe/oauth2provider/recipe.py @@ -66,8 +66,9 @@ USER_INFO_PATH, ) from .utils import ( - InputOverrideConfig, + NormalisedOAuth2ProviderConfig, OAuth2ProviderConfig, + OAuth2ProviderOverrideConfig, validate_and_normalise_user_input, ) @@ -80,11 +81,11 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - override: Union[InputOverrideConfig, None] = None, + config: OAuth2ProviderConfig, ) -> None: super().__init__(recipe_id, app_info) - self.config: OAuth2ProviderConfig = validate_and_normalise_user_input( - override, + self.config: NormalisedOAuth2ProviderConfig = validate_and_normalise_user_input( + config=config, ) from .recipe_implementation import RecipeImplementation @@ -96,19 +97,13 @@ def __init__( self.get_default_id_token_payload, self.get_default_user_info_payload, ) - self.recipe_implementation: RecipeInterface = ( - self.config.override.functions(recipe_implementation) - if self.config.override is not None - and self.config.override.functions is not None - else recipe_implementation + self.recipe_implementation: RecipeInterface = self.config.override.functions( + recipe_implementation ) api_implementation = APIImplementation() - self.api_implementation: APIInterface = ( - self.config.override.apis(api_implementation) - if self.config.override is not None - and self.config.override.apis is not None - else api_implementation + self.api_implementation: APIInterface = self.config.override.apis( + api_implementation ) self._access_token_builders: List[PayloadBuilderFunction] = [] @@ -268,14 +263,22 @@ def get_all_cors_headers(self) -> List[str]: @staticmethod def init( - override: Union[InputOverrideConfig, None] = None, + override: Optional[OAuth2ProviderOverrideConfig] = None, ): - def func(app_info: AppInfo): + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = OAuth2ProviderConfig(override=override) + + def func(app_info: AppInfo, plugins: List[OverrideMap]) -> OAuth2ProviderRecipe: if OAuth2ProviderRecipe.__instance is None: OAuth2ProviderRecipe.__instance = OAuth2ProviderRecipe( - OAuth2ProviderRecipe.recipe_id, - app_info, - override, + recipe_id=OAuth2ProviderRecipe.recipe_id, + app_info=app_info, + config=apply_plugins( + recipe_id=OAuth2ProviderRecipe.recipe_id, + config=config, + plugins=plugins, + ), ) return OAuth2ProviderRecipe.__instance diff --git a/supertokens_python/recipe/oauth2provider/utils.py b/supertokens_python/recipe/oauth2provider/utils.py index 7e6ff7e84..cbea697ca 100644 --- a/supertokens_python/recipe/oauth2provider/utils.py +++ b/supertokens_python/recipe/oauth2provider/utils.py @@ -13,42 +13,34 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable +from supertokens_python.types.config import ( + BaseConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, + BaseOverrideConfig, +) -if TYPE_CHECKING: - from typing import Union +from .interfaces import APIInterface, RecipeInterface - from .interfaces import APIInterface, RecipeInterface +OAuth2ProviderOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedOAuth2ProviderOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] +InputOverrideConfig = OAuth2ProviderOverrideConfig +"""Deprecated, use `OAuth2ProviderOverrideConfig` instead.""" -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class OAuth2ProviderConfig(BaseConfig[RecipeInterface, APIInterface]): ... -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class NormalisedOAuth2ProviderConfig( + BaseNormalisedConfig[RecipeInterface, APIInterface] +): ... -class OAuth2ProviderConfig: - def __init__(self, override: Union[OverrideConfig, None] = None): - self.override = override +def validate_and_normalise_user_input(config: OAuth2ProviderConfig): + override_config = NormalisedOAuth2ProviderOverrideConfig.from_input_config( + override_config=config.override + ) - -def validate_and_normalise_user_input( - override: Union[InputOverrideConfig, None] = None, -): - if override is None: - return OAuth2ProviderConfig(OverrideConfig()) - return OAuth2ProviderConfig(OverrideConfig(override.functions, override.apis)) + return NormalisedOAuth2ProviderConfig(override=override_config) diff --git a/supertokens_python/recipe/openid/__init__.py b/supertokens_python/recipe/openid/__init__.py index 6ca57f2e2..948527a58 100644 --- a/supertokens_python/recipe/openid/__init__.py +++ b/supertokens_python/recipe/openid/__init__.py @@ -13,19 +13,25 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Union from .recipe import OpenIdRecipe -from .utils import InputOverrideConfig +from .utils import InputOverrideConfig, OpenIdOverrideConfig if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( issuer: Union[str, None] = None, - override: Union[InputOverrideConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: + override: Union[OpenIdOverrideConfig, None] = None, +) -> RecipeInit: return OpenIdRecipe.init(issuer, override) + + +__all__ = [ + "InputOverrideConfig", # deprecated, use OpenIdOverrideConfig instead + "OpenIdOverrideConfig", + "OpenIdRecipe", + "init", +] diff --git a/supertokens_python/recipe/openid/interfaces.py b/supertokens_python/recipe/openid/interfaces.py index 0c2fd0ae5..b4dc84a37 100644 --- a/supertokens_python/recipe/openid/interfaces.py +++ b/supertokens_python/recipe/openid/interfaces.py @@ -11,8 +11,8 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Union +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.recipe.jwt.interfaces import ( @@ -20,9 +20,11 @@ CreateJwtResultUnsupportedAlgorithm, GetJWKSResult, ) +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from supertokens_python.types.response import APIResponse, GeneralErrorResponse -from .utils import OpenIdConfig +if TYPE_CHECKING: + from .utils import NormalisedOpenIdConfig class GetOpenIdDiscoveryConfigurationResult: @@ -70,7 +72,7 @@ def to_json(self) -> Dict[str, Any]: } -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): def __init__(self): pass @@ -101,7 +103,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: OpenIdConfig, + config: "NormalisedOpenIdConfig", recipe_implementation: RecipeInterface, ): self.request = request @@ -159,7 +161,7 @@ def to_json(self): } -class APIInterface: +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_open_id_discovery_configuration_get = False diff --git a/supertokens_python/recipe/openid/recipe.py b/supertokens_python/recipe/openid/recipe.py index 6f7abbe59..f1c4325ec 100644 --- a/supertokens_python/recipe/openid/recipe.py +++ b/supertokens_python/recipe/openid/recipe.py @@ -24,7 +24,11 @@ from .exceptions import SuperTokensOpenIdError from .interfaces import APIOptions from .recipe_implementation import RecipeImplementation -from .utils import InputOverrideConfig, validate_and_normalise_user_input +from .utils import ( + OpenIdConfig, + OpenIdOverrideConfig, + validate_and_normalise_user_input, +) if TYPE_CHECKING: from supertokens_python.framework.request import BaseRequest @@ -44,13 +48,14 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - issuer: Union[str, None] = None, - override: Union[InputOverrideConfig, None] = None, + config: OpenIdConfig, ): from supertokens_python.recipe.jwt import JWTRecipe super().__init__(recipe_id, app_info) - self.config = validate_and_normalise_user_input(app_info, issuer, override) + self.config = validate_and_normalise_user_input( + app_info=app_info, config=config + ) self.jwt_recipe = JWTRecipe.get_instance() recipe_implementation = RecipeImplementation( @@ -58,17 +63,11 @@ def __init__( self.config, app_info, ) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) + self.api_implementation = self.config.override.apis(api_implementation) def get_apis_handled(self) -> List[APIHandled]: return [ @@ -129,15 +128,25 @@ def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: @staticmethod def init( issuer: Union[str, None] = None, - override: Union[InputOverrideConfig, None] = None, + override: Union[OpenIdOverrideConfig, None] = None, ): - def func(app_info: AppInfo): + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = OpenIdConfig( + issuer=issuer, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if OpenIdRecipe.__instance is None: OpenIdRecipe.__instance = OpenIdRecipe( - OpenIdRecipe.recipe_id, - app_info, - issuer, - override, + recipe_id=OpenIdRecipe.recipe_id, + app_info=app_info, + config=apply_plugins( + recipe_id=OpenIdRecipe.recipe_id, + config=config, + plugins=plugins, + ), ) return OpenIdRecipe.__instance raise_general_exception( diff --git a/supertokens_python/recipe/openid/recipe_implementation.py b/supertokens_python/recipe/openid/recipe_implementation.py index 72bd9362b..ba97f264d 100644 --- a/supertokens_python/recipe/openid/recipe_implementation.py +++ b/supertokens_python/recipe/openid/recipe_implementation.py @@ -21,7 +21,7 @@ from supertokens_python.supertokens import AppInfo from .interfaces import CreateJwtOkResult, CreateJwtResultUnsupportedAlgorithm - from .utils import OpenIdConfig + from .utils import NormalisedOpenIdConfig from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.recipe.jwt.constants import GET_JWKS_API @@ -81,7 +81,7 @@ async def get_open_id_discovery_configuration( def __init__( self, querier: Querier, - config: OpenIdConfig, + config: NormalisedOpenIdConfig, app_info: AppInfo, ): super().__init__() diff --git a/supertokens_python/recipe/openid/utils.py b/supertokens_python/recipe/openid/utils.py index cb30d0d32..43a346c5a 100644 --- a/supertokens_python/recipe/openid/utils.py +++ b/supertokens_python/recipe/openid/utils.py @@ -13,77 +13,62 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union - -if TYPE_CHECKING: - from supertokens_python import AppInfo - from supertokens_python.recipe.jwt import OverrideConfig as JWTOverrideConfig - - from .interfaces import APIInterface, RecipeInterface +from typing import TYPE_CHECKING, Optional from supertokens_python.normalised_url_domain import NormalisedURLDomain from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.types.config import ( + BaseConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, + BaseOverrideConfig, +) + +from .interfaces import APIInterface, RecipeInterface + +if TYPE_CHECKING: + from supertokens_python import AppInfo -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - jwt_feature: Union[JWTOverrideConfig, None] = None, - ): - self.functions = functions - self.apis = apis - self.jwt_feature = jwt_feature +OpenIdOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedOpenIdOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] +InputOverrideConfig = OpenIdOverrideConfig +"""Deprecated, use `OpenIdOverrideConfig` instead.""" -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class OpenIdConfig(BaseConfig[RecipeInterface, APIInterface]): + issuer: Optional[str] = None -class OpenIdConfig: - def __init__( - self, - override: OverrideConfig, - issuer_domain: NormalisedURLDomain, - issuer_path: NormalisedURLPath, - ): - self.override = override - self.issuer_domain = issuer_domain - self.issuer_path = issuer_path +class NormalisedOpenIdConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): + issuer_domain: NormalisedURLDomain + issuer_path: NormalisedURLPath def validate_and_normalise_user_input( app_info: AppInfo, - issuer: Union[str, None] = None, - override: Union[InputOverrideConfig, None] = None, -): - if issuer is None: + config: OpenIdConfig, +) -> NormalisedOpenIdConfig: + if config.issuer is None: issuer_domain = app_info.api_domain issuer_path = app_info.api_base_path else: - issuer_domain = NormalisedURLDomain(issuer) - issuer_path = NormalisedURLPath(issuer) + issuer_domain = NormalisedURLDomain(config.issuer) + issuer_path = NormalisedURLPath(config.issuer) if not issuer_path.equals(app_info.api_base_path): raise Exception( "The path of the issuer URL must be equal to the apiBasePath. The default value is /auth" ) - if override is not None and not isinstance(override, InputOverrideConfig): # type: ignore - raise ValueError("override must be an instance of InputOverrideConfig or None") - - if override is None: - override = InputOverrideConfig() + override_config = NormalisedOpenIdOverrideConfig.from_input_config( + override_config=config.override + ) - return OpenIdConfig( - OverrideConfig(functions=override.functions, apis=override.apis), - issuer_domain, - issuer_path, + return NormalisedOpenIdConfig( + issuer_domain=issuer_domain, + issuer_path=issuer_path, + override=override_config, ) diff --git a/supertokens_python/recipe/passwordless/__init__.py b/supertokens_python/recipe/passwordless/__init__.py index 4d0648384..336a779c6 100644 --- a/supertokens_python/recipe/passwordless/__init__.py +++ b/supertokens_python/recipe/passwordless/__init__.py @@ -25,31 +25,27 @@ SMSTemplateVars, ) -from . import types, utils -from .emaildelivery import services as emaildelivery_services +from .emaildelivery.services import SMTPService from .recipe import PasswordlessRecipe -from .smsdelivery import services as smsdelivery_services +from .smsdelivery.services import SuperTokensSMSService, TwilioService +from .types import ( + CreateAndSendCustomEmailParameters, + CreateAndSendCustomTextMessageParameters, + EmailDeliveryInterface, + SMSDeliveryInterface, +) +from .utils import ( + ContactConfig, + ContactEmailOnlyConfig, + ContactEmailOrPhoneConfig, + ContactPhoneOnlyConfig, + InputOverrideConfig, + PasswordlessOverrideConfig, + PhoneOrEmailInput, +) if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule - -InputOverrideConfig = utils.OverrideConfig -ContactEmailOnlyConfig = utils.ContactEmailOnlyConfig -ContactConfig = utils.ContactConfig -PhoneOrEmailInput = utils.PhoneOrEmailInput -CreateAndSendCustomTextMessageParameters = ( - types.CreateAndSendCustomTextMessageParameters -) -CreateAndSendCustomEmailParameters = types.CreateAndSendCustomEmailParameters -ContactPhoneOnlyConfig = utils.ContactPhoneOnlyConfig -ContactEmailOrPhoneConfig = utils.ContactEmailOrPhoneConfig -SMTPService = emaildelivery_services.SMTPService -TwilioService = smsdelivery_services.TwilioService -SuperTokensSMSService = smsdelivery_services.SuperTokensSMSService -EmailDeliveryInterface = types.EmailDeliveryInterface -SMSDeliveryInterface = types.SMSDeliveryInterface + from supertokens_python.supertokens import RecipeInit def init( @@ -57,13 +53,13 @@ def init( flow_type: Literal[ "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK" ], - override: Union[InputOverrideConfig, None] = None, + override: Union[PasswordlessOverrideConfig, None] = None, get_custom_user_input_code: Union[ Callable[[str, Dict[str, Any]], Awaitable[str]], None ] = None, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, sms_delivery: Union[SMSDeliveryConfig[SMSTemplateVars], None] = None, -) -> Callable[[AppInfo], RecipeModule]: +) -> RecipeInit: return PasswordlessRecipe.init( contact_config, flow_type, @@ -72,3 +68,22 @@ def init( email_delivery, sms_delivery, ) + + +__all__ = [ + "ContactConfig", + "ContactEmailOnlyConfig", + "ContactEmailOrPhoneConfig", + "ContactPhoneOnlyConfig", + "CreateAndSendCustomEmailParameters", + "CreateAndSendCustomTextMessageParameters", + "EmailDeliveryInterface", + "InputOverrideConfig", # deprecated, use PasswordlessOverrideConfig instead + "PasswordlessOverrideConfig", + "PhoneOrEmailInput", + "SMSDeliveryInterface", + "SMTPService", + "SuperTokensSMSService", + "TwilioService", + "init", +] diff --git a/supertokens_python/recipe/passwordless/interfaces.py b/supertokens_python/recipe/passwordless/interfaces.py index aae9347cc..165097275 100644 --- a/supertokens_python/recipe/passwordless/interfaces.py +++ b/supertokens_python/recipe/passwordless/interfaces.py @@ -13,8 +13,8 @@ # under the License. from __future__ import annotations -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Union +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing_extensions import Literal @@ -26,6 +26,7 @@ User, ) from supertokens_python.types.auth_utils import LinkingToSessionUserFailedError +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from supertokens_python.types.response import APIResponse, GeneralErrorResponse from ...supertokens import AppInfo @@ -37,7 +38,9 @@ PasswordlessLoginSMSTemplateVars, SMSDeliveryIngredient, ) -from .utils import PasswordlessConfig + +if TYPE_CHECKING: + from .utils import NormalisedPasswordlessConfig class CreateCodeOkResult: @@ -214,7 +217,7 @@ def __init__(self, reason: str): self.reason = reason -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): def __init__(self): pass @@ -358,7 +361,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: PasswordlessConfig, + config: NormalisedPasswordlessConfig, recipe_implementation: RecipeInterface, app_info: AppInfo, email_delivery: EmailDeliveryIngredient[PasswordlessLoginEmailTemplateVars], @@ -504,7 +507,7 @@ def to_json(self) -> Dict[str, Any]: return {"status": self.status, "reason": self.reason} -class APIInterface: +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_create_code_post = False self.disable_resend_code_post = False diff --git a/supertokens_python/recipe/passwordless/recipe.py b/supertokens_python/recipe/passwordless/recipe.py index a332cb019..043b9272d 100644 --- a/supertokens_python/recipe/passwordless/recipe.py +++ b/supertokens_python/recipe/passwordless/recipe.py @@ -72,7 +72,8 @@ from .recipe_implementation import RecipeImplementation from .utils import ( ContactConfig, - OverrideConfig, + PasswordlessConfig, + PasswordlessOverrideConfig, get_enabled_pwless_factors, validate_and_normalise_user_input, ) @@ -98,46 +99,22 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - contact_config: ContactConfig, - flow_type: Literal[ - "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK" - ], ingredients: PasswordlessIngredients, - override: Union[OverrideConfig, None] = None, - get_custom_user_input_code: Union[ - Callable[[str, Dict[str, Any]], Awaitable[str]], None - ] = None, - email_delivery: Union[ - EmailDeliveryConfig[PasswordlessLoginEmailTemplateVars], None - ] = None, - sms_delivery: Union[ - SMSDeliveryConfig[PasswordlessLoginSMSTemplateVars], None - ] = None, + config: PasswordlessConfig, ): super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( - app_info, - contact_config, - flow_type, - override, - get_custom_user_input_code, - email_delivery, - sms_delivery, + app_info=app_info, + config=config, ) recipe_implementation = RecipeImplementation(Querier.get_instance(recipe_id)) - self.recipe_implementation: RecipeInterface = ( + self.recipe_implementation: RecipeInterface = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) + self.api_implementation = self.config.override.apis(api_implementation) email_delivery_ingredient = ingredients.email_delivery if email_delivery_ingredient is None: @@ -508,7 +485,7 @@ def init( flow_type: Literal[ "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK" ], - override: Union[OverrideConfig, None] = None, + override: Optional[PasswordlessOverrideConfig] = None, get_custom_user_input_code: Union[ Callable[[str, Dict[str, Any]], Awaitable[str]], None ] = None, @@ -519,19 +496,29 @@ def init( SMSDeliveryConfig[PasswordlessLoginSMSTemplateVars], None ] = None, ): - def func(app_info: AppInfo): + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = PasswordlessConfig( + contact_config=contact_config, + get_custom_user_input_code=get_custom_user_input_code, + email_delivery=email_delivery, + sms_delivery=sms_delivery, + flow_type=flow_type, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if PasswordlessRecipe.__instance is None: ingredients = PasswordlessIngredients(None, None) PasswordlessRecipe.__instance = PasswordlessRecipe( - PasswordlessRecipe.recipe_id, - app_info, - contact_config, - flow_type, - ingredients, - override, - get_custom_user_input_code, - email_delivery, - sms_delivery, + recipe_id=PasswordlessRecipe.recipe_id, + app_info=app_info, + ingredients=ingredients, + config=apply_plugins( + recipe_id=PasswordlessRecipe.recipe_id, + config=config, + plugins=plugins, + ), ) return PasswordlessRecipe.__instance raise_general_exception( diff --git a/supertokens_python/recipe/passwordless/utils.py b/supertokens_python/recipe/passwordless/utils.py index b61615de5..6425b6739 100644 --- a/supertokens_python/recipe/passwordless/utils.py +++ b/supertokens_python/recipe/passwordless/utils.py @@ -15,8 +15,10 @@ from __future__ import annotations from abc import ABC +from re import fullmatch from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Union +from phonenumbers import is_valid_number, parse from typing_extensions import Literal from supertokens_python.ingredients.emaildelivery.types import ( @@ -28,29 +30,30 @@ SMSDeliveryConfigWithService, ) from supertokens_python.recipe.multifactorauth.types import FactorIds -from supertokens_python.recipe.passwordless.types import ( - PasswordlessLoginSMSTemplateVars, -) - -if TYPE_CHECKING: - from supertokens_python import AppInfo - - from .interfaces import ( - APIInterface, - PasswordlessLoginEmailTemplateVars, - RecipeInterface, - ) - -from re import fullmatch - -from phonenumbers import is_valid_number, parse # type: ignore - from supertokens_python.recipe.passwordless.emaildelivery.services.backward_compatibility import ( BackwardCompatibilityService, ) from supertokens_python.recipe.passwordless.smsdelivery.services.backward_compatibility import ( BackwardCompatibilityService as SMSBackwardCompatibilityService, ) +from supertokens_python.recipe.passwordless.types import ( + PasswordlessLoginSMSTemplateVars, +) +from supertokens_python.types.config import ( + BaseConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, + BaseOverrideConfig, +) + +from .interfaces import ( + APIInterface, + PasswordlessLoginEmailTemplateVars, + RecipeInterface, +) + +if TYPE_CHECKING: + from supertokens_python import AppInfo async def default_validate_phone_number(value: str, _tenant_id: str): @@ -68,14 +71,12 @@ async def default_validate_email(value: str, _tenant_id: str): return "Email is invalid" -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +PasswordlessOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedPasswordlessOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] +InputOverrideConfig = PasswordlessOverrideConfig +"""Deprecated, use `PasswordlessOverrideConfig` instead.""" class ContactConfig(ABC): @@ -142,64 +143,61 @@ def __init__(self, phone_number: Union[str, None], email: Union[str, None]): self.email = email -class PasswordlessConfig: - def __init__( - self, - contact_config: ContactConfig, - override: OverrideConfig, - flow_type: Literal[ - "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK" - ], - get_email_delivery_config: Callable[ - [], EmailDeliveryConfigWithService[PasswordlessLoginEmailTemplateVars] - ], - get_sms_delivery_config: Callable[ - [], SMSDeliveryConfigWithService[PasswordlessLoginSMSTemplateVars] - ], - get_custom_user_input_code: Union[ - Callable[[str, Dict[str, Any]], Awaitable[str]], None - ] = None, - ): - self.contact_config = contact_config - self.override = override - self.flow_type: Literal[ - "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK" - ] = flow_type - self.get_custom_user_input_code = get_custom_user_input_code - self.get_email_delivery_config = get_email_delivery_config - self.get_sms_delivery_config = get_sms_delivery_config - - -def validate_and_normalise_user_input( - app_info: AppInfo, - contact_config: ContactConfig, +class PasswordlessConfig(BaseConfig[RecipeInterface, APIInterface]): + contact_config: ContactConfig flow_type: Literal[ "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK" - ], - override: Union[OverrideConfig, None] = None, + ] get_custom_user_input_code: Union[ Callable[[str, Dict[str, Any]], Awaitable[str]], None - ] = None, + ] = None email_delivery: Union[ EmailDeliveryConfig[PasswordlessLoginEmailTemplateVars], None - ] = None, - sms_delivery: Union[ - SMSDeliveryConfig[PasswordlessLoginSMSTemplateVars], None - ] = None, -) -> PasswordlessConfig: - if override is None: - override = OverrideConfig() + ] = None + sms_delivery: Union[SMSDeliveryConfig[PasswordlessLoginSMSTemplateVars], None] = ( + None + ) + + +class NormalisedPasswordlessConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): + contact_config: ContactConfig + flow_type: Literal[ + "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK" + ] + get_email_delivery_config: Callable[ + [], EmailDeliveryConfigWithService[PasswordlessLoginEmailTemplateVars] + ] + get_sms_delivery_config: Callable[ + [], SMSDeliveryConfigWithService[PasswordlessLoginSMSTemplateVars] + ] + get_custom_user_input_code: Union[ + Callable[[str, Dict[str, Any]], Awaitable[str]], None + ] + + +def validate_and_normalise_user_input( + app_info: AppInfo, + config: PasswordlessConfig, +) -> NormalisedPasswordlessConfig: + override_config = NormalisedPasswordlessOverrideConfig.from_input_config( + override_config=config.override + ) def get_email_delivery_config() -> EmailDeliveryConfigWithService[ PasswordlessLoginEmailTemplateVars ]: - email_service = email_delivery.service if email_delivery is not None else None + email_service = ( + config.email_delivery.service if config.email_delivery is not None else None + ) if email_service is None: email_service = BackwardCompatibilityService(app_info) - if email_delivery is not None and email_delivery.override is not None: - override = email_delivery.override + if ( + config.email_delivery is not None + and config.email_delivery.override is not None + ): + override = config.email_delivery.override else: override = None @@ -208,22 +206,24 @@ def get_email_delivery_config() -> EmailDeliveryConfigWithService[ def get_sms_delivery_config() -> SMSDeliveryConfigWithService[ PasswordlessLoginSMSTemplateVars ]: - sms_service = sms_delivery.service if sms_delivery is not None else None + sms_service = ( + config.sms_delivery.service if config.sms_delivery is not None else None + ) if sms_service is None: sms_service = SMSBackwardCompatibilityService(app_info) - if sms_delivery is not None and sms_delivery.override is not None: - override = sms_delivery.override + if config.sms_delivery is not None and config.sms_delivery.override is not None: + override = config.sms_delivery.override else: override = None return SMSDeliveryConfigWithService(sms_service, override=override) - if not isinstance(contact_config, ContactConfig): # type: ignore user might not have linter enabled + if not isinstance(config.contact_config, ContactConfig): # type: ignore user might not have linter enabled raise ValueError("contact_config must be of type ContactConfig") - if flow_type not in [ + if config.flow_type not in [ "USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK", @@ -232,21 +232,18 @@ def get_sms_delivery_config() -> SMSDeliveryConfigWithService[ "flow_type must be one of USER_INPUT_CODE, MAGIC_LINK, USER_INPUT_CODE_AND_MAGIC_LINK" ) - if not isinstance(override, OverrideConfig): # type: ignore user might not have linter enabled - raise ValueError("override must be of type OverrideConfig") - - return PasswordlessConfig( - contact_config=contact_config, - override=OverrideConfig(functions=override.functions, apis=override.apis), - flow_type=flow_type, + return NormalisedPasswordlessConfig( + contact_config=config.contact_config, + override=override_config, + flow_type=config.flow_type, get_email_delivery_config=get_email_delivery_config, get_sms_delivery_config=get_sms_delivery_config, - get_custom_user_input_code=get_custom_user_input_code, + get_custom_user_input_code=config.get_custom_user_input_code, ) def get_enabled_pwless_factors( - config: PasswordlessConfig, + config: NormalisedPasswordlessConfig, ) -> List[str]: all_factors: List[str] = [] diff --git a/supertokens_python/recipe/session/__init__.py b/supertokens_python/recipe/session/__init__.py index 5e4808f3e..497348028 100644 --- a/supertokens_python/recipe/session/__init__.py +++ b/supertokens_python/recipe/session/__init__.py @@ -17,20 +17,19 @@ from typing_extensions import Literal -if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo, BaseRequest - - from ...recipe_module import RecipeModule - from .utils import TokenTransferMethod +from supertokens_python.framework import BaseRequest -from . import exceptions as ex -from . import interfaces, utils +from .interfaces import SessionContainer from .recipe import SessionRecipe +from .utils import ( + InputErrorHandlers, + InputOverrideConfig, + SessionOverrideConfig, + TokenTransferMethod, +) -InputErrorHandlers = utils.InputErrorHandlers -InputOverrideConfig = utils.InputOverrideConfig -SessionContainer = interfaces.SessionContainer -exceptions = ex +if TYPE_CHECKING: + from supertokens_python.supertokens import RecipeInit def init( @@ -48,12 +47,12 @@ def init( None, ] = None, error_handlers: Union[InputErrorHandlers, None] = None, - override: Union[InputOverrideConfig, None] = None, + override: Union[SessionOverrideConfig, None] = None, invalid_claim_status_code: Union[int, None] = None, use_dynamic_access_token_signing_key: Union[bool, None] = None, expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None, jwks_refresh_interval_sec: Union[int, None] = None, -) -> Callable[[AppInfo], RecipeModule]: +) -> RecipeInit: return SessionRecipe.init( cookie_domain, older_cookie_domain, @@ -69,3 +68,14 @@ def init( expose_access_token_to_frontend_in_cookie_based_auth, jwks_refresh_interval_sec, ) + + +__all__ = [ + "InputErrorHandlers", + "InputOverrideConfig", # deprecated, use SessionOverrideConfig instead + "SessionContainer", + "SessionOverrideConfig", + "SessionRecipe", + "TokenTransferMethod", + "init", +] diff --git a/supertokens_python/recipe/session/access_token.py b/supertokens_python/recipe/session/access_token.py index 1205b6c7b..cdfd23abe 100644 --- a/supertokens_python/recipe/session/access_token.py +++ b/supertokens_python/recipe/session/access_token.py @@ -21,7 +21,7 @@ from supertokens_python.logger import log_debug_message from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID from supertokens_python.recipe.session.jwks import get_latest_keys -from supertokens_python.recipe.session.utils import SessionConfig +from supertokens_python.recipe.session.utils import NormalisedSessionConfig from supertokens_python.utils import get_timestamp_ms from .exceptions import raise_try_refresh_token_exception @@ -46,7 +46,7 @@ def sanitize_number(n: Any) -> Union[Union[int, float], None]: def get_info_from_access_token( - config: SessionConfig, + config: NormalisedSessionConfig, jwt_info: ParsedJWTInfo, do_anti_csrf_check: bool, ): diff --git a/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py b/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py index 77a2ac061..0845680c1 100644 --- a/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py +++ b/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py @@ -12,7 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union, cast from supertokens_python.types import MaybeAwaitable, RecipeUserId from supertokens_python.utils import get_timestamp_ms @@ -105,7 +105,7 @@ async def _validate( # Doing this to ensure same code in the upcoming steps irrespective of # whether self.val is Primitive or PrimitiveList - vals: List[_T] = val if isinstance(val, list) else [val] + vals: List[_T] = cast(List[_T], val if isinstance(val, list) else [val]) claim_val_set = set(claim_val) if is_include and not is_include_any: diff --git a/supertokens_python/recipe/session/cookie_and_header.py b/supertokens_python/recipe/session/cookie_and_header.py index 1a7aa2e50..18e4a40b1 100644 --- a/supertokens_python/recipe/session/cookie_and_header.py +++ b/supertokens_python/recipe/session/cookie_and_header.py @@ -45,7 +45,7 @@ from .recipe import SessionRecipe from .utils import ( - SessionConfig, + NormalisedSessionConfig, TokenTransferMethod, TokenType, ) @@ -111,7 +111,7 @@ def get_cookie(request: BaseRequest, key: str): def _set_cookie( response: BaseResponse, - config: SessionConfig, + config: NormalisedSessionConfig, key: str, value: str, expires: int, @@ -141,7 +141,7 @@ def _set_cookie( def set_cookie_response_mutator( - config: SessionConfig, + config: NormalisedSessionConfig, key: str, value: str, expires: int, @@ -207,7 +207,7 @@ def clear_session_from_all_token_transfer_methods( def clear_session_mutator( - config: SessionConfig, + config: NormalisedSessionConfig, transfer_method: TokenTransferMethod, request: BaseRequest, ): @@ -222,7 +222,7 @@ def mutator( def _clear_session( response: BaseResponse, - config: SessionConfig, + config: NormalisedSessionConfig, transfer_method: TokenTransferMethod, request: BaseRequest, user_context: Dict[str, Any], @@ -244,7 +244,7 @@ def _clear_session( def clear_session_response_mutator( - config: SessionConfig, + config: NormalisedSessionConfig, transfer_method: TokenTransferMethod, request: BaseRequest, ): @@ -293,7 +293,7 @@ def get_token( def _set_token( response: BaseResponse, - config: SessionConfig, + config: NormalisedSessionConfig, token_type: TokenType, value: str, expires: int, @@ -323,7 +323,7 @@ def _set_token( def token_response_mutator( - config: SessionConfig, + config: NormalisedSessionConfig, token_type: TokenType, value: str, expires: int, @@ -356,7 +356,7 @@ def set_token_in_header(response: BaseResponse, name: str, value: str): def access_token_mutator( access_token: str, front_token: str, - config: SessionConfig, + config: NormalisedSessionConfig, transfer_method: TokenTransferMethod, request: BaseRequest, ): @@ -381,7 +381,7 @@ def _set_access_token_in_response( res: BaseResponse, access_token: str, front_token: str, - config: SessionConfig, + config: NormalisedSessionConfig, transfer_method: TokenTransferMethod, request: BaseRequest, user_context: Dict[str, Any], @@ -430,7 +430,7 @@ def _set_access_token_in_response( # This function checks for multiple cookies with the same name and clears the cookies for the older domain. def clear_session_cookies_from_older_cookie_domain( - request: BaseRequest, config: SessionConfig, user_context: Dict[str, Any] + request: BaseRequest, config: NormalisedSessionConfig, user_context: Dict[str, Any] ): allowed_transfer_method = config.get_token_transfer_method( request, False, user_context diff --git a/supertokens_python/recipe/session/interfaces.py b/supertokens_python/recipe/session/interfaces.py index 30cf85f54..fd8d79f10 100644 --- a/supertokens_python/recipe/session/interfaces.py +++ b/supertokens_python/recipe/session/interfaces.py @@ -29,20 +29,19 @@ from typing_extensions import TypedDict from supertokens_python.async_to_sync_wrapper import sync +from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.types import ( MaybeAwaitable, RecipeUserId, ) +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface from supertokens_python.types.response import APIResponse, GeneralErrorResponse from ...utils import resolve from .exceptions import ClaimValidationError -from .utils import SessionConfig, TokenTransferMethod if TYPE_CHECKING: - from supertokens_python.framework import BaseRequest - -from supertokens_python.framework import BaseResponse + from .utils import NormalisedSessionConfig, TokenTransferMethod class SessionObj: @@ -187,7 +186,7 @@ class GetSessionTokensDangerouslyDict(TypedDict): antiCsrfToken: Optional[str] -class RecipeInterface(ABC): # pylint: disable=too-many-public-methods +class RecipeInterface(BaseRecipeInterface): # pylint: disable=too-many-public-methods def __init__(self): pass @@ -373,7 +372,7 @@ def __init__( request: BaseRequest, response: Optional[BaseResponse], recipe_id: str, - config: SessionConfig, + config: NormalisedSessionConfig, recipe_implementation: RecipeInterface, ): self.request = request @@ -383,7 +382,7 @@ def __init__( self.recipe_implementation = recipe_implementation -class APIInterface(ABC): +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_refresh_post = False self.disable_signout_post = False @@ -446,7 +445,7 @@ class SessionContainer(ABC): # pylint: disable=too-many-public-methods def __init__( self, recipe_implementation: RecipeInterface, - config: SessionConfig, + config: NormalisedSessionConfig, access_token: str, front_token: str, refresh_token: Optional[TokenInfo], diff --git a/supertokens_python/recipe/session/jwks.py b/supertokens_python/recipe/session/jwks.py index e2eb39c7d..f4f8aa35d 100644 --- a/supertokens_python/recipe/session/jwks.py +++ b/supertokens_python/recipe/session/jwks.py @@ -21,7 +21,7 @@ from supertokens_python.logger import log_debug_message from supertokens_python.querier import Querier -from supertokens_python.recipe.session.utils import SessionConfig +from supertokens_python.recipe.session.utils import NormalisedSessionConfig from supertokens_python.utils import RWLockContext, RWMutex, get_timestamp_ms @@ -88,7 +88,9 @@ def find_matching_keys( return None -def get_latest_keys(config: SessionConfig, kid: Optional[str] = None) -> List[PyJWK]: +def get_latest_keys( + config: NormalisedSessionConfig, kid: Optional[str] = None +) -> List[PyJWK]: global cached_keys if environ.get("SUPERTOKENS_ENV") == "testing": diff --git a/supertokens_python/recipe/session/recipe.py b/supertokens_python/recipe/session/recipe.py index 5d025026a..29764de5a 100644 --- a/supertokens_python/recipe/session/recipe.py +++ b/supertokens_python/recipe/session/recipe.py @@ -18,10 +18,18 @@ from typing_extensions import Literal +from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.framework.response import BaseResponse +from supertokens_python.logger import log_debug_message +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.querier import Querier +from supertokens_python.recipe_module import APIHandled, RecipeModule from ...types import MaybeAwaitable +from .api import handle_refresh_api, handle_signout_api +from .constants import SESSION_REFRESH, SIGNOUT from .cookie_and_header import ( + clear_session_from_all_token_transfer_methods, get_cors_allowed_headers, ) from .exceptions import ( @@ -31,20 +39,6 @@ TokenTheftError, UnauthorisedError, ) - -if TYPE_CHECKING: - from supertokens_python.framework import BaseRequest - from supertokens_python.supertokens import AppInfo - -from supertokens_python.exceptions import SuperTokensError, raise_general_exception -from supertokens_python.logger import log_debug_message -from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.querier import Querier -from supertokens_python.recipe_module import APIHandled, RecipeModule - -from .api import handle_refresh_api, handle_signout_api -from .constants import SESSION_REFRESH, SIGNOUT -from .cookie_and_header import clear_session_from_all_token_transfer_methods from .interfaces import ( APIInterface, APIOptions, @@ -58,11 +52,16 @@ ) from .utils import ( InputErrorHandlers, - InputOverrideConfig, + SessionConfig, + SessionOverrideConfig, TokenTransferMethod, validate_and_normalise_user_input, ) +if TYPE_CHECKING: + from supertokens_python.framework import BaseRequest + from supertokens_python.supertokens import AppInfo + class SessionRecipe(RecipeModule): recipe_id = "session" @@ -72,44 +71,12 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - cookie_domain: Union[str, None] = None, - older_cookie_domain: Union[str, None] = None, - cookie_secure: Union[bool, None] = None, - cookie_same_site: Union[Literal["lax", "none", "strict"], None] = None, - session_expired_status_code: Union[int, None] = None, - anti_csrf: Union[ - Literal["VIA_TOKEN", "VIA_CUSTOM_HEADER", "NONE"], None - ] = None, - get_token_transfer_method: Union[ - Callable[ - [BaseRequest, bool, Dict[str, Any]], - Union[TokenTransferMethod, Literal["any"]], - ], - None, - ] = None, - error_handlers: Union[InputErrorHandlers, None] = None, - override: Union[InputOverrideConfig, None] = None, - invalid_claim_status_code: Union[int, None] = None, - use_dynamic_access_token_signing_key: Union[bool, None] = None, - expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None, - jwks_refresh_interval_sec: Union[int, None] = None, + config: SessionConfig, ): super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( - app_info, - cookie_domain, - older_cookie_domain, - cookie_secure, - cookie_same_site, - session_expired_status_code, - anti_csrf, - get_token_transfer_method, - error_handlers, - override, - invalid_claim_status_code, - use_dynamic_access_token_signing_key, - expose_access_token_to_frontend_in_cookie_based_auth, - jwks_refresh_interval_sec, + app_info=app_info, + config=config, ) log_debug_message( "session init: anti_csrf: %s", self.config.anti_csrf_function_or_string @@ -123,8 +90,10 @@ def __init__( # we check the input cookie_same_site because the normalised version is # always a function. - if cookie_same_site is not None: - log_debug_message("session init: cookie_same_site: %s", cookie_same_site) + if config.cookie_same_site is not None: + log_debug_message( + "session init: cookie_same_site: %s", config.cookie_same_site + ) else: log_debug_message("session init: cookie_same_site: function") @@ -142,19 +111,15 @@ def __init__( recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self.config, self.app_info ) - self.recipe_implementation: RecipeInterface = ( + self.recipe_implementation: RecipeInterface = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) from .api.implementation import APIImplementation api_implementation = APIImplementation() - self.api_implementation: APIInterface = ( + self.api_implementation: APIInterface = self.config.override.apis( api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) ) self.claims_added_by_other_recipes: List[SessionClaim[Any]] = [] @@ -290,30 +255,40 @@ def init( None, ] = None, error_handlers: Union[InputErrorHandlers, None] = None, - override: Union[InputOverrideConfig, None] = None, + override: Union[SessionOverrideConfig, None] = None, invalid_claim_status_code: Union[int, None] = None, use_dynamic_access_token_signing_key: Union[bool, None] = None, expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None, jwks_refresh_interval_sec: Union[int, None] = None, ): - def func(app_info: AppInfo): + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = SessionConfig( + cookie_domain=cookie_domain, + older_cookie_domain=older_cookie_domain, + cookie_secure=cookie_secure, + cookie_same_site=cookie_same_site, + session_expired_status_code=session_expired_status_code, + anti_csrf=anti_csrf, + get_token_transfer_method=get_token_transfer_method, + error_handlers=error_handlers, + override=override, + invalid_claim_status_code=invalid_claim_status_code, + use_dynamic_access_token_signing_key=use_dynamic_access_token_signing_key, + expose_access_token_to_frontend_in_cookie_based_auth=expose_access_token_to_frontend_in_cookie_based_auth, + jwks_refresh_interval_sec=jwks_refresh_interval_sec, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if SessionRecipe.__instance is None: SessionRecipe.__instance = SessionRecipe( - SessionRecipe.recipe_id, - app_info, - cookie_domain, - older_cookie_domain, - cookie_secure, - cookie_same_site, - session_expired_status_code, - anti_csrf, - get_token_transfer_method, - error_handlers, - override, - invalid_claim_status_code, - use_dynamic_access_token_signing_key, - expose_access_token_to_frontend_in_cookie_based_auth, - jwks_refresh_interval_sec, + recipe_id=SessionRecipe.recipe_id, + app_info=app_info, + config=apply_plugins( + recipe_id=SessionRecipe.recipe_id, + config=config, + plugins=plugins, + ), ) return SessionRecipe.__instance raise_general_exception( diff --git a/supertokens_python/recipe/session/recipe_implementation.py b/supertokens_python/recipe/session/recipe_implementation.py index aa073a872..750d721e4 100644 --- a/supertokens_python/recipe/session/recipe_implementation.py +++ b/supertokens_python/recipe/session/recipe_implementation.py @@ -39,7 +39,7 @@ ) from .jwt import ParsedJWTInfo, parse_jwt_without_signature_verification from .session_class import Session -from .utils import SessionConfig, validate_claims_in_payload +from .utils import NormalisedSessionConfig, validate_claims_in_payload if TYPE_CHECKING: from typing import List, Union @@ -54,7 +54,9 @@ class RecipeImplementation(RecipeInterface): # pylint: disable=too-many-public-methods - def __init__(self, querier: Querier, config: SessionConfig, app_info: AppInfo): + def __init__( + self, querier: Querier, config: NormalisedSessionConfig, app_info: AppInfo + ): super().__init__() self.querier = querier self.config = config diff --git a/supertokens_python/recipe/session/session_request_functions.py b/supertokens_python/recipe/session/session_request_functions.py index 40e6c6d58..1262ea1f6 100644 --- a/supertokens_python/recipe/session/session_request_functions.py +++ b/supertokens_python/recipe/session/session_request_functions.py @@ -47,12 +47,11 @@ parse_jwt_without_signature_verification, ) from supertokens_python.recipe.session.utils import ( - SessionConfig, + NormalisedSessionConfig, TokenTransferMethod, get_auth_mode_from_header, get_required_claim_validators, ) -from supertokens_python.supertokens import Supertokens from supertokens_python.types import MaybeAwaitable, RecipeUserId from supertokens_python.utils import ( FRAMEWORKS, @@ -75,7 +74,7 @@ async def get_session_from_request( request: Any, - config: SessionConfig, + config: NormalisedSessionConfig, recipe_interface_impl: SessionRecipeInterface, session_required: Optional[bool] = None, anti_csrf_check: Optional[bool] = None, @@ -88,6 +87,8 @@ async def get_session_from_request( ] = None, user_context: Optional[Dict[str, Any]] = None, ) -> Optional[SessionContainer]: + from supertokens_python.supertokens import Supertokens + log_debug_message("getSession: Started") if not hasattr(request, "wrapper_used") or not request.wrapper_used: @@ -240,11 +241,13 @@ async def create_new_session_in_request( access_token_payload: Dict[str, Any], user_id: str, recipe_user_id: RecipeUserId, - config: SessionConfig, + config: NormalisedSessionConfig, app_info: AppInfo, session_data_in_database: Dict[str, Any], tenant_id: str, ) -> SessionContainer: + from supertokens_python.supertokens import Supertokens + log_debug_message("createNewSession: Started") # Handling framework specific request/response wrapping @@ -353,9 +356,11 @@ async def create_new_session_in_request( async def refresh_session_in_request( request: Any, user_context: Dict[str, Any], - config: SessionConfig, + config: NormalisedSessionConfig, recipe_interface_impl: SessionRecipeInterface, ) -> SessionContainer: + from supertokens_python.supertokens import Supertokens + log_debug_message("refreshSession: Started") response_mutators: List[ResponseMutator] = [] diff --git a/supertokens_python/recipe/session/utils.py b/supertokens_python/recipe/session/utils.py index bae607c22..30fe626e9 100644 --- a/supertokens_python/recipe/session/utils.py +++ b/supertokens_python/recipe/session/utils.py @@ -20,8 +20,14 @@ from typing_extensions import Literal from supertokens_python.exceptions import raise_general_exception -from supertokens_python.framework import BaseResponse +from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.types.config import ( + BaseConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, + BaseOverrideConfig, +) from supertokens_python.utils import ( is_an_ip_address, resolve, @@ -33,20 +39,16 @@ from ...types import MaybeAwaitable, RecipeUserId from .constants import AUTH_MODE_HEADER_KEY, SESSION_REFRESH from .exceptions import ClaimValidationError +from .interfaces import ( + APIInterface, + RecipeInterface, + SessionClaimValidator, + SessionContainer, +) if TYPE_CHECKING: - from supertokens_python.framework import BaseRequest - from supertokens_python.recipe.openid import ( - InputOverrideConfig as OpenIdInputOverrideConfig, - ) from supertokens_python.supertokens import AppInfo - from .interfaces import ( - APIInterface, - RecipeInterface, - SessionClaimValidator, - SessionContainer, - ) from .recipe import SessionRecipe from supertokens_python.logger import log_debug_message @@ -334,141 +336,112 @@ def get_token_transfer_method_default( return "any" -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - openid_feature: Union[OpenIdInputOverrideConfig, None] = None, - ): - self.functions = functions - self.apis = apis - self.openid_feature = openid_feature - - -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +SessionOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedSessionOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] +InputOverrideConfig = SessionOverrideConfig +"""Deprecated: Use `SessionOverrideConfig` instead.""" TokenType = Literal["access", "refresh"] TokenTransferMethod = Literal["cookie", "header"] -class SessionConfig: - def __init__( - self, - refresh_token_path: NormalisedURLPath, - cookie_domain: Union[None, str], - older_cookie_domain: Union[None, str], - get_cookie_same_site: Callable[ - [Optional[BaseRequest], Dict[str, Any]], - Literal["lax", "strict", "none"], - ], - cookie_secure: bool, - session_expired_status_code: int, - error_handlers: ErrorHandlers, - anti_csrf_function_or_string: Union[ - Callable[ - [Optional[BaseRequest], Dict[str, Any]], - Literal["VIA_CUSTOM_HEADER", "NONE"], - ], - Literal["VIA_CUSTOM_HEADER", "NONE", "VIA_TOKEN"], - ], - get_token_transfer_method: Callable[ +class SessionConfig(BaseConfig[RecipeInterface, APIInterface]): + cookie_domain: Union[str, None] = None + older_cookie_domain: Union[str, None] = None + cookie_secure: Union[bool, None] = None + cookie_same_site: Union[Literal["lax", "strict", "none"], None] = None + session_expired_status_code: Union[int, None] = None + anti_csrf: Union[Literal["VIA_TOKEN", "VIA_CUSTOM_HEADER", "NONE"], None] = None + get_token_transfer_method: Union[ + Callable[ [BaseRequest, bool, Dict[str, Any]], Union[TokenTransferMethod, Literal["any"]], ], - override: OverrideConfig, - framework: str, - mode: str, - invalid_claim_status_code: int, - use_dynamic_access_token_signing_key: bool, - expose_access_token_to_frontend_in_cookie_based_auth: bool, - jwks_refresh_interval_sec: int, - ): - self.session_expired_status_code = session_expired_status_code - self.invalid_claim_status_code = invalid_claim_status_code - self.use_dynamic_access_token_signing_key = use_dynamic_access_token_signing_key - self.expose_access_token_to_frontend_in_cookie_based_auth = ( - expose_access_token_to_frontend_in_cookie_based_auth - ) - - self.refresh_token_path = refresh_token_path - self.cookie_domain = cookie_domain - self.older_cookie_domain = older_cookie_domain - self.get_cookie_same_site = get_cookie_same_site - self.cookie_secure = cookie_secure - self.error_handlers = error_handlers - self.anti_csrf_function_or_string = anti_csrf_function_or_string - self.get_token_transfer_method = get_token_transfer_method - self.override = override - self.framework = framework - self.mode = mode - self.jwks_refresh_interval_sec = jwks_refresh_interval_sec + None, + ] = None + error_handlers: Union[ErrorHandlers, None] = None + invalid_claim_status_code: Union[int, None] = None + use_dynamic_access_token_signing_key: Union[bool, None] = None + expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None + jwks_refresh_interval_sec: Union[int, None] = None + + +class NormalisedSessionConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): + refresh_token_path: NormalisedURLPath + cookie_domain: Union[None, str] + older_cookie_domain: Union[None, str] + get_cookie_same_site: Callable[ + [Optional[BaseRequest], Dict[str, Any]], + Literal["lax", "strict", "none"], + ] + cookie_secure: bool + session_expired_status_code: int + error_handlers: ErrorHandlers + anti_csrf_function_or_string: Union[ + Callable[ + [Optional[BaseRequest], Dict[str, Any]], + Literal["VIA_CUSTOM_HEADER", "NONE"], + ], + Literal["VIA_CUSTOM_HEADER", "NONE", "VIA_TOKEN"], + ] + get_token_transfer_method: Callable[ + [BaseRequest, bool, Dict[str, Any]], + Union[TokenTransferMethod, Literal["any"]], + ] + framework: str + mode: str + invalid_claim_status_code: int + use_dynamic_access_token_signing_key: bool + expose_access_token_to_frontend_in_cookie_based_auth: bool + jwks_refresh_interval_sec: int def validate_and_normalise_user_input( app_info: AppInfo, - cookie_domain: Union[str, None] = None, - older_cookie_domain: Union[str, None] = None, - cookie_secure: Union[bool, None] = None, - cookie_same_site: Union[Literal["lax", "strict", "none"], None] = None, - session_expired_status_code: Union[int, None] = None, - anti_csrf: Union[Literal["VIA_TOKEN", "VIA_CUSTOM_HEADER", "NONE"], None] = None, - get_token_transfer_method: Union[ - Callable[ - [BaseRequest, bool, Dict[str, Any]], - Union[TokenTransferMethod, Literal["any"]], - ], - None, - ] = None, - error_handlers: Union[ErrorHandlers, None] = None, - override: Union[InputOverrideConfig, None] = None, - invalid_claim_status_code: Union[int, None] = None, - use_dynamic_access_token_signing_key: Union[bool, None] = None, - expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None, - jwks_refresh_interval_sec: Union[int, None] = None, + config: SessionConfig, ): - _ = cookie_same_site # we have this otherwise pylint complains that cookie_same_site is unused, but it is being used in the get_cookie_same_site function. - if anti_csrf not in {"VIA_TOKEN", "VIA_CUSTOM_HEADER", "NONE", None}: + # _ = cookie_same_site # we have this otherwise pylint complains that cookie_same_site is unused, but it is being used in the get_cookie_same_site function. + if config.anti_csrf not in {"VIA_TOKEN", "VIA_CUSTOM_HEADER", "NONE", None}: raise ValueError( "anti_csrf must be one of VIA_TOKEN, VIA_CUSTOM_HEADER, NONE or None" ) - if error_handlers is not None and not isinstance(error_handlers, ErrorHandlers): # type: ignore + if config.error_handlers is not None and not isinstance( + config.error_handlers, ErrorHandlers + ): # type: ignore raise ValueError("error_handlers must be an instance of ErrorHandlers or None") - if override is not None and not isinstance(override, InputOverrideConfig): # type: ignore - raise ValueError("override must be an instance of InputOverrideConfig or None") - cookie_domain = ( - normalise_session_scope(cookie_domain) if cookie_domain is not None else None + normalise_session_scope(config.cookie_domain) + if config.cookie_domain is not None + else None ) older_cookie_domain = ( - older_cookie_domain - if older_cookie_domain is None or older_cookie_domain == "" - else normalise_session_scope(older_cookie_domain) + config.older_cookie_domain + if config.older_cookie_domain is None or config.older_cookie_domain == "" + else normalise_session_scope(config.older_cookie_domain) ) cookie_secure = ( - cookie_secure - if cookie_secure is not None + config.cookie_secure + if config.cookie_secure is not None else app_info.api_domain.get_as_string_dangerous().startswith("https") ) session_expired_status_code = ( - session_expired_status_code if session_expired_status_code is not None else 401 + config.session_expired_status_code + if config.session_expired_status_code is not None + else 401 ) invalid_claim_status_code = ( - invalid_claim_status_code if invalid_claim_status_code is not None else 403 + config.invalid_claim_status_code + if config.invalid_claim_status_code is not None + else 403 ) if session_expired_status_code == invalid_claim_status_code: @@ -477,21 +450,25 @@ def validate_and_normalise_user_input( f"({invalid_claim_status_code})" ) + get_token_transfer_method = config.get_token_transfer_method if get_token_transfer_method is None: get_token_transfer_method = get_token_transfer_method_default + error_handlers = config.error_handlers if error_handlers is None: error_handlers = InputErrorHandlers() - if override is None: - override = InputOverrideConfig() - + use_dynamic_access_token_signing_key = config.use_dynamic_access_token_signing_key if use_dynamic_access_token_signing_key is None: use_dynamic_access_token_signing_key = True + expose_access_token_to_frontend_in_cookie_based_auth = ( + config.expose_access_token_to_frontend_in_cookie_based_auth + ) if expose_access_token_to_frontend_in_cookie_based_auth is None: expose_access_token_to_frontend_in_cookie_based_auth = False + cookie_same_site = config.cookie_same_site if cookie_same_site is not None: # this is just so that we check that the user has provided the right # values, since normalise_same_site throws an error if the user @@ -538,29 +515,38 @@ def anti_csrf_function( ], Literal["VIA_CUSTOM_HEADER", "NONE", "VIA_TOKEN"], ] = anti_csrf_function + + anti_csrf = config.anti_csrf if anti_csrf is not None: anti_csrf_function_or_string = anti_csrf + jwks_refresh_interval_sec = config.jwks_refresh_interval_sec if jwks_refresh_interval_sec is None: jwks_refresh_interval_sec = 4 * 3600 # 4 hours - return SessionConfig( - app_info.api_base_path.append(NormalisedURLPath(SESSION_REFRESH)), - cookie_domain, - older_cookie_domain, - get_cookie_same_site, - cookie_secure, - session_expired_status_code, - error_handlers, - anti_csrf_function_or_string, - get_token_transfer_method, - OverrideConfig(override.functions, override.apis), - app_info.framework, - app_info.mode, - invalid_claim_status_code, - use_dynamic_access_token_signing_key, - expose_access_token_to_frontend_in_cookie_based_auth, - jwks_refresh_interval_sec, + override_config = NormalisedSessionOverrideConfig.from_input_config( + override_config=config.override + ) + + return NormalisedSessionConfig( + refresh_token_path=app_info.api_base_path.append( + NormalisedURLPath(SESSION_REFRESH) + ), + cookie_domain=cookie_domain, + older_cookie_domain=older_cookie_domain, + get_cookie_same_site=get_cookie_same_site, + cookie_secure=cookie_secure, + session_expired_status_code=session_expired_status_code, + error_handlers=error_handlers, + anti_csrf_function_or_string=anti_csrf_function_or_string, + get_token_transfer_method=get_token_transfer_method, + override=override_config, + framework=app_info.framework, + mode=app_info.mode, + invalid_claim_status_code=invalid_claim_status_code, + use_dynamic_access_token_signing_key=use_dynamic_access_token_signing_key, + expose_access_token_to_frontend_in_cookie_based_auth=expose_access_token_to_frontend_in_cookie_based_auth, + jwks_refresh_interval_sec=jwks_refresh_interval_sec, ) diff --git a/supertokens_python/recipe/thirdparty/__init__.py b/supertokens_python/recipe/thirdparty/__init__.py index 9e20d78e1..0e411de29 100644 --- a/supertokens_python/recipe/thirdparty/__init__.py +++ b/supertokens_python/recipe/thirdparty/__init__.py @@ -14,29 +14,32 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Optional, Union -from . import exceptions as ex -from . import provider, utils +from .provider import ProviderClientConfig, ProviderConfig, ProviderInput from .recipe import ThirdPartyRecipe - -InputOverrideConfig = utils.InputOverrideConfig -SignInAndUpFeature = utils.SignInAndUpFeature -ProviderInput = provider.ProviderInput -ProviderConfig = provider.ProviderConfig -ProviderClientConfig = provider.ProviderClientConfig -exceptions = ex +from .utils import InputOverrideConfig, SignInAndUpFeature, ThirdPartyOverrideConfig if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( sign_in_and_up_feature: Optional[SignInAndUpFeature] = None, - override: Union[InputOverrideConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: + override: Union[ThirdPartyOverrideConfig, None] = None, +) -> RecipeInit: if sign_in_and_up_feature is None: sign_in_and_up_feature = SignInAndUpFeature() return ThirdPartyRecipe.init(sign_in_and_up_feature, override) + + +__all__ = [ + "InputOverrideConfig", # deprecated, use `ThirdPartyOverrideConfig` instead + "ProviderClientConfig", + "ProviderConfig", + "ProviderInput", + "SignInAndUpFeature", + "ThirdPartyOverrideConfig", + "ThirdPartyRecipe", + "init", +] diff --git a/supertokens_python/recipe/thirdparty/interfaces.py b/supertokens_python/recipe/thirdparty/interfaces.py index db8012b39..f7d01ae0c 100644 --- a/supertokens_python/recipe/thirdparty/interfaces.py +++ b/supertokens_python/recipe/thirdparty/interfaces.py @@ -13,9 +13,11 @@ # under the License. from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface + from ...types import RecipeUserId, User from ...types.response import APIResponse, GeneralErrorResponse from .provider import Provider, ProviderInput, RedirectUriInfo @@ -27,7 +29,7 @@ from supertokens_python.types.auth_utils import LinkingToSessionUserFailedError from .types import RawUserInfoFromProvider - from .utils import ThirdPartyConfig + from .utils import NormalisedThirdPartyConfig class SignInUpOkResult: @@ -79,7 +81,7 @@ def __init__(self, reason: str): self.reason = reason -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): def __init__(self): pass @@ -135,7 +137,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: ThirdPartyConfig, + config: NormalisedThirdPartyConfig, recipe_implementation: RecipeInterface, providers: List[ProviderInput], app_info: AppInfo, @@ -143,7 +145,7 @@ def __init__( self.request: BaseRequest = request self.response: BaseResponse = response self.recipe_id: str = recipe_id - self.config: ThirdPartyConfig = config + self.config: NormalisedThirdPartyConfig = config self.providers: List[ProviderInput] = providers self.recipe_implementation: RecipeInterface = recipe_implementation self.app_info: AppInfo = app_info @@ -198,7 +200,7 @@ def to_json(self): } -class APIInterface: +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_sign_in_up_post = False self.disable_authorisation_url_get = False diff --git a/supertokens_python/recipe/thirdparty/providers/apple.py b/supertokens_python/recipe/thirdparty/providers/apple.py index 5c344b980..5466335cc 100644 --- a/supertokens_python/recipe/thirdparty/providers/apple.py +++ b/supertokens_python/recipe/thirdparty/providers/apple.py @@ -16,7 +16,7 @@ import json from re import sub from time import time -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, cast from jwt import encode # type: ignore @@ -106,7 +106,7 @@ async def get_user_info( if isinstance(user, str): user_dict = json.loads(user) elif isinstance(user, dict): - user_dict = user + user_dict = cast(Dict[str, Any], user) else: return response diff --git a/supertokens_python/recipe/thirdparty/providers/custom.py b/supertokens_python/recipe/thirdparty/providers/custom.py index af78ea1e2..c97de1005 100644 --- a/supertokens_python/recipe/thirdparty/providers/custom.py +++ b/supertokens_python/recipe/thirdparty/providers/custom.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, cast from urllib.parse import parse_qs, urlencode, urlparse import pkce @@ -67,7 +67,7 @@ def access_field(obj: Any, key: str) -> Any: key_parts = key.split(".") for part in key_parts: if isinstance(obj, dict): - obj = obj.get(part) # type: ignore + obj = cast(Dict[str, Any], obj).get(part) else: return None diff --git a/supertokens_python/recipe/thirdparty/recipe.py b/supertokens_python/recipe/thirdparty/recipe.py index b3ae12be6..05a26ed4f 100644 --- a/supertokens_python/recipe/thirdparty/recipe.py +++ b/supertokens_python/recipe/thirdparty/recipe.py @@ -30,7 +30,7 @@ from supertokens_python.framework.response import BaseResponse from supertokens_python.supertokens import AppInfo - from .utils import InputOverrideConfig, SignInAndUpFeature + from .utils import SignInAndUpFeature, ThirdPartyOverrideConfig from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe @@ -43,7 +43,7 @@ from .constants import APPLE_REDIRECT_HANDLER, AUTHORISATIONURL, SIGNINUP from .exceptions import SuperTokensThirdPartyError from .types import ThirdPartyIngredients -from .utils import validate_and_normalise_user_input +from .utils import ThirdPartyConfig, validate_and_normalise_user_input class ThirdPartyRecipe(RecipeModule): @@ -54,29 +54,21 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - sign_in_and_up_feature: SignInAndUpFeature, + config: ThirdPartyConfig, _ingredients: ThirdPartyIngredients, - override: Union[InputOverrideConfig, None] = None, ): super().__init__(recipe_id, app_info) - self.config = validate_and_normalise_user_input( - sign_in_and_up_feature, - override, - ) + self.config = validate_and_normalise_user_input(config=config) self.providers = self.config.sign_in_and_up_feature.providers recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self.providers ) - self.recipe_implementation: RecipeInterface = ( + self.recipe_implementation: RecipeInterface = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) api_implementation = APIImplementation() - self.api_implementation: APIInterface = ( + self.api_implementation: APIInterface = self.config.override.apis( api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) ) def callback(): @@ -165,17 +157,27 @@ def get_all_cors_headers(self) -> List[str]: @staticmethod def init( sign_in_and_up_feature: SignInAndUpFeature, - override: Union[InputOverrideConfig, None] = None, + override: Union[ThirdPartyOverrideConfig, None] = None, ): - def func(app_info: AppInfo): + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = ThirdPartyConfig( + sign_in_and_up_feature=sign_in_and_up_feature, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if ThirdPartyRecipe.__instance is None: ingredients = ThirdPartyIngredients() ThirdPartyRecipe.__instance = ThirdPartyRecipe( - ThirdPartyRecipe.recipe_id, - app_info, - sign_in_and_up_feature, - ingredients, - override, + recipe_id=ThirdPartyRecipe.recipe_id, + app_info=app_info, + _ingredients=ingredients, + config=apply_plugins( + recipe_id=ThirdPartyRecipe.recipe_id, + config=config, + plugins=plugins, + ), ) return ThirdPartyRecipe.__instance raise_general_exception( diff --git a/supertokens_python/recipe/thirdparty/utils.py b/supertokens_python/recipe/thirdparty/utils.py index d556d20ac..fc1617838 100644 --- a/supertokens_python/recipe/thirdparty/utils.py +++ b/supertokens_python/recipe/thirdparty/utils.py @@ -13,18 +13,24 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set + +from jwt import PyJWKClient, decode # type: ignore from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.thirdparty.provider import ProviderInput +from supertokens_python.types.config import ( + BaseConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, + BaseOverrideConfig, +) from .interfaces import APIInterface, RecipeInterface if TYPE_CHECKING: from .provider import ProviderInput -from jwt import PyJWKClient, decode # type: ignore - class SignInAndUpFeature: def __init__(self, providers: Optional[List[ProviderInput]] = None): @@ -46,54 +52,37 @@ def __init__(self, providers: Optional[List[ProviderInput]] = None): self.providers = providers -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +ThirdPartyOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedThirdPartyOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] +InputOverrideConfig = ThirdPartyOverrideConfig +"""Deprecated: Use `ThirdPartyOverrideConfig` instead.""" -class OverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +class ThirdPartyConfig(BaseConfig[RecipeInterface, APIInterface]): + sign_in_and_up_feature: SignInAndUpFeature -class ThirdPartyConfig: - def __init__( - self, - sign_in_and_up_feature: SignInAndUpFeature, - override: OverrideConfig, - ): - self.sign_in_and_up_feature = sign_in_and_up_feature - self.override = override +class NormalisedThirdPartyConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): + sign_in_and_up_feature: SignInAndUpFeature def validate_and_normalise_user_input( - sign_in_and_up_feature: SignInAndUpFeature, - override: Union[InputOverrideConfig, None] = None, -) -> ThirdPartyConfig: - if not isinstance(sign_in_and_up_feature, SignInAndUpFeature): # type: ignore + config: ThirdPartyConfig, +) -> NormalisedThirdPartyConfig: + if not isinstance(config.sign_in_and_up_feature, SignInAndUpFeature): # type: ignore raise ValueError( "sign_in_and_up_feature must be an instance of SignInAndUpFeature" ) - if override is not None and not isinstance(override, InputOverrideConfig): # type: ignore - raise ValueError("override must be an instance of InputOverrideConfig or None") - - if override is None: - override = InputOverrideConfig() + override_config = NormalisedThirdPartyOverrideConfig.from_input_config( + override_config=config.override + ) - return ThirdPartyConfig( - sign_in_and_up_feature, - OverrideConfig(functions=override.functions, apis=override.apis), + return NormalisedThirdPartyConfig( + sign_in_and_up_feature=config.sign_in_and_up_feature, + override=override_config, ) diff --git a/supertokens_python/recipe/totp/__init__.py b/supertokens_python/recipe/totp/__init__.py index f89944688..21baa226e 100644 --- a/supertokens_python/recipe/totp/__init__.py +++ b/supertokens_python/recipe/totp/__init__.py @@ -13,21 +13,32 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Union -from supertokens_python.recipe.totp.types import TOTPConfig +from supertokens_python.recipe.totp.types import ( + OverrideConfig, + TOTPConfig, + TOTPOverrideConfig, +) from .recipe import TOTPRecipe if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( config: Union[TOTPConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: +) -> RecipeInit: return TOTPRecipe.init( config=config, ) + + +__all__ = [ + "OverrideConfig", # deprecated, use `TOTPOverrideConfig` instead + "TOTPConfig", + "TOTPOverrideConfig", + "TOTPRecipe", + "init", +] diff --git a/supertokens_python/recipe/totp/interfaces.py b/supertokens_python/recipe/totp/interfaces.py index bd7f36249..0cafe03a6 100644 --- a/supertokens_python/recipe/totp/interfaces.py +++ b/supertokens_python/recipe/totp/interfaces.py @@ -14,9 +14,11 @@ from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface + if TYPE_CHECKING: from supertokens_python import AppInfo from supertokens_python.framework import BaseRequest, BaseResponse @@ -30,8 +32,8 @@ InvalidTOTPError, LimitReachedError, ListDevicesOkResult, + NormalisedTOTPConfig, RemoveDeviceOkResult, - TOTPNormalisedConfig, UnknownDeviceError, UnknownUserIdError, UpdateDeviceOkResult, @@ -42,7 +44,7 @@ ) -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): @abstractmethod async def get_user_identifier_info_for_user_id( self, user_id: str, user_context: Dict[str, Any] @@ -129,7 +131,7 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: TOTPNormalisedConfig, + config: NormalisedTOTPConfig, recipe_implementation: RecipeInterface, app_info: AppInfo, recipe_instance: TOTPRecipe, @@ -143,7 +145,7 @@ def __init__( self.recipe_instance = recipe_instance -class APIInterface(ABC): +class APIInterface(BaseAPIInterface): def __init__(self): self.disable_create_device_post = False self.disable_list_devices_get = False diff --git a/supertokens_python/recipe/totp/recipe.py b/supertokens_python/recipe/totp/recipe.py index 638aac124..855e615b9 100644 --- a/supertokens_python/recipe/totp/recipe.py +++ b/supertokens_python/recipe/totp/recipe.py @@ -72,17 +72,13 @@ def __init__( recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self.config ) - self.recipe_implementation: RecipeInterface = ( + self.recipe_implementation: RecipeInterface = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) api_implementation = APIImplementation() - self.api_implementation: APIInterface = ( + self.api_implementation: APIInterface = self.config.override.apis( api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) ) def callback(): @@ -205,12 +201,21 @@ def get_all_cors_headers(self) -> List[str]: def init( config: Union[TOTPConfig, None] = None, ): - def func(app_info: AppInfo): + from supertokens_python.plugins import OverrideMap, apply_plugins + + if config is None: + config = TOTPConfig() + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if TOTPRecipe.__instance is None: TOTPRecipe.__instance = TOTPRecipe( - TOTPRecipe.recipe_id, - app_info, - config, + recipe_id=TOTPRecipe.recipe_id, + app_info=app_info, + config=apply_plugins( + recipe_id=TOTPRecipe.recipe_id, + config=config, + plugins=plugins, + ), ) return TOTPRecipe.__instance raise Exception( diff --git a/supertokens_python/recipe/totp/recipe_implementation.py b/supertokens_python/recipe/totp/recipe_implementation.py index 93ce02a2a..3a16f399b 100644 --- a/supertokens_python/recipe/totp/recipe_implementation.py +++ b/supertokens_python/recipe/totp/recipe_implementation.py @@ -29,8 +29,8 @@ InvalidTOTPError, LimitReachedError, ListDevicesOkResult, + NormalisedTOTPConfig, RemoveDeviceOkResult, - TOTPNormalisedConfig, UnknownDeviceError, UnknownUserIdError, UpdateDeviceOkResult, @@ -48,7 +48,7 @@ class RecipeImplementation(RecipeInterface): def __init__( self, querier: Querier, - config: TOTPNormalisedConfig, + config: NormalisedTOTPConfig, ): super().__init__() self.querier = querier diff --git a/supertokens_python/recipe/totp/types.py b/supertokens_python/recipe/totp/types.py index d863696ba..3ace4cb78 100644 --- a/supertokens_python/recipe/totp/types.py +++ b/supertokens_python/recipe/totp/types.py @@ -12,10 +12,16 @@ # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional from typing_extensions import Literal +from supertokens_python.types.config import ( + BaseConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, + BaseOverrideConfig, +) from supertokens_python.types.response import APIResponse from .interfaces import APIInterface, RecipeInterface @@ -177,39 +183,21 @@ def to_json(self) -> Dict[str, Any]: return {"status": self.status} -class OverrideConfig: - def __init__( - self, - functions: Optional[Callable[[RecipeInterface], RecipeInterface]] = None, - apis: Optional[Callable[[APIInterface], APIInterface]] = None, - ): - self.functions = functions - self.apis = apis +TOTPOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedTOTPOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] +OverrideConfig = TOTPOverrideConfig +"""Deprecated: Use `TOTPOverrideConfig` instead.""" -class TOTPConfig: - def __init__( - self, - issuer: Optional[str] = None, - default_skew: Optional[int] = None, - default_period: Optional[int] = None, - override: Optional[OverrideConfig] = None, - ): - self.issuer = issuer - self.default_skew = default_skew - self.default_period = default_period - self.override = override +class TOTPConfig(BaseConfig[RecipeInterface, APIInterface]): + issuer: Optional[str] = None + default_skew: Optional[int] = None + default_period: Optional[int] = None -class TOTPNormalisedConfig: - def __init__( - self, - issuer: str, - default_skew: int, - default_period: int, - override: OverrideConfig, - ): - self.issuer = issuer - self.default_skew = default_skew - self.default_period = default_period - self.override = override +class NormalisedTOTPConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): + issuer: str + default_skew: int + default_period: int diff --git a/supertokens_python/recipe/totp/utils.py b/supertokens_python/recipe/totp/utils.py index 1e3781335..4c49f662c 100644 --- a/supertokens_python/recipe/totp/utils.py +++ b/supertokens_python/recipe/totp/utils.py @@ -16,12 +16,16 @@ from supertokens_python import AppInfo -from .types import OverrideConfig, TOTPConfig, TOTPNormalisedConfig +from .types import ( + NormalisedTOTPConfig, + NormalisedTOTPOverrideConfig, + TOTPConfig, +) def validate_and_normalise_user_input( app_info: AppInfo, config: Union[TOTPConfig, None] -) -> TOTPNormalisedConfig: +) -> NormalisedTOTPConfig: if config is None: config = TOTPConfig() @@ -29,17 +33,13 @@ def validate_and_normalise_user_input( default_skew = config.default_skew if config.default_skew is not None else 1 default_period = config.default_period if config.default_period is not None else 30 - if config.override is None: - override = OverrideConfig() - else: - override = OverrideConfig( - functions=config.override.functions, - apis=config.override.apis, - ) + override_config = NormalisedTOTPOverrideConfig.from_input_config( + override_config=config.override + ) - return TOTPNormalisedConfig( + return NormalisedTOTPConfig( issuer=issuer, default_skew=default_skew, default_period=default_period, - override=override, + override=override_config, ) diff --git a/supertokens_python/recipe/usermetadata/__init__.py b/supertokens_python/recipe/usermetadata/__init__.py index e5bdd43ed..c760e225b 100644 --- a/supertokens_python/recipe/usermetadata/__init__.py +++ b/supertokens_python/recipe/usermetadata/__init__.py @@ -13,18 +13,24 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Union -from . import utils from .recipe import UserMetadataRecipe +from .utils import InputOverrideConfig, UserMetadataOverrideConfig if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( - override: Union[utils.InputOverrideConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: + override: Union[UserMetadataOverrideConfig, None] = None, +) -> RecipeInit: return UserMetadataRecipe.init(override) + + +__all__ = [ + "InputOverrideConfig", # deprecated, use `UserMetadataOverrideConfig` instead + "UserMetadataOverrideConfig", + "UserMetadataRecipe", + "init", +] diff --git a/supertokens_python/recipe/usermetadata/interfaces.py b/supertokens_python/recipe/usermetadata/interfaces.py index 05c042650..601a65278 100644 --- a/supertokens_python/recipe/usermetadata/interfaces.py +++ b/supertokens_python/recipe/usermetadata/interfaces.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface + class MetadataResult(ABC): def __init__(self, metadata: Dict[str, Any]): @@ -11,7 +13,7 @@ class ClearUserMetadataResult: pass -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): @abstractmethod async def get_user_metadata( self, user_id: str, user_context: Dict[str, Any] @@ -34,5 +36,5 @@ async def clear_user_metadata( pass -class APIInterface(ABC): +class APIInterface(BaseAPIInterface): pass diff --git a/supertokens_python/recipe/usermetadata/recipe.py b/supertokens_python/recipe/usermetadata/recipe.py index 4eee4b066..fc5437b18 100644 --- a/supertokens_python/recipe/usermetadata/recipe.py +++ b/supertokens_python/recipe/usermetadata/recipe.py @@ -35,7 +35,7 @@ if TYPE_CHECKING: from supertokens_python.supertokens import AppInfo -from .utils import InputOverrideConfig +from .utils import UserMetadataConfig, UserMetadataOverrideConfig class UserMetadataRecipe(RecipeModule): @@ -46,15 +46,15 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - override: Union[InputOverrideConfig, None] = None, + config: UserMetadataConfig, ): super().__init__(recipe_id, app_info) - self.config = validate_and_normalise_user_input(self, app_info, override) + self.config = validate_and_normalise_user_input( + _recipe=self, _app_info=app_info, input_config=config + ) recipe_implementation = RecipeImplementation(Querier.get_instance(recipe_id)) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: @@ -90,11 +90,21 @@ def get_all_cors_headers(self) -> List[str]: return [] @staticmethod - def init(override: Union[InputOverrideConfig, None] = None): - def func(app_info: AppInfo): + def init(override: Union[UserMetadataOverrideConfig, None] = None): + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = UserMetadataConfig(override=override) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if UserMetadataRecipe.__instance is None: UserMetadataRecipe.__instance = UserMetadataRecipe( - UserMetadataRecipe.recipe_id, app_info, override + recipe_id=UserMetadataRecipe.recipe_id, + app_info=app_info, + config=apply_plugins( + recipe_id=UserMetadataRecipe.recipe_id, + config=config, + plugins=plugins, + ), ) return UserMetadataRecipe.__instance raise Exception( diff --git a/supertokens_python/recipe/usermetadata/utils.py b/supertokens_python/recipe/usermetadata/utils.py index 7e059cf74..70d0a979a 100644 --- a/supertokens_python/recipe/usermetadata/utils.py +++ b/supertokens_python/recipe/usermetadata/utils.py @@ -14,42 +14,47 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING from supertokens_python.recipe.usermetadata.interfaces import ( APIInterface, RecipeInterface, ) +from supertokens_python.types.config import ( + BaseConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, + BaseOverrideConfig, +) if TYPE_CHECKING: from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe from supertokens_python.supertokens import AppInfo -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +UserMetadataOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedUserMetadataOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] +InputOverrideConfig = UserMetadataOverrideConfig +"""Deprecated: Use `UserMetadataOverrideConfig` instead.""" + +class UserMetadataConfig(BaseConfig[RecipeInterface, APIInterface]): ... -class UserMetadataConfig: - def __init__(self, override: InputOverrideConfig) -> None: - self.override = override + +class NormalisedUserMetadataConfig( + BaseNormalisedConfig[RecipeInterface, APIInterface] +): ... def validate_and_normalise_user_input( _recipe: UserMetadataRecipe, _app_info: AppInfo, - override: Union[InputOverrideConfig, None] = None, -) -> UserMetadataConfig: - if override is not None and not isinstance(override, InputOverrideConfig): # type: ignore - raise ValueError("override must be an instance of InputOverrideConfig or None") - - if override is None: - override = InputOverrideConfig() + input_config: UserMetadataConfig, +) -> NormalisedUserMetadataConfig: + override_config = NormalisedUserMetadataOverrideConfig.from_input_config( + override_config=input_config.override + ) - return UserMetadataConfig(override=override) + return NormalisedUserMetadataConfig(override=override_config) diff --git a/supertokens_python/recipe/userroles/__init__.py b/supertokens_python/recipe/userroles/__init__.py index f0e082669..7356cc507 100644 --- a/supertokens_python/recipe/userroles/__init__.py +++ b/supertokens_python/recipe/userroles/__init__.py @@ -13,27 +13,32 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Optional, Union -from . import recipe, utils -from .recipe import UserRolesRecipe - -PermissionClaim = recipe.PermissionClaim -UserRoleClaim = recipe.UserRoleClaim +from .recipe import PermissionClaim, UserRoleClaim, UserRolesRecipe +from .utils import InputOverrideConfig, UserRolesOverrideConfig if TYPE_CHECKING: - from supertokens_python.supertokens import AppInfo - - from ...recipe_module import RecipeModule + from supertokens_python.supertokens import RecipeInit def init( skip_adding_roles_to_access_token: Optional[bool] = None, skip_adding_permissions_to_access_token: Optional[bool] = None, - override: Union[utils.InputOverrideConfig, None] = None, -) -> Callable[[AppInfo], RecipeModule]: + override: Union[UserRolesOverrideConfig, None] = None, +) -> RecipeInit: return UserRolesRecipe.init( skip_adding_roles_to_access_token, skip_adding_permissions_to_access_token, override, ) + + +__all__ = [ + "InputOverrideConfig", # deprecated, use `UserRolesOverrideConfig` instead + "PermissionClaim", + "UserRoleClaim", + "UserRolesOverrideConfig", + "UserRolesRecipe", + "init", +] diff --git a/supertokens_python/recipe/userroles/interfaces.py b/supertokens_python/recipe/userroles/interfaces.py index 1ffc97410..e6b7fff81 100644 --- a/supertokens_python/recipe/userroles/interfaces.py +++ b/supertokens_python/recipe/userroles/interfaces.py @@ -1,6 +1,8 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Any, Dict, List, Union +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface + class AddRoleToUserOkResult: def __init__(self, did_user_already_have_role: bool): @@ -55,7 +57,7 @@ def __init__(self, roles: List[str]): self.roles = roles -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): @abstractmethod async def add_role_to_user( self, @@ -123,5 +125,5 @@ async def get_all_roles(self, user_context: Dict[str, Any]) -> GetAllRolesOkResu pass -class APIInterface(ABC): +class APIInterface(BaseAPIInterface): pass diff --git a/supertokens_python/recipe/userroles/recipe.py b/supertokens_python/recipe/userroles/recipe.py index 7a2d830d5..a8971d747 100644 --- a/supertokens_python/recipe/userroles/recipe.py +++ b/supertokens_python/recipe/userroles/recipe.py @@ -35,7 +35,7 @@ from ..session.claim_base_classes.primitive_array_claim import PrimitiveArrayClaim from .exceptions import SuperTokensUserRolesError from .interfaces import GetPermissionsForRoleOkResult, UnknownRoleError -from .utils import InputOverrideConfig +from .utils import UserRolesConfig, UserRolesOverrideConfig class UserRolesRecipe(RecipeModule): @@ -46,25 +46,19 @@ def __init__( self, recipe_id: str, app_info: AppInfo, - skip_adding_roles_to_access_token: Optional[bool] = None, - skip_adding_permissions_to_access_token: Optional[bool] = None, - override: Union[InputOverrideConfig, None] = None, + config: UserRolesConfig, ): from ..oauth2provider.recipe import OAuth2ProviderRecipe super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( - self, - app_info, - skip_adding_roles_to_access_token, - skip_adding_permissions_to_access_token, - override, + _recipe=self, + _app_info=app_info, + config=config, ) recipe_implementation = RecipeImplementation(Querier.get_instance(recipe_id)) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) def callback(): @@ -216,16 +210,26 @@ def get_all_cors_headers(self) -> List[str]: def init( skip_adding_roles_to_access_token: Optional[bool] = None, skip_adding_permissions_to_access_token: Optional[bool] = None, - override: Union[InputOverrideConfig, None] = None, + override: Union[UserRolesOverrideConfig, None] = None, ): - def func(app_info: AppInfo): + from supertokens_python.plugins import OverrideMap, apply_plugins + + config = UserRolesConfig( + skip_adding_roles_to_access_token=skip_adding_roles_to_access_token, + skip_adding_permissions_to_access_token=skip_adding_permissions_to_access_token, + override=override, + ) + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if UserRolesRecipe.__instance is None: UserRolesRecipe.__instance = UserRolesRecipe( - UserRolesRecipe.recipe_id, - app_info, - skip_adding_roles_to_access_token, - skip_adding_permissions_to_access_token, - override, + recipe_id=UserRolesRecipe.recipe_id, + app_info=app_info, + config=apply_plugins( + recipe_id=UserRolesRecipe.recipe_id, + config=config, + plugins=plugins, + ), ) return UserRolesRecipe.__instance raise Exception( diff --git a/supertokens_python/recipe/userroles/utils.py b/supertokens_python/recipe/userroles/utils.py index c94ab5bed..81eb6bdbe 100644 --- a/supertokens_python/recipe/userroles/utils.py +++ b/supertokens_python/recipe/userroles/utils.py @@ -14,59 +14,60 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Optional from supertokens_python.recipe.userroles.interfaces import APIInterface, RecipeInterface from supertokens_python.supertokens import AppInfo +from supertokens_python.types.config import ( + BaseConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, + BaseOverrideConfig, +) if TYPE_CHECKING: from supertokens_python.recipe.userroles.recipe import UserRolesRecipe -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis +UserRolesOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedUserRolesOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] +InputOverrideConfig = UserRolesOverrideConfig +"""Deprecated: Use `UserRolesOverrideConfig` instead.""" -class UserRolesConfig: - def __init__( - self, - skip_adding_roles_to_access_token: bool, - skip_adding_permissions_to_access_token: bool, - override: InputOverrideConfig, - ) -> None: - self.skip_adding_roles_to_access_token = skip_adding_roles_to_access_token - self.skip_adding_permissions_to_access_token = ( - skip_adding_permissions_to_access_token - ) - self.override = override +class UserRolesConfig(BaseConfig[RecipeInterface, APIInterface]): + skip_adding_roles_to_access_token: Optional[bool] = None + skip_adding_permissions_to_access_token: Optional[bool] = None + + +class NormalisedUserRolesConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): + skip_adding_roles_to_access_token: bool + skip_adding_permissions_to_access_token: bool def validate_and_normalise_user_input( _recipe: UserRolesRecipe, _app_info: AppInfo, - skip_adding_roles_to_access_token: Optional[bool] = None, - skip_adding_permissions_to_access_token: Optional[bool] = None, - override: Union[InputOverrideConfig, None] = None, -) -> UserRolesConfig: - if override is not None and not isinstance(override, InputOverrideConfig): # type: ignore - raise ValueError("override must be an instance of InputOverrideConfig or None") - - if override is None: - override = InputOverrideConfig() + config: UserRolesConfig, +) -> NormalisedUserRolesConfig: + override_config = NormalisedUserRolesOverrideConfig.from_input_config( + override_config=config.override + ) + skip_adding_roles_to_access_token = config.skip_adding_roles_to_access_token if skip_adding_roles_to_access_token is None: skip_adding_roles_to_access_token = False + + skip_adding_permissions_to_access_token = ( + config.skip_adding_permissions_to_access_token + ) if skip_adding_permissions_to_access_token is None: skip_adding_permissions_to_access_token = False - return UserRolesConfig( + return NormalisedUserRolesConfig( skip_adding_roles_to_access_token=skip_adding_roles_to_access_token, skip_adding_permissions_to_access_token=skip_adding_permissions_to_access_token, - override=override, + override=override_config, ) diff --git a/supertokens_python/recipe/webauthn/__init__.py b/supertokens_python/recipe/webauthn/__init__.py index 8f48fba47..f19c9098b 100644 --- a/supertokens_python/recipe/webauthn/__init__.py +++ b/supertokens_python/recipe/webauthn/__init__.py @@ -39,8 +39,8 @@ from supertokens_python.recipe.webauthn.recipe import WebauthnRecipe from supertokens_python.recipe.webauthn.types.config import ( NormalisedWebauthnConfig, - OverrideConfig, WebauthnConfig, + WebauthnOverrideConfig, ) # Some Pydantic models need a rebuild to resolve ForwardRefs @@ -60,11 +60,10 @@ def init(config: Optional[WebauthnConfig] = None): __all__ = [ - "init", "APIInterface", "RecipeInterface", - "OverrideConfig", "WebauthnConfig", + "WebauthnOverrideConfig", "WebauthnRecipe", "consume_recover_account_token", "create_recover_account_link", @@ -72,6 +71,7 @@ def init(config: Optional[WebauthnConfig] = None): "get_credential", "get_generated_options", "get_user_from_recover_account_token", + "init", "list_credentials", "recover_account", "register_credential", diff --git a/supertokens_python/recipe/webauthn/interfaces/api.py b/supertokens_python/recipe/webauthn/interfaces/api.py index 230a9086f..8e52c15f2 100644 --- a/supertokens_python/recipe/webauthn/interfaces/api.py +++ b/supertokens_python/recipe/webauthn/interfaces/api.py @@ -12,8 +12,8 @@ # License for the specific language governing permissions and limitations # under the License. -from abc import ABC, abstractmethod -from typing import List, Literal, Optional, TypedDict, Union +from abc import abstractmethod +from typing import TYPE_CHECKING, List, Literal, Optional, TypedDict, Union from typing_extensions import NotRequired, Unpack @@ -36,10 +36,10 @@ SignInOptionsErrorResponse, UserVerification, ) -from supertokens_python.recipe.webauthn.types.config import NormalisedWebauthnConfig from supertokens_python.supertokens import AppInfo from supertokens_python.types import RecipeUserId, User from supertokens_python.types.base import UserContext +from supertokens_python.types.recipe import BaseAPIInterface from supertokens_python.types.response import ( CamelCaseBaseModel, GeneralErrorResponse, @@ -47,6 +47,9 @@ StatusReasonResponseBaseModel, ) +if TYPE_CHECKING: + from supertokens_python.recipe.webauthn.types.config import NormalisedWebauthnConfig + class SignUpNotAllowedErrorResponse( StatusReasonResponseBaseModel[Literal["SIGN_UP_NOT_ALLOWED"], str] @@ -93,7 +96,7 @@ class TypeWebauthnRecoverAccountEmailDeliveryInput(CamelCaseBaseModel): class APIOptions(CamelCaseBaseModel): recipe_implementation: RecipeInterface app_info: AppInfo - config: NormalisedWebauthnConfig + config: "NormalisedWebauthnConfig" recipe_id: str req: BaseRequest res: BaseResponse @@ -219,7 +222,7 @@ class RegisterOptionsPOSTKwargsInput(TypedDict): email: NotRequired[str] -class APIInterface(ABC): +class APIInterface(BaseAPIInterface): disable_register_options_post: bool = False disable_sign_in_options_post: bool = False disable_sign_up_post: bool = False diff --git a/supertokens_python/recipe/webauthn/interfaces/recipe.py b/supertokens_python/recipe/webauthn/interfaces/recipe.py index 67dcbe012..8f22da760 100644 --- a/supertokens_python/recipe/webauthn/interfaces/recipe.py +++ b/supertokens_python/recipe/webauthn/interfaces/recipe.py @@ -12,7 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import ( Any, Dict, @@ -30,6 +30,7 @@ from supertokens_python.types import RecipeUserId, User from supertokens_python.types.auth_utils import LinkingToSessionUserFailedError from supertokens_python.types.base import UserContext +from supertokens_python.types.recipe import BaseRecipeInterface from supertokens_python.types.response import ( CamelCaseBaseModel, OkResponseBaseModel, @@ -440,7 +441,7 @@ class RegisterOptionsKwargsInput(TypedDict): email: NotRequired[str] -class RecipeInterface(ABC): +class RecipeInterface(BaseRecipeInterface): @abstractmethod async def register_options( self, diff --git a/supertokens_python/recipe/webauthn/recipe.py b/supertokens_python/recipe/webauthn/recipe.py index 02465285b..455319e96 100644 --- a/supertokens_python/recipe/webauthn/recipe.py +++ b/supertokens_python/recipe/webauthn/recipe.py @@ -105,18 +105,12 @@ def __init__( querier=querier, config=self.config, ) - self.recipe_implementation = ( + self.recipe_implementation = self.config.override.functions( recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) ) api_implementation = APIImplementation() - self.api_implementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) - ) + self.api_implementation = self.config.override.apis(api_implementation) if ingredients.email_delivery is None: self.email_delivery = EmailDeliveryIngredient( @@ -301,12 +295,21 @@ def get_instance_optional() -> Optional["WebauthnRecipe"]: @staticmethod def init(config: Optional[WebauthnConfig]): - def func(app_info: AppInfo): + from supertokens_python.plugins import OverrideMap, apply_plugins + + if config is None: + config = WebauthnConfig() + + def func(app_info: AppInfo, plugins: List[OverrideMap]): if WebauthnRecipe.__instance is None: WebauthnRecipe.__instance = WebauthnRecipe( recipe_id=WebauthnRecipe.recipe_id, app_info=app_info, - config=config, + config=apply_plugins( + recipe_id=WebauthnRecipe.recipe_id, + config=config, + plugins=plugins, + ), ingredients=WebauthnIngredients(email_delivery=None), ) return WebauthnRecipe.__instance diff --git a/supertokens_python/recipe/webauthn/types/config.py b/supertokens_python/recipe/webauthn/types/config.py index f0c36202e..d2132cf1d 100644 --- a/supertokens_python/recipe/webauthn/types/config.py +++ b/supertokens_python/recipe/webauthn/types/config.py @@ -14,8 +14,7 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Protocol, TypeVar, Union, runtime_checkable +from typing import Optional, Protocol, TypeVar, Union, runtime_checkable from supertokens_python.framework import BaseRequest from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient @@ -23,14 +22,19 @@ EmailDeliveryConfig, EmailDeliveryConfigWithService, ) +from supertokens_python.recipe.webauthn.interfaces.api import ( + APIInterface, + TypeWebauthnEmailDeliveryInput, +) +from supertokens_python.recipe.webauthn.interfaces.recipe import RecipeInterface from supertokens_python.types.base import UserContext - -if TYPE_CHECKING: - from supertokens_python.recipe.webauthn.interfaces.api import ( - APIInterface, - TypeWebauthnEmailDeliveryInput, - ) - from supertokens_python.recipe.webauthn.interfaces.recipe import RecipeInterface +from supertokens_python.types.config import ( + BaseConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, + BaseOverrideConfig, +) +from supertokens_python.types.response import CamelCaseBaseModel InterfaceType = TypeVar("InterfaceType") """Generic Type for use in `InterfaceOverride`""" @@ -179,39 +183,29 @@ def __call__( ) -> InterfaceType: ... -# NOTE: Using dataclasses for these classes since validation is not required -@dataclass -class OverrideConfig: - """ - `WebauthnConfig.override` - """ - - functions: Optional[InterfaceOverride[RecipeInterface]] = None - apis: Optional[InterfaceOverride[APIInterface]] = None +WebauthnOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedWebauthnOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] -@dataclass -class WebauthnConfig: +class WebauthnConfig(BaseConfig[RecipeInterface, APIInterface]): get_relying_party_id: Optional[Union[str, GetRelyingPartyId]] = None get_relying_party_name: Optional[Union[str, GetRelyingPartyName]] = None get_origin: Optional[GetOrigin] = None email_delivery: Optional[EmailDeliveryConfig[TypeWebauthnEmailDeliveryInput]] = None validate_email_address: Optional[ValidateEmailAddress] = None - override: Optional[OverrideConfig] = None -@dataclass -class NormalisedWebauthnConfig: +class NormalisedWebauthnConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): get_relying_party_id: NormalisedGetRelyingPartyId get_relying_party_name: NormalisedGetRelyingPartyName get_origin: NormalisedGetOrigin get_email_delivery_config: NormalisedGetEmailDeliveryConfig validate_email_address: NormalisedValidateEmailAddress - override: OverrideConfig -@dataclass -class WebauthnIngredients: +class WebauthnIngredients(CamelCaseBaseModel): email_delivery: Optional[ EmailDeliveryIngredient[TypeWebauthnEmailDeliveryInput] ] = None diff --git a/supertokens_python/recipe/webauthn/utils.py b/supertokens_python/recipe/webauthn/utils.py index c3a35abfb..4e6b543b6 100644 --- a/supertokens_python/recipe/webauthn/utils.py +++ b/supertokens_python/recipe/webauthn/utils.py @@ -35,7 +35,7 @@ NormalisedGetRelyingPartyName, NormalisedValidateEmailAddress, NormalisedWebauthnConfig, - OverrideConfig, + NormalisedWebauthnOverrideConfig, ValidateEmailAddress, WebauthnConfig, ) @@ -60,13 +60,9 @@ def validate_and_normalise_user_input( config.validate_email_address ) - if config.override is None: - override = OverrideConfig() - else: - override = OverrideConfig( - functions=config.override.functions, - apis=config.override.apis, - ) + override_config = NormalisedWebauthnOverrideConfig.from_input_config( + override_config=config.override + ) def get_email_delivery_config() -> EmailDeliveryConfigWithService[ TypeWebauthnEmailDeliveryInput @@ -93,7 +89,7 @@ def get_email_delivery_config() -> EmailDeliveryConfigWithService[ get_origin=get_origin, get_email_delivery_config=get_email_delivery_config, validate_email_address=validate_email_address, - override=override, + override=override_config, ) diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index abcf3aed0..024998576 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -15,7 +15,18 @@ from __future__ import annotations from os import environ -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Protocol, + Set, + Tuple, + Union, +) from typing_extensions import Literal @@ -24,9 +35,17 @@ get_maybe_none_as_str, log_debug_message, ) +from supertokens_python.plugins import ( + OverrideMap, + PluginRouteHandler, + PluginRouteHandlerWithPluginId, + SuperTokensPlugin, + SuperTokensPublicPlugin, +) +from supertokens_python.types.response import CamelCaseBaseModel from .constants import FDI_KEY_HEADER, RID_KEY_HEADER, USER_COUNT -from .exceptions import SuperTokensError +from .exceptions import PluginError, SuperTokensError from .interfaces import ( CreateUserIdMappingOkResult, DeleteUserIdMappingOkResult, @@ -199,18 +218,119 @@ def manage_session_post_response( mutator(response, user_context) +class RecipeInit(Protocol): + def __call__( + self, app_info: AppInfo, plugins: List[OverrideMap] + ) -> RecipeModule: ... + + +class SupertokensExperimentalConfig(CamelCaseBaseModel): + plugins: Optional[List["SuperTokensPlugin"]] = None + + +class _BaseSupertokensPublicConfig(CamelCaseBaseModel): + """ + Public properties received as input to the `Supertokens.init` function. + """ + + supertokens_config: SupertokensConfig + + +class SupertokensPublicConfig(_BaseSupertokensPublicConfig): + """ + Public properties received as input to the `Supertokens.init` function. + """ + + pass + + +class _BaseSupertokensInputConfig(_BaseSupertokensPublicConfig): + framework: Literal["fastapi", "flask", "django"] + mode: Optional[Literal["asgi", "wsgi"]] + telemetry: Optional[bool] + debug: Optional[bool] + recipe_list: List[Callable[[AppInfo, List["OverrideMap"]], "RecipeModule"]] + experimental: Optional[SupertokensExperimentalConfig] = None + + +class SupertokensInputConfigWithNormalisedAppInfo(_BaseSupertokensInputConfig): + app_info: AppInfo + + +class SupertokensInputConfig(_BaseSupertokensInputConfig): + """ + Various properties received as input to the `Supertokens.init` function. + """ + + app_info: InputAppInfo + + def to_public_config(self) -> SupertokensPublicConfig: + return SupertokensPublicConfig(supertokens_config=self.supertokens_config) + + @classmethod + def from_public_and_input_config( + cls, + input_config: "SupertokensInputConfig", + public_config: SupertokensPublicConfig, + ) -> "SupertokensInputConfig": + return cls( + **{ + **input_config.model_dump(), + **public_config.model_dump(), + } + ) + + @classmethod + def from_public_config( + cls, + config: SupertokensPublicConfig, + app_info: InputAppInfo, + framework: Literal["fastapi", "flask", "django"], + mode: Optional[Literal["asgi", "wsgi"]], + telemetry: Optional[bool], + debug: Optional[bool], + recipe_list: List[Callable[[AppInfo, List["OverrideMap"]], "RecipeModule"]], + experimental: Optional[SupertokensExperimentalConfig], + ) -> "SupertokensInputConfig": + return cls( + app_info=app_info, + framework=framework, + supertokens_config=config.supertokens_config, + mode=mode, + telemetry=telemetry, + debug=debug, + recipe_list=recipe_list, + experimental=experimental, + ) + + class Supertokens: - __instance = None + __instance: Optional[Supertokens] = None + + recipe_modules: List[RecipeModule] + + app_info: AppInfo + + supertokens_config: SupertokensConfig + + _telemetry_status: str + + telemetry: bool + + plugin_route_handlers: List[PluginRouteHandlerWithPluginId] + + plugin_list: List[SuperTokensPublicPlugin] def __init__( self, app_info: InputAppInfo, framework: Literal["fastapi", "flask", "django"], supertokens_config: SupertokensConfig, - recipe_list: List[Callable[[AppInfo], RecipeModule]], + recipe_list: List[RecipeInit], mode: Optional[Literal["asgi", "wsgi"]], telemetry: Optional[bool], debug: Optional[bool], + experimental: Optional[SupertokensExperimentalConfig] = None, ): if not isinstance(app_info, InputAppInfo): # type: ignore raise ValueError("app_info must be an instance of InputAppInfo") @@ -226,7 +346,54 @@ def __init__( mode, app_info.origin, ) - self.supertokens_config = supertokens_config + + input_config = SupertokensInputConfig( + app_info=app_info, + framework=framework, + supertokens_config=supertokens_config, + recipe_list=recipe_list, + mode=mode, + telemetry=telemetry, + debug=debug, + experimental=experimental, + ) + input_public_config = input_config.to_public_config() + # Use the input public config by default if no plugins provided + processed_public_config: SupertokensPublicConfig = input_public_config + + self.plugin_route_handlers = [] + override_maps: List[OverrideMap] = [] + + if experimental is not None and experimental.plugins is not None: + from supertokens_python.plugins import load_plugins + + load_plugins_result = load_plugins( + plugins=experimental.plugins, + public_config=input_public_config, + ) + + override_maps = load_plugins_result.override_maps + processed_public_config = load_plugins_result.public_config + self.plugin_list = load_plugins_result.processed_plugins + self.plugin_route_handlers = load_plugins_result.plugin_route_handlers + + config = SupertokensInputConfig.from_public_and_input_config( + input_config=input_config, + public_config=processed_public_config, + ) + + self.app_info = AppInfo( + config.app_info.app_name, + config.app_info.api_domain, + config.app_info.website_domain, + config.framework, + config.app_info.api_gateway_path, + config.app_info.api_base_path, + config.app_info.website_base_path, + config.mode, + config.app_info.origin, + ) + self.supertokens_config = config.supertokens_config if debug is True: enable_debug_logging() self._telemetry_status: str = "NONE" @@ -234,7 +401,7 @@ def __init__( "Started SuperTokens with debug logging (supertokens.init called)" ) log_debug_message("app_info: %s", self.app_info.toJSON()) - log_debug_message("framework: %s", framework) + log_debug_message("framework: %s", config.framework) hosts = list( map( lambda h: Host( @@ -262,8 +429,11 @@ def __init__( oauth2_found = False openid_found = False jwt_found = False + account_linking_found = False - def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: + def make_recipe( + recipe: Callable[[AppInfo, List[OverrideMap]], RecipeModule], + ) -> RecipeModule: nonlocal \ multitenancy_found, \ totp_found, \ @@ -271,8 +441,10 @@ def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: multi_factor_auth_found, \ oauth2_found, \ openid_found, \ - jwt_found - recipe_module = recipe(self.app_info) + jwt_found, \ + account_linking_found + + recipe_module = recipe(self.app_info, override_maps) if recipe_module.get_recipe_id() == "multitenancy": multitenancy_found = True elif recipe_module.get_recipe_id() == "usermetadata": @@ -287,24 +459,39 @@ def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: openid_found = True elif recipe_module.get_recipe_id() == "jwt": jwt_found = True + elif recipe_module.get_recipe_id() == "accountlinking": + account_linking_found = True return recipe_module self.recipe_modules: List[RecipeModule] = list(map(make_recipe, recipe_list)) + if not account_linking_found: + from supertokens_python.recipe.accountlinking.recipe import ( + AccountLinkingRecipe, + ) + + self.recipe_modules.append( + AccountLinkingRecipe.init()(self.app_info, override_maps) + ) + if not jwt_found: from supertokens_python.recipe.jwt.recipe import JWTRecipe - self.recipe_modules.append(JWTRecipe.init()(self.app_info)) + self.recipe_modules.append(JWTRecipe.init()(self.app_info, override_maps)) if not openid_found: from supertokens_python.recipe.openid.recipe import OpenIdRecipe - self.recipe_modules.append(OpenIdRecipe.init()(self.app_info)) + self.recipe_modules.append( + OpenIdRecipe.init()(self.app_info, override_maps) + ) if not multitenancy_found: from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe - self.recipe_modules.append(MultitenancyRecipe.init()(self.app_info)) + self.recipe_modules.append( + MultitenancyRecipe.init()(self.app_info, override_maps) + ) if totp_found and not multi_factor_auth_found: raise Exception("Please initialize the MultiFactorAuth recipe to use TOTP.") @@ -312,18 +499,22 @@ def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: if not user_metadata_found: from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe - self.recipe_modules.append(UserMetadataRecipe.init()(self.app_info)) + self.recipe_modules.append( + UserMetadataRecipe.init()(self.app_info, override_maps) + ) if not oauth2_found: from supertokens_python.recipe.oauth2provider.recipe import ( OAuth2ProviderRecipe, ) - self.recipe_modules.append(OAuth2ProviderRecipe.init()(self.app_info)) + self.recipe_modules.append( + OAuth2ProviderRecipe.init()(self.app_info, override_maps) + ) self.telemetry = ( - telemetry - if telemetry is not None + config.telemetry + if config.telemetry is not None else (environ.get("TEST_MODE") != "testing") ) @@ -332,10 +523,11 @@ def init( app_info: InputAppInfo, framework: Literal["fastapi", "flask", "django"], supertokens_config: SupertokensConfig, - recipe_list: List[Callable[[AppInfo], RecipeModule]], + recipe_list: List[RecipeInit], mode: Optional[Literal["asgi", "wsgi"]], telemetry: Optional[bool], debug: Optional[bool], + experimental: Optional[SupertokensExperimentalConfig] = None, ): if Supertokens.__instance is None: Supertokens.__instance = Supertokens( @@ -346,6 +538,7 @@ def init( mode, telemetry, debug, + experimental=experimental, ) PostSTInitCallbacks.run_post_init_callbacks() @@ -543,12 +736,58 @@ async def update_or_delete_user_id_mapping_info( async def middleware( self, request: BaseRequest, response: BaseResponse, user_context: Dict[str, Any] ) -> Union[BaseResponse, None]: + from supertokens_python.recipe.session.recipe import SessionRecipe + log_debug_message("middleware: Started") path = Supertokens.get_instance().app_info.api_gateway_path.append( NormalisedURLPath(request.get_path()) ) method = normalise_http_method(request.method()) + handler_from_apis: Optional[PluginRouteHandler] = None + for handler in self.plugin_route_handlers: + if ( + handler.path == path.get_as_string_dangerous() + and handler.method == method + ): + log_debug_message( + "middleware: Found matching plugin route handler for path: %s and method: %s", + path.get_as_string_dangerous(), + method, + ) + handler_from_apis = handler + break + + if handler_from_apis is not None: + session: Optional[SessionContainer] = None + if handler_from_apis.verify_session_options is not None: + verify_session_options = handler_from_apis.verify_session_options + session = await SessionRecipe.get_instance().verify_session( + request=request, + user_context=user_context, + anti_csrf_check=verify_session_options.anti_csrf_check, + session_required=verify_session_options.session_required, + check_database=verify_session_options.check_database, + override_global_claim_validators=verify_session_options.override_global_claim_validators, + ) + + log_debug_message( + f"middleware: Request being handled by plugin `{handler_from_apis.plugin_id}`" + ) + try: + return await handler_from_apis.handler( + request=request, + response=response, + session=session, + user_context=user_context, + ) + except PluginError as err: + log_debug_message( + f"middleware: Error from plugin `{handler_from_apis.plugin_id}`: {str(err)}. " + "Transforming to SuperTokensError." + ) + raise err + if not path.startswith(Supertokens.get_instance().app_info.api_base_path): log_debug_message( "middleware: Not handling because request path did not start with api base path. Request path: %s", @@ -677,7 +916,7 @@ async def handle_supertokens_error( if isinstance(err, GeneralError): raise err - if isinstance(err, BadInputError): + if isinstance(err, (BadInputError, PluginError)): log_debug_message("errorHandler: Sending 400 status code response") return send_non_200_response_with_message(str(err), 400, response) @@ -709,6 +948,18 @@ def get_request_from_user_context( return user_context.get("_default", {}).get("request") + @staticmethod + def is_recipe_initialized(recipe_id: str) -> bool: + """ + Check if a recipe is initialized. + :param recipe_id: The ID of the recipe to check. + :return: Whether the recipe is initialized. + """ + return any( + recipe.get_recipe_id() == recipe_id + for recipe in Supertokens.get_instance().recipe_modules + ) + def get_request_from_user_context( user_context: Optional[Dict[str, Any]], diff --git a/supertokens_python/test.py b/supertokens_python/test.py new file mode 100644 index 000000000..d8b5e8c71 --- /dev/null +++ b/supertokens_python/test.py @@ -0,0 +1,14 @@ +from typing import cast + +from django.http import HttpRequest + +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session.framework.django.asyncio import verify_session + + +# highlight-start +@verify_session() +async def some_api(request: HttpRequest): + session: SessionContainer = cast(SessionContainer, request.supertokens) # type: ignore This will delete the session from the db and from the frontend (cookies) + # highlight-end + await session.revoke_session() diff --git a/supertokens_python/types/__init__.py b/supertokens_python/types/__init__.py index 72f7cda5d..7359f2069 100644 --- a/supertokens_python/types/__init__.py +++ b/supertokens_python/types/__init__.py @@ -27,8 +27,8 @@ __all__ = ( "APIResponse", - "GeneralErrorResponse", "AccountInfo", + "GeneralErrorResponse", "LoginMethod", "MaybeAwaitable", "RecipeUserId", diff --git a/supertokens_python/types/config.py b/supertokens_python/types/config.py new file mode 100644 index 000000000..d6d12808f --- /dev/null +++ b/supertokens_python/types/config.py @@ -0,0 +1,123 @@ +from typing import Callable, Generic, Optional, TypeVar + +from supertokens_python.types.recipe import BaseAPIInterface, BaseRecipeInterface +from supertokens_python.types.response import CamelCaseBaseModel +from supertokens_python.types.utils import UseDefaultIfNone + +T = TypeVar("T") + +"""Generic Type for use in `InterfaceOverride`""" +FunctionInterfaceType = TypeVar("FunctionInterfaceType", bound=BaseRecipeInterface) +"""Generic Type for use in `FunctionOverrideConfig`""" +APIInterfaceType = TypeVar("APIInterfaceType", bound=BaseAPIInterface) +"""Generic Type for use in `APIOverrideConfig`""" + + +InterfaceOverride = Callable[[T], T] + + +class BaseOverrideConfigWithoutAPI(CamelCaseBaseModel, Generic[FunctionInterfaceType]): + """Base class for input override config without API overrides.""" + + functions: UseDefaultIfNone[Optional[InterfaceOverride[FunctionInterfaceType]]] = ( + lambda original_implementation: original_implementation + ) + + +class BaseNormalisedOverrideConfigWithoutAPI( + CamelCaseBaseModel, Generic[FunctionInterfaceType] +): + """Base class for normalized override config without API overrides.""" + + functions: InterfaceOverride[FunctionInterfaceType] = ( + lambda original_implementation: original_implementation + ) + + @classmethod + def from_input_config( + cls, + override_config: Optional[BaseOverrideConfigWithoutAPI[FunctionInterfaceType]], + ) -> "BaseNormalisedOverrideConfigWithoutAPI[FunctionInterfaceType]": + """Create a normalized config from the input config.""" + normalised_config = cls() + + if override_config is None: + return normalised_config + + if override_config.functions is not None: + normalised_config.functions = override_config.functions + + return normalised_config + + +class BaseOverrideConfig( + BaseOverrideConfigWithoutAPI[FunctionInterfaceType], + Generic[FunctionInterfaceType, APIInterfaceType], +): + """Base class for input override config with API overrides.""" + + apis: UseDefaultIfNone[Optional[InterfaceOverride[APIInterfaceType]]] = ( + lambda original_implementation: original_implementation + ) + + +class BaseNormalisedOverrideConfig( + BaseNormalisedOverrideConfigWithoutAPI[FunctionInterfaceType], + Generic[FunctionInterfaceType, APIInterfaceType], +): + """Base class for normalized override config with API overrides.""" + + apis: InterfaceOverride[APIInterfaceType] = ( + lambda original_implementation: original_implementation + ) + + @classmethod + def from_input_config( # type: ignore - invalid override due to subclassing + cls, + override_config: Optional[ + BaseOverrideConfig[FunctionInterfaceType, APIInterfaceType] + ], + ) -> "BaseNormalisedOverrideConfig[FunctionInterfaceType, APIInterfaceType]": # type: ignore + """Create a normalized config from the input config.""" + normalised_config = cls() + + if override_config is None: + return normalised_config + + if override_config.functions is not None: + normalised_config.functions = override_config.functions + + if override_config.apis is not None: + normalised_config.apis = override_config.apis + + return normalised_config + + +class BaseConfigWithoutAPIOverride(CamelCaseBaseModel, Generic[FunctionInterfaceType]): + """Base class for input config of a Recipe without API overrides.""" + + override: Optional[BaseOverrideConfigWithoutAPI[FunctionInterfaceType]] = None + + +class BaseNormalisedConfigWithoutAPIOverride( + CamelCaseBaseModel, Generic[FunctionInterfaceType] +): + """Base class for normalized config of a Recipe without API overrides.""" + + override: BaseNormalisedOverrideConfigWithoutAPI[FunctionInterfaceType] + + +class BaseConfig(CamelCaseBaseModel, Generic[FunctionInterfaceType, APIInterfaceType]): + """Base class for input config of a Recipe with API overrides.""" + + override: Optional[BaseOverrideConfig[FunctionInterfaceType, APIInterfaceType]] = ( + None + ) + + +class BaseNormalisedConfig( + CamelCaseBaseModel, Generic[FunctionInterfaceType, APIInterfaceType] +): + """Base class for normalized config of a Recipe with API overrides.""" + + override: BaseNormalisedOverrideConfig[FunctionInterfaceType, APIInterfaceType] diff --git a/supertokens_python/types/recipe.py b/supertokens_python/types/recipe.py new file mode 100644 index 000000000..566695534 --- /dev/null +++ b/supertokens_python/types/recipe.py @@ -0,0 +1,7 @@ +from abc import ABC + + +class BaseRecipeInterface(ABC): ... + + +class BaseAPIInterface(ABC): ... diff --git a/supertokens_python/types/utils.py b/supertokens_python/types/utils.py new file mode 100644 index 000000000..4e50039cb --- /dev/null +++ b/supertokens_python/types/utils.py @@ -0,0 +1,16 @@ +from typing import Any, TypeVar + +from pydantic import BeforeValidator +from pydantic_core import PydanticUseDefault +from typing_extensions import Annotated + +T = TypeVar("T") + + +def default_if_none(value: Any) -> Any: + if value is None: + return PydanticUseDefault() + return value + + +UseDefaultIfNone = Annotated[T, BeforeValidator(default_if_none)] diff --git a/tests/Django/test_django.py b/tests/Django/test_django.py index f294af2b5..c0036c607 100644 --- a/tests/Django/test_django.py +++ b/tests/Django/test_django.py @@ -32,9 +32,9 @@ ) from supertokens_python.querier import Querier from supertokens_python.recipe import emailpassword, session, thirdparty -from supertokens_python.recipe.dashboard import DashboardRecipe, InputOverrideConfig +from supertokens_python.recipe.dashboard import DashboardOverrideConfig, DashboardRecipe from supertokens_python.recipe.dashboard.interfaces import RecipeInterface -from supertokens_python.recipe.dashboard.utils import DashboardConfig +from supertokens_python.recipe.dashboard.utils import NormalisedDashboardConfig from supertokens_python.recipe.emailpassword.interfaces import APIInterface, APIOptions from supertokens_python.recipe.passwordless import ContactConfig, PasswordlessRecipe from supertokens_python.recipe.session import SessionContainer @@ -53,7 +53,7 @@ def override_dashboard_functions(original_implementation: RecipeInterface): async def should_allow_access( - request: BaseRequest, __: DashboardConfig, ___: Dict[str, Any] + request: BaseRequest, __: NormalisedDashboardConfig, ___: Dict[str, Any] ): auth_header = request.get_header("authorization") return auth_header == "Bearer testapikey" @@ -370,7 +370,7 @@ async def email_exists_get( mode="asgi", recipe_list=[ emailpassword.init( - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=override_email_password_apis ) ) @@ -497,7 +497,7 @@ async def test_search_with_multiple_emails(self): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig( + override=DashboardOverrideConfig( functions=override_dashboard_functions ), ), @@ -548,7 +548,7 @@ async def test_search_with_email_t(self): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig( + override=DashboardOverrideConfig( functions=override_dashboard_functions ), ), @@ -597,7 +597,7 @@ async def test_search_with_email_iresh(self): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig( + override=DashboardOverrideConfig( functions=override_dashboard_functions ), ), @@ -648,7 +648,7 @@ async def test_search_with_phone_plus_one(self): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig( + override=DashboardOverrideConfig( functions=override_dashboard_functions ), ), @@ -702,7 +702,7 @@ async def test_search_with_phone_one_bracket(self): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig( + override=DashboardOverrideConfig( functions=override_dashboard_functions ), ), @@ -799,7 +799,7 @@ async def test_search_with_provider_google(self): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig( + override=DashboardOverrideConfig( functions=override_dashboard_functions ), ), @@ -892,7 +892,7 @@ async def test_search_with_provider_google_and_phone_one(self): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig( + override=DashboardOverrideConfig( functions=override_dashboard_functions ), ), diff --git a/tests/Fastapi/test_fastapi.py b/tests/Fastapi/test_fastapi.py index 4741713d5..b9c35c01a 100644 --- a/tests/Fastapi/test_fastapi.py +++ b/tests/Fastapi/test_fastapi.py @@ -22,9 +22,9 @@ from supertokens_python.framework.fastapi import get_middleware from supertokens_python.querier import Querier from supertokens_python.recipe import emailpassword, session, thirdparty -from supertokens_python.recipe.dashboard import DashboardRecipe, InputOverrideConfig +from supertokens_python.recipe.dashboard import DashboardOverrideConfig, DashboardRecipe from supertokens_python.recipe.dashboard.interfaces import RecipeInterface -from supertokens_python.recipe.dashboard.utils import DashboardConfig +from supertokens_python.recipe.dashboard.utils import NormalisedDashboardConfig from supertokens_python.recipe.emailpassword.interfaces import ( APIInterface as EPAPIInterface, ) @@ -60,7 +60,7 @@ def override_dashboard_functions(original_implementation: RecipeInterface): async def should_allow_access( - request: BaseRequest, __: DashboardConfig, ___: Dict[str, Any] + request: BaseRequest, __: NormalisedDashboardConfig, ___: Dict[str, Any] ): auth_header = request.get_header("authorization") return auth_header == "Bearer testapikey" @@ -157,7 +157,7 @@ async def test_login_refresh(driver_config_client: TestClient): anti_csrf="VIA_TOKEN", cookie_domain="supertokens.io", get_token_transfer_method=lambda _, __, ___: "cookie", - override=session.InputOverrideConfig(apis=apis_override_session), + override=session.SessionOverrideConfig(apis=apis_override_session), ) ], ) @@ -462,7 +462,7 @@ async def email_exists_get( framework="fastapi", recipe_list=[ emailpassword.init( - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=override_email_password_apis ) ) @@ -570,7 +570,7 @@ async def refresh_post( recipe_list=[ session.init( anti_csrf="VIA_TOKEN", - override=session.InputOverrideConfig(apis=override_session_apis), + override=session.SessionOverrideConfig(apis=override_session_apis), ) ], ) @@ -686,7 +686,9 @@ async def test_search_with_email_t(driver_config_client: TestClient): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig(functions=override_dashboard_functions), + override=DashboardOverrideConfig( + functions=override_dashboard_functions + ), ), emailpassword.init(), ], @@ -731,7 +733,9 @@ async def test_search_with_email_multiple_email_entry(driver_config_client: Test ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig(functions=override_dashboard_functions), + override=DashboardOverrideConfig( + functions=override_dashboard_functions + ), ), emailpassword.init(), ], @@ -776,7 +780,9 @@ async def test_search_with_email_iresh(driver_config_client: TestClient): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig(functions=override_dashboard_functions), + override=DashboardOverrideConfig( + functions=override_dashboard_functions + ), ), emailpassword.init(), ], @@ -821,7 +827,9 @@ async def test_search_with_phone_plus_one(driver_config_client: TestClient): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig(functions=override_dashboard_functions), + override=DashboardOverrideConfig( + functions=override_dashboard_functions + ), ), PasswordlessRecipe.init( contact_config=ContactConfig(contact_method="EMAIL"), @@ -869,7 +877,9 @@ async def test_search_with_phone_one_bracket(driver_config_client: TestClient): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig(functions=override_dashboard_functions), + override=DashboardOverrideConfig( + functions=override_dashboard_functions + ), ), PasswordlessRecipe.init( contact_config=ContactConfig(contact_method="EMAIL"), @@ -917,7 +927,9 @@ async def test_search_with_provider_google(driver_config_client: TestClient): ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig(functions=override_dashboard_functions), + override=DashboardOverrideConfig( + functions=override_dashboard_functions + ), ), thirdparty.init( sign_in_and_up_feature=thirdparty.SignInAndUpFeature( @@ -1006,7 +1018,9 @@ async def test_search_with_provider_google_and_phone_1( ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig(functions=override_dashboard_functions), + override=DashboardOverrideConfig( + functions=override_dashboard_functions + ), ), PasswordlessRecipe.init( contact_config=ContactConfig(contact_method="EMAIL"), diff --git a/tests/Flask/test_flask.py b/tests/Flask/test_flask.py index bdcc50658..40671df58 100644 --- a/tests/Flask/test_flask.py +++ b/tests/Flask/test_flask.py @@ -28,9 +28,9 @@ ) from supertokens_python.querier import Querier from supertokens_python.recipe import emailpassword, session, thirdparty -from supertokens_python.recipe.dashboard import DashboardRecipe, InputOverrideConfig +from supertokens_python.recipe.dashboard import DashboardOverrideConfig, DashboardRecipe from supertokens_python.recipe.dashboard.interfaces import RecipeInterface -from supertokens_python.recipe.dashboard.utils import DashboardConfig +from supertokens_python.recipe.dashboard.utils import NormalisedDashboardConfig from supertokens_python.recipe.emailpassword.interfaces import APIInterface, APIOptions from supertokens_python.recipe.passwordless import ContactConfig, PasswordlessRecipe from supertokens_python.recipe.session import SessionContainer @@ -81,7 +81,7 @@ async def email_exists_get( def override_dashboard_functions(original_implementation: RecipeInterface): async def should_allow_access( - request: BaseRequest, __: DashboardConfig, ___: Dict[str, Any] + request: BaseRequest, __: NormalisedDashboardConfig, ___: Dict[str, Any] ): auth_header = request.get_header("authorization") return auth_header == "Bearer testapikey" @@ -110,7 +110,7 @@ async def should_allow_access( get_token_transfer_method=lambda _, __, ___: "cookie", ), emailpassword.init( - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=override_email_password_apis ) ), @@ -159,7 +159,9 @@ async def should_allow_access( ), DashboardRecipe.init( api_key="testapikey", - override=InputOverrideConfig(functions=override_dashboard_functions), + override=DashboardOverrideConfig( + functions=override_dashboard_functions + ), ), PasswordlessRecipe.init( contact_config=ContactConfig(contact_method="EMAIL"), diff --git a/tests/auth-react/django3x/mysite/utils.py b/tests/auth-react/django3x/mysite/utils.py index 97335f18e..ebb40e7b5 100644 --- a/tests/auth-react/django3x/mysite/utils.py +++ b/tests/auth-react/django3x/mysite/utils.py @@ -8,7 +8,7 @@ from supertokens_python import InputAppInfo, Supertokens, SupertokensConfig, init from supertokens_python.framework.request import BaseRequest from supertokens_python.ingredients.emaildelivery.types import ( - EmailDeliveryConfigWithService, + EmailDeliveryConfig, EmailDeliveryInterface, ) from supertokens_python.recipe import ( @@ -806,7 +806,7 @@ async def resend_code_post( contact_config=ContactPhoneOnlyConfig(), flow_type=passwordlessFlowType, # type: ignore - type expects only certain literals sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_passwordless_apis ), ) @@ -817,7 +817,7 @@ async def resend_code_post( email_delivery=passwordless.EmailDeliveryConfig( CustomPlessEmailService() ), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_passwordless_apis ), ) @@ -829,7 +829,7 @@ async def resend_code_post( CustomPlessEmailService() ), sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_passwordless_apis ), ) @@ -839,7 +839,9 @@ async def resend_code_post( flow_type="USER_INPUT_CODE_AND_MAGIC_LINK", email_delivery=passwordless.EmailDeliveryConfig(CustomPlessEmailService()), sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), - override=passwordless.InputOverrideConfig(apis=override_passwordless_apis), + override=passwordless.PasswordlessOverrideConfig( + apis=override_passwordless_apis + ), ) async def get_allowed_domains_for_tenant_id( @@ -968,7 +970,7 @@ async def resync_session_and_fetch_mfa_info_put( { "id": "session", "init": session.init( - override=session.InputOverrideConfig(apis=override_session_apis) + override=session.SessionOverrideConfig(apis=override_session_apis) ), }, { @@ -988,7 +990,7 @@ async def resync_session_and_fetch_mfa_info_put( email_delivery=emailpassword.EmailDeliveryConfig( CustomEPEmailService() ), - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=override_email_password_apis, ), ), @@ -997,9 +999,9 @@ async def resync_session_and_fetch_mfa_info_put( "id": "webauthn", "init": webauthn.init( config=WebauthnConfig( - email_delivery=EmailDeliveryConfigWithService[ - TypeWebauthnEmailDeliveryInput - ](service=CustomWebwuthnEmailService()) # type: ignore + email_delivery=EmailDeliveryConfig[TypeWebauthnEmailDeliveryInput]( + service=CustomWebwuthnEmailService() + ) ) ), }, @@ -1007,7 +1009,9 @@ async def resync_session_and_fetch_mfa_info_put( "id": "thirdparty", "init": thirdparty.init( sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), - override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), + override=thirdparty.ThirdPartyOverrideConfig( + apis=override_thirdparty_apis + ), ), }, { @@ -1024,7 +1028,7 @@ async def resync_session_and_fetch_mfa_info_put( "id": "multifactorauth", "init": multifactorauth.init( first_factors=mfaInfo.get("firstFactors", None), - override=multifactorauth.OverrideConfig( + override=multifactorauth.MultiFactorAuthOverrideConfig( functions=override_mfa_functions, apis=override_mfa_apis, ), diff --git a/tests/auth-react/fastapi-server/app.py b/tests/auth-react/fastapi-server/app.py index c09a67f23..84a004f46 100644 --- a/tests/auth-react/fastapi-server/app.py +++ b/tests/auth-react/fastapi-server/app.py @@ -41,7 +41,7 @@ from supertokens_python.framework.fastapi import get_middleware from supertokens_python.framework.request import BaseRequest from supertokens_python.ingredients.emaildelivery.types import ( - EmailDeliveryConfigWithService, + EmailDeliveryConfig, EmailDeliveryInterface, ) from supertokens_python.recipe import ( @@ -907,7 +907,7 @@ async def resend_code_post( contact_config=ContactPhoneOnlyConfig(), flow_type=passwordlessFlowType, # type: ignore - type expects only certain literals sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_passwordless_apis ), ) @@ -918,7 +918,7 @@ async def resend_code_post( email_delivery=passwordless.EmailDeliveryConfig( CustomPlessEmailService() ), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_passwordless_apis ), ) @@ -930,7 +930,7 @@ async def resend_code_post( CustomPlessEmailService() ), sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_passwordless_apis ), ) @@ -940,7 +940,9 @@ async def resend_code_post( flow_type="USER_INPUT_CODE_AND_MAGIC_LINK", email_delivery=passwordless.EmailDeliveryConfig(CustomPlessEmailService()), sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), - override=passwordless.InputOverrideConfig(apis=override_passwordless_apis), + override=passwordless.PasswordlessOverrideConfig( + apis=override_passwordless_apis + ), ) async def get_allowed_domains_for_tenant_id( @@ -1069,7 +1071,7 @@ async def resync_session_and_fetch_mfa_info_put( { "id": "session", "init": session.init( - override=session.InputOverrideConfig(apis=override_session_apis) + override=session.SessionOverrideConfig(apis=override_session_apis) ), }, { @@ -1089,7 +1091,7 @@ async def resync_session_and_fetch_mfa_info_put( email_delivery=emailpassword.EmailDeliveryConfig( CustomEPEmailService() ), - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=override_email_password_apis, ), ), @@ -1098,9 +1100,9 @@ async def resync_session_and_fetch_mfa_info_put( "id": "webauthn", "init": webauthn.init( config=WebauthnConfig( - email_delivery=EmailDeliveryConfigWithService[ - TypeWebauthnEmailDeliveryInput - ](service=CustomWebwuthnEmailService()) # type: ignore + email_delivery=EmailDeliveryConfig[TypeWebauthnEmailDeliveryInput]( + service=CustomWebwuthnEmailService() + ) ) ), }, @@ -1108,7 +1110,9 @@ async def resync_session_and_fetch_mfa_info_put( "id": "thirdparty", "init": thirdparty.init( sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), - override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), + override=thirdparty.ThirdPartyOverrideConfig( + apis=override_thirdparty_apis + ), ), }, { @@ -1125,7 +1129,7 @@ async def resync_session_and_fetch_mfa_info_put( "id": "multifactorauth", "init": multifactorauth.init( first_factors=mfaInfo.get("firstFactors", None), - override=multifactorauth.OverrideConfig( + override=multifactorauth.MultiFactorAuthOverrideConfig( functions=override_mfa_functions, apis=override_mfa_apis, ), diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index 11c8698c5..14e9ae5db 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -32,7 +32,7 @@ from supertokens_python.framework.flask.flask_middleware import Middleware from supertokens_python.framework.request import BaseRequest from supertokens_python.ingredients.emaildelivery.types import ( - EmailDeliveryConfigWithService, + EmailDeliveryConfig, EmailDeliveryInterface, ) from supertokens_python.recipe import ( @@ -486,13 +486,11 @@ def custom_init( WebauthnRecipe.reset() def override_email_verification_apis( - original_implementation_email_verification: EmailVerificationAPIInterface, + original_implementation: EmailVerificationAPIInterface, ): - original_email_verify_post = ( - original_implementation_email_verification.email_verify_post - ) + original_email_verify_post = original_implementation.email_verify_post original_generate_email_verify_token_post = ( - original_implementation_email_verification.generate_email_verify_token_post + original_implementation.generate_email_verify_token_post ) async def email_verify_post( @@ -531,11 +529,11 @@ async def generate_email_verify_token_post( session, api_options, user_context ) - original_implementation_email_verification.email_verify_post = email_verify_post - original_implementation_email_verification.generate_email_verify_token_post = ( + original_implementation.email_verify_post = email_verify_post + original_implementation.generate_email_verify_token_post = ( generate_email_verify_token_post ) - return original_implementation_email_verification + return original_implementation def override_email_password_apis( original_implementation: EmailPasswordAPIInterface, @@ -888,7 +886,7 @@ async def resend_code_post( contact_config=ContactPhoneOnlyConfig(), flow_type=passwordlessFlowType, # type: ignore - type expects only certain literals sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_passwordless_apis ), ) @@ -899,7 +897,7 @@ async def resend_code_post( email_delivery=passwordless.EmailDeliveryConfig( CustomPlessEmailService() ), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_passwordless_apis ), ) @@ -911,7 +909,7 @@ async def resend_code_post( CustomPlessEmailService() ), sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_passwordless_apis ), ) @@ -921,7 +919,9 @@ async def resend_code_post( flow_type="USER_INPUT_CODE_AND_MAGIC_LINK", email_delivery=passwordless.EmailDeliveryConfig(CustomPlessEmailService()), sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), - override=passwordless.InputOverrideConfig(apis=override_passwordless_apis), + override=passwordless.PasswordlessOverrideConfig( + apis=override_passwordless_apis + ), ) async def get_allowed_domains_for_tenant_id( @@ -1050,7 +1050,7 @@ async def resync_session_and_fetch_mfa_info_put( { "id": "session", "init": session.init( - override=session.InputOverrideConfig(apis=override_session_apis) + override=session.SessionOverrideConfig(apis=override_session_apis) ), }, { @@ -1070,7 +1070,7 @@ async def resync_session_and_fetch_mfa_info_put( email_delivery=emailpassword.EmailDeliveryConfig( CustomEPEmailService() ), - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=override_email_password_apis, ), ), @@ -1079,9 +1079,9 @@ async def resync_session_and_fetch_mfa_info_put( "id": "webauthn", "init": webauthn.init( config=WebauthnConfig( - email_delivery=EmailDeliveryConfigWithService[ - TypeWebauthnEmailDeliveryInput - ](service=CustomWebwuthnEmailService()) # type: ignore + email_delivery=EmailDeliveryConfig[TypeWebauthnEmailDeliveryInput]( + service=CustomWebwuthnEmailService() + ) ) ), }, @@ -1089,7 +1089,9 @@ async def resync_session_and_fetch_mfa_info_put( "id": "thirdparty", "init": thirdparty.init( sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), - override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), + override=thirdparty.ThirdPartyOverrideConfig( + apis=override_thirdparty_apis + ), ), }, { @@ -1106,7 +1108,7 @@ async def resync_session_and_fetch_mfa_info_put( "id": "multifactorauth", "init": multifactorauth.init( first_factors=mfaInfo.get("firstFactors", None), - override=multifactorauth.OverrideConfig( + override=multifactorauth.MultiFactorAuthOverrideConfig( functions=override_mfa_functions, apis=override_mfa_apis, ), diff --git a/tests/dashboard/test_dashboard.py b/tests/dashboard/test_dashboard.py index ac348ed98..4adfd5e48 100644 --- a/tests/dashboard/test_dashboard.py +++ b/tests/dashboard/test_dashboard.py @@ -14,11 +14,11 @@ thirdparty, usermetadata, ) -from supertokens_python.recipe.dashboard import InputOverrideConfig +from supertokens_python.recipe.dashboard import DashboardOverrideConfig from supertokens_python.recipe.dashboard.interfaces import ( RecipeInterface as DashboardRI, ) -from supertokens_python.recipe.dashboard.utils import DashboardConfig +from supertokens_python.recipe.dashboard.utils import NormalisedDashboardConfig from supertokens_python.recipe.passwordless import ContactEmailOrPhoneConfig from supertokens_python.recipe.thirdparty.asyncio import manually_create_or_update_user from supertokens_python.recipe.thirdparty.interfaces import ( @@ -49,7 +49,7 @@ async def test_dashboard_recipe(app: TestClient): def override_dashboard_functions(oi: DashboardRI) -> DashboardRI: async def should_allow_access( _request: BaseRequest, - _config: DashboardConfig, + _config: NormalisedDashboardConfig, _user_context: Dict[str, Any], ) -> bool: return True @@ -63,7 +63,9 @@ async def should_allow_access( session.init(get_token_transfer_method=lambda _, __, ___: "cookie"), dashboard.init( api_key="someKey", - override=InputOverrideConfig(functions=override_dashboard_functions), + override=DashboardOverrideConfig( + functions=override_dashboard_functions + ), ), ], ) @@ -83,7 +85,7 @@ async def test_dashboard_users_get(app: TestClient): def override_dashboard_functions(oi: DashboardRI) -> DashboardRI: async def should_allow_access( _request: BaseRequest, - _config: DashboardConfig, + _config: NormalisedDashboardConfig, _user_context: Dict[str, Any], ) -> bool: return True @@ -99,7 +101,7 @@ async def should_allow_access( usermetadata.init(), dashboard.init( api_key="someKey", - override=InputOverrideConfig( + override=DashboardOverrideConfig( functions=override_dashboard_functions, ), ), @@ -211,7 +213,7 @@ async def test_that_get_user_works_with_combination_recipes(app: TestClient): def override_dashboard_functions(oi: DashboardRI) -> DashboardRI: async def should_allow_access( _request: BaseRequest, - _config: DashboardConfig, + _config: NormalisedDashboardConfig, _user_context: Dict[str, Any], ) -> bool: return True @@ -231,7 +233,7 @@ async def should_allow_access( usermetadata.init(), dashboard.init( api_key="someKey", - override=InputOverrideConfig( + override=DashboardOverrideConfig( functions=override_dashboard_functions, ), ), diff --git a/tests/emailpassword/test_emailexists.py b/tests/emailpassword/test_emailexists.py index 8536b6d0a..1dfa97f9d 100644 --- a/tests/emailpassword/test_emailexists.py +++ b/tests/emailpassword/test_emailexists.py @@ -188,7 +188,7 @@ async def test_that_if_disabling_api_the_default_email_exists_api_does_not_work( recipe_list=[ session.init(anti_csrf="VIA_TOKEN", cookie_domain="supertokens.io"), emailpassword.init( - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=apis_override_email_password ) ), diff --git a/tests/emailpassword/test_emailverify.py b/tests/emailpassword/test_emailverify.py index 1b3bcb12d..1ac3a3b97 100644 --- a/tests/emailpassword/test_emailverify.py +++ b/tests/emailpassword/test_emailverify.py @@ -44,7 +44,9 @@ from supertokens_python.recipe.emailverification.types import ( EmailVerificationUser as EVUser, ) -from supertokens_python.recipe.emailverification.utils import OverrideConfig +from supertokens_python.recipe.emailverification.utils import ( + EmailVerificationOverrideConfig, +) from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.session.asyncio import ( create_new_session, @@ -685,7 +687,9 @@ async def email_verify_post( email_delivery=emailverification.EmailDeliveryConfig( CustomEmailService() ), - override=OverrideConfig(apis=apis_override_email_password), + override=EmailVerificationOverrideConfig( + apis=apis_override_email_password + ), ), emailpassword.init(), ], @@ -917,7 +921,9 @@ async def email_verify_post( email_delivery=emailverification.EmailDeliveryConfig( CustomEmailService() ), - override=OverrideConfig(apis=apis_override_email_password), + override=EmailVerificationOverrideConfig( + apis=apis_override_email_password + ), ), emailpassword.init(), ], @@ -1025,7 +1031,7 @@ async def email_verify_post( email_delivery=emailverification.EmailDeliveryConfig( CustomEmailService() ), - override=emailverification.InputOverrideConfig( + override=emailverification.EmailVerificationOverrideConfig( apis=apis_override_email_password ), ), diff --git a/tests/emailpassword/test_signin.py b/tests/emailpassword/test_signin.py index 6d2ce23c4..5a69bcca9 100644 --- a/tests/emailpassword/test_signin.py +++ b/tests/emailpassword/test_signin.py @@ -102,7 +102,7 @@ def apis_override_email_password(param: APIInterface): framework="fastapi", recipe_list=[ emailpassword.init( - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=apis_override_email_password ) ) diff --git a/tests/frontendIntegration/django2x/polls/views.py b/tests/frontendIntegration/django2x/polls/views.py index 5d2a587e4..a0128eb48 100644 --- a/tests/frontendIntegration/django2x/polls/views.py +++ b/tests/frontendIntegration/django2x/polls/views.py @@ -36,6 +36,7 @@ from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier from supertokens_python.recipe import session +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.jwt.recipe import JWTRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe @@ -279,13 +280,13 @@ def unauthorised_f(req: BaseRequest, message: str, res: BaseResponse): return res -def apis_override_session(param: APIInterface): - param.disable_refresh_post = True - return param +def apis_override_session(original_implementation: APIInterface): + original_implementation.disable_refresh_post = True + return original_implementation -def functions_override_session(param: RecipeInterface): - original_create_new_session = param.create_new_session +def functions_override_session(original_implementation: RecipeInterface): + original_create_new_session = original_implementation.create_new_session async def create_new_session_custom( user_id: str, @@ -309,9 +310,9 @@ async def create_new_session_custom( user_context, ) - param.create_new_session = create_new_session_custom + original_implementation.create_new_session = create_new_session_custom - return param + return original_implementation def get_app_port(): @@ -347,7 +348,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -371,7 +372,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -392,7 +393,7 @@ def config( session.init( error_handlers=InputErrorHandlers(on_unauthorised=unauthorised_f), anti_csrf=anti_csrf, - override=session.InputOverrideConfig(apis=apis_override_session), + override=session.SessionOverrideConfig(apis=apis_override_session), ) ], telemetry=False, @@ -607,6 +608,7 @@ def reinitialize(request: HttpRequest): OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( data["coreUrl"], last_set_enable_anti_csrf, @@ -627,6 +629,7 @@ def setup_st(request: HttpRequest): OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( core_url=data["coreUrl"], enable_anti_csrf=data.get("enableAntiCsrf"), diff --git a/tests/frontendIntegration/django3x/polls/views.py b/tests/frontendIntegration/django3x/polls/views.py index 0567a0c2f..1f74a5ef9 100644 --- a/tests/frontendIntegration/django3x/polls/views.py +++ b/tests/frontendIntegration/django3x/polls/views.py @@ -34,6 +34,7 @@ from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier from supertokens_python.recipe import session +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.jwt.recipe import JWTRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe @@ -349,7 +350,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -373,7 +374,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -394,7 +395,7 @@ def config( session.init( error_handlers=InputErrorHandlers(on_unauthorised=unauthorised_f), anti_csrf=anti_csrf, - override=session.InputOverrideConfig(apis=apis_override_session), + override=session.SessionOverrideConfig(apis=apis_override_session), ) ], telemetry=False, @@ -608,6 +609,7 @@ async def reinitialize(request: HttpRequest): OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( data["coreUrl"], last_set_enable_anti_csrf, @@ -628,6 +630,7 @@ async def setup_st(request: HttpRequest): OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( core_url=data["coreUrl"], enable_anti_csrf=data.get("enableAntiCsrf"), diff --git a/tests/frontendIntegration/drf_async/mysite/settings.py b/tests/frontendIntegration/drf_async/mysite/settings.py index a7f9aba4f..ca43b8b8f 100644 --- a/tests/frontendIntegration/drf_async/mysite/settings.py +++ b/tests/frontendIntegration/drf_async/mysite/settings.py @@ -162,13 +162,13 @@ def get_app_port(): return "8080" -def apis_override_session(param: APIInterface): - param.disable_refresh_post = True - return param +def apis_override_session(original_implementation: APIInterface): + original_implementation.disable_refresh_post = True + return original_implementation -def functions_override_session(param: RecipeInterface): - original_create_new_session = param.create_new_session +def functions_override_session(original_implementation: RecipeInterface): + original_create_new_session = original_implementation.create_new_session async def create_new_session_custom( user_id: str, @@ -192,9 +192,9 @@ async def create_new_session_custom( user_context, ) - param.create_new_session = create_new_session_custom + original_implementation.create_new_session = create_new_session_custom - return param + return original_implementation init( @@ -207,7 +207,7 @@ async def create_new_session_custom( framework="django", recipe_list=[ session.init( - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), diff --git a/tests/frontendIntegration/drf_async/polls/views.py b/tests/frontendIntegration/drf_async/polls/views.py index f5e53db57..28513fbfa 100644 --- a/tests/frontendIntegration/drf_async/polls/views.py +++ b/tests/frontendIntegration/drf_async/polls/views.py @@ -40,6 +40,7 @@ from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier from supertokens_python.recipe import session +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.jwt.recipe import JWTRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe @@ -306,13 +307,13 @@ async def unauthorised_f(req: BaseRequest, message: str, res: BaseResponse): return res -def apis_override_session(param: APIInterface): - param.disable_refresh_post = True - return param +def apis_override_session(original_implementation: APIInterface): + original_implementation.disable_refresh_post = True + return original_implementation -def functions_override_session(param: RecipeInterface): - original_create_new_session = param.create_new_session +def functions_override_session(original_implementation: RecipeInterface): + original_create_new_session = original_implementation.create_new_session async def create_new_session_custom( user_id: str, @@ -336,9 +337,9 @@ async def create_new_session_custom( user_context, ) - param.create_new_session = create_new_session_custom + original_implementation.create_new_session = create_new_session_custom - return param + return original_implementation def get_app_port(): @@ -374,7 +375,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -398,7 +399,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -419,7 +420,7 @@ def config( session.init( error_handlers=InputErrorHandlers(on_unauthorised=unauthorised_f), anti_csrf=anti_csrf, - override=session.InputOverrideConfig(apis=apis_override_session), + override=session.SessionOverrideConfig(apis=apis_override_session), ) ], telemetry=False, @@ -671,6 +672,7 @@ async def reinitialize(request: Request): # type: ignore OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( data["coreUrl"], # type: ignore last_set_enable_anti_csrf, @@ -693,6 +695,7 @@ async def setup_st(request: HttpRequest): # type: ignore OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( core_url=data["coreUrl"], enable_anti_csrf=data.get("enableAntiCsrf"), diff --git a/tests/frontendIntegration/drf_sync/mysite/settings.py b/tests/frontendIntegration/drf_sync/mysite/settings.py index 37c3dc651..70e011e9c 100644 --- a/tests/frontendIntegration/drf_sync/mysite/settings.py +++ b/tests/frontendIntegration/drf_sync/mysite/settings.py @@ -207,7 +207,7 @@ async def create_new_session_custom( framework="django", recipe_list=[ session.init( - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), diff --git a/tests/frontendIntegration/drf_sync/polls/views.py b/tests/frontendIntegration/drf_sync/polls/views.py index 0f869b3e2..4fb1a1001 100644 --- a/tests/frontendIntegration/drf_sync/polls/views.py +++ b/tests/frontendIntegration/drf_sync/polls/views.py @@ -39,6 +39,7 @@ from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier from supertokens_python.recipe import session +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.jwt.recipe import JWTRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe @@ -373,7 +374,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -397,7 +398,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -418,7 +419,7 @@ def config( session.init( error_handlers=InputErrorHandlers(on_unauthorised=unauthorised_f), anti_csrf=anti_csrf, - override=session.InputOverrideConfig(apis=apis_override_session), + override=session.SessionOverrideConfig(apis=apis_override_session), ) ], telemetry=False, @@ -673,6 +674,7 @@ def reinitialize(request: Request): # type: ignore OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( data["coreUrl"], # type: ignore last_set_enable_anti_csrf, @@ -695,6 +697,7 @@ def setup_st(request: HttpRequest): # type: ignore OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( core_url=data["coreUrl"], enable_anti_csrf=data.get("enableAntiCsrf"), diff --git a/tests/frontendIntegration/fastapi-server/app.py b/tests/frontendIntegration/fastapi-server/app.py index d6766074f..ad8a46194 100644 --- a/tests/frontendIntegration/fastapi-server/app.py +++ b/tests/frontendIntegration/fastapi-server/app.py @@ -38,6 +38,7 @@ from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier from supertokens_python.recipe import session +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.jwt.recipe import JWTRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe @@ -131,13 +132,13 @@ async def unauthorised_f(_: BaseRequest, __: str, res: BaseResponse): return res -def apis_override_session(param: APIInterface): - param.disable_refresh_post = True - return param +def apis_override_session(original_implementation: APIInterface): + original_implementation.disable_refresh_post = True + return original_implementation -def functions_override_session(param: RecipeInterface): - original_create_new_session = param.create_new_session +def functions_override_session(original_implementation: RecipeInterface): + original_create_new_session = original_implementation.create_new_session async def create_new_session_custom( user_id: str, @@ -161,9 +162,9 @@ async def create_new_session_custom( user_context, ) - param.create_new_session = create_new_session_custom + original_implementation.create_new_session = create_new_session_custom - return param + return original_implementation def get_app_port(): @@ -199,7 +200,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -223,7 +224,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -244,7 +245,7 @@ def config( session.init( error_handlers=InputErrorHandlers(on_unauthorised=unauthorised_f), anti_csrf=anti_csrf, - override=session.InputOverrideConfig(apis=apis_override_session), + override=session.SessionOverrideConfig(apis=apis_override_session), ) ], telemetry=False, @@ -657,6 +658,7 @@ async def reinitialize(request: Request): OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( json["coreUrl"], last_set_enable_anti_csrf, @@ -678,6 +680,7 @@ async def setup_st(request: Request): OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( core_url=json["coreUrl"], enable_anti_csrf=json.get("enableAntiCsrf"), diff --git a/tests/frontendIntegration/flask-server/app.py b/tests/frontendIntegration/flask-server/app.py index 27065f66f..c3134208e 100644 --- a/tests/frontendIntegration/flask-server/app.py +++ b/tests/frontendIntegration/flask-server/app.py @@ -38,6 +38,7 @@ from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier from supertokens_python.recipe import session +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.jwt.recipe import JWTRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe @@ -162,13 +163,13 @@ async def unauthorised_f(_: BaseRequest, __: str, res: BaseResponse): return res -def apis_override_session(param: APIInterface): - param.disable_refresh_post = True - return param +def apis_override_session(original_implementation: APIInterface): + original_implementation.disable_refresh_post = True + return original_implementation -def functions_override_session(param: RecipeInterface): - original_create_new_session = param.create_new_session +def functions_override_session(original_implementation: RecipeInterface): + original_create_new_session = original_implementation.create_new_session async def create_new_session_custom( user_id: str, @@ -192,9 +193,9 @@ async def create_new_session_custom( user_context, ) - param.create_new_session = create_new_session_custom + original_implementation.create_new_session = create_new_session_custom - return param + return original_implementation def get_app_port(): @@ -230,7 +231,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -254,7 +255,7 @@ def config( on_unauthorised=unauthorised_f ), anti_csrf=anti_csrf, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=apis_override_session, functions=functions_override_session, ), @@ -275,23 +276,13 @@ def config( session.init( error_handlers=InputErrorHandlers(on_unauthorised=unauthorised_f), anti_csrf=anti_csrf, - override=session.InputOverrideConfig(apis=apis_override_session), + override=session.SessionOverrideConfig(apis=apis_override_session), ) ], telemetry=False, ) -core_host = os.environ.get("SUPERTOKENS_CORE_HOST", "localhost") -core_port = os.environ.get("SUPERTOKENS_CORE_PORT", "3567") -config( - core_url=f"http://{core_host}:{core_port}", - enable_anti_csrf=True, - enable_jwt=False, - jwt_property_name=None, -) - - @app.route("/index.html", methods=["GET"]) # type: ignore def send_file(): return render_template("index.html") @@ -674,6 +665,7 @@ def reinitialize(): OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( json["coreUrl"], last_set_enable_anti_csrf, # type: ignore @@ -695,6 +687,7 @@ async def setup_st(): # type: ignore OpenIdRecipe.reset() OAuth2ProviderRecipe.reset() JWTRecipe.reset() + AccountLinkingRecipe.reset() config( core_url=json["coreUrl"], enable_anti_csrf=json.get("enableAntiCsrf"), # type: ignore @@ -733,5 +726,14 @@ def handle_exception(e): # type: ignore return Response(str(e), status=500) # type: ignore +core_host = os.environ.get("SUPERTOKENS_CORE_HOST", "localhost") +core_port = os.environ.get("SUPERTOKENS_CORE_PORT", "3567") +config( + core_url=f"http://{core_host}:{core_port}", + enable_anti_csrf=True, + enable_jwt=False, + jwt_property_name=None, +) + if __name__ == "__main__": app.run(host="0.0.0.0", port=int(get_app_port()), threaded=True) diff --git a/tests/input_validation/test_input_validation.py b/tests/input_validation/test_input_validation.py index b9ce47f3c..ae6111fa9 100644 --- a/tests/input_validation/test_input_validation.py +++ b/tests/input_validation/test_input_validation.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List import pytest +from pydantic import ValidationError from supertokens_python import InputAppInfo, SupertokensConfig, init from supertokens_python.recipe import ( emailpassword, @@ -23,7 +24,9 @@ @pytest.mark.asyncio async def test_init_validation_emailpassword(): - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValueError, match="app_info must be an instance of InputAppInfo" + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info="AppInfo", # type: ignore @@ -32,9 +35,10 @@ async def test_init_validation_emailpassword(): emailpassword.init(), ], ) - assert "app_info must be an instance of InputAppInfo" == str(ex.value) - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, match="Input should be 'REQUIRED' or 'OPTIONAL'" + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -49,10 +53,6 @@ async def test_init_validation_emailpassword(): emailpassword.init(), ], ) - assert ( - "Email Verification recipe mode must be one of 'REQUIRED' or 'OPTIONAL'" - == str(ex.value) - ) async def get_email_for_user_id(_: RecipeUserId, __: Dict[str, Any]): @@ -61,7 +61,9 @@ async def get_email_for_user_id(_: RecipeUserId, __: Dict[str, Any]): @pytest.mark.asyncio async def test_init_validation_emailverification(): - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, match="Input should be 'REQUIRED' or 'OPTIONAL'" + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -73,12 +75,11 @@ async def test_init_validation_emailverification(): framework="fastapi", recipe_list=[emailverification.init("config")], # type: ignore ) - assert ( - "Email Verification recipe mode must be one of 'REQUIRED' or 'OPTIONAL'" - == str(ex.value) - ) - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, + match="Input should be a valid dictionary or instance of BaseOverrideConfig\\[RecipeInterface, APIInterface\\]", + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -96,26 +97,29 @@ async def test_init_validation_emailverification(): ) ], ) - assert "override must be of type OverrideConfig or None" == str(ex.value) @pytest.mark.asyncio async def test_init_validation_jwt(): - with pytest.raises(ValueError) as ex: - init( - supertokens_config=SupertokensConfig(get_new_core_app_url()), - app_info=InputAppInfo( - app_name="SuperTokens Demo", - api_domain="http://api.supertokens.io", - website_domain="http://supertokens.io", - api_base_path="/auth", - ), - framework="fastapi", - recipe_list=[jwt.init(jwt_validity_seconds="100")], # type: ignore - ) - assert "jwt_validity_seconds must be an integer or None" == str(ex.value) - - with pytest.raises(ValueError) as ex: + # NOTE: `pydantic` auto-converts strings to integers + # with pytest.raises(ValueError) as ex: + # init( + # supertokens_config=SupertokensConfig(get_new_core_app_url()), + # app_info=InputAppInfo( + # app_name="SuperTokens Demo", + # api_domain="http://api.supertokens.io", + # website_domain="http://supertokens.io", + # api_base_path="/auth", + # ), + # framework="fastapi", + # recipe_list=[jwt.init(jwt_validity_seconds="100")], # type: ignore + # ) + # assert "jwt_validity_seconds must be an integer or None" == str(ex.value) + + with pytest.raises( + ValidationError, + match="Input should be a valid dictionary or instance of BaseOverrideConfig\\[RecipeInterface, APIInterface\\]", + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -127,12 +131,14 @@ async def test_init_validation_jwt(): framework="fastapi", recipe_list=[jwt.init(override="override")], # type: ignore ) - assert "override must be an instance of OverrideConfig or None" == str(ex.value) @pytest.mark.asyncio async def test_init_validation_openid(): - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, + match="Input should be a valid dictionary or instance of BaseOverrideConfig\\[RecipeInterface, APIInterface\\]", + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -144,9 +150,6 @@ async def test_init_validation_openid(): framework="fastapi", recipe_list=[openid.init(override="override")], # type: ignore ) - assert "override must be an instance of InputOverrideConfig or None" == str( - ex.value - ) async def send_text_message( @@ -177,7 +180,9 @@ async def send_email( ) -> None: pass - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValueError, match="app_info must be an instance of InputAppInfo" + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info="AppInfo", # type: ignore @@ -195,9 +200,11 @@ async def send_email( ) ], ) - assert "app_info must be an instance of InputAppInfo" == str(ex.value) - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, + match="Input should be 'USER_INPUT_CODE', 'MAGIC_LINK' or 'USER_INPUT_CODE_AND_MAGIC_LINK'", + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -220,12 +227,10 @@ async def send_email( ) ], ) - assert ( - "flow_type must be one of USER_INPUT_CODE, MAGIC_LINK, USER_INPUT_CODE_AND_MAGIC_LINK" - == str(ex.value) - ) - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, match="Input should be an instance of ContactConfig" + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -242,9 +247,11 @@ async def send_email( ) ], ) - assert "contact_config must be of type ContactConfig" == str(ex.value) - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, + match="Input should be a valid dictionary or instance of BaseOverrideConfig\\[RecipeInterface, APIInterface\\]", + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -268,7 +275,6 @@ async def send_email( ) ], ) - assert "override must be of type OverrideConfig" == str(ex.value) providers_list: List[thirdparty.ProviderInput] = [ @@ -310,7 +316,10 @@ async def send_email( @pytest.mark.asyncio async def test_init_validation_session(): - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, + match="Input should be 'VIA_TOKEN', 'VIA_CUSTOM_HEADER' or 'NONE'", + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -322,11 +331,10 @@ async def test_init_validation_session(): framework="fastapi", recipe_list=[session.init(anti_csrf="ABCDE")], # type: ignore ) - assert "anti_csrf must be one of VIA_TOKEN, VIA_CUSTOM_HEADER, NONE or None" == str( - ex.value - ) - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, match="Input should be an instance of ErrorHandlers" + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -341,11 +349,11 @@ async def test_init_validation_session(): # on invalid type. recipe_list=[session.init(error_handlers="error handlers")], # type: ignore ) - assert "error_handlers must be an instance of ErrorHandlers or None" == str( - ex.value - ) - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, + match="Input should be a valid dictionary or instance of BaseOverrideConfig\\[RecipeInterface, APIInterface\\]", + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -357,14 +365,13 @@ async def test_init_validation_session(): framework="fastapi", recipe_list=[session.init(override="override")], # type: ignore ) - assert "override must be an instance of InputOverrideConfig or None" == str( - ex.value - ) @pytest.mark.asyncio async def test_init_validation_thirdparty(): - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, match="Input should be an instance of SignInAndUpFeature" + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -381,11 +388,11 @@ async def test_init_validation_thirdparty(): thirdparty.init(sign_in_and_up_feature="sign in up") # type: ignore ], ) - assert "sign_in_and_up_feature must be an instance of SignInAndUpFeature" == str( - ex.value - ) - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, + match="Input should be a valid dictionary or instance of BaseOverrideConfig\\[RecipeInterface, APIInterface\\]", + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -404,14 +411,14 @@ async def test_init_validation_thirdparty(): ) ], ) - assert "override must be an instance of InputOverrideConfig or None" == str( - ex.value - ) @pytest.mark.asyncio async def test_init_validation_usermetadata(): - with pytest.raises(ValueError) as ex: + with pytest.raises( + ValidationError, + match="Input should be a valid dictionary or instance of BaseOverrideConfig\\[RecipeInterface, APIInterface\\]", + ): init( supertokens_config=SupertokensConfig(get_new_core_app_url()), app_info=InputAppInfo( @@ -423,6 +430,3 @@ async def test_init_validation_usermetadata(): framework="fastapi", recipe_list=[usermetadata.init(override="override")], # type: ignore ) - assert "override must be an instance of InputOverrideConfig or None" == str( - ex.value - ) diff --git a/tests/jwt/test_get_JWKS.py b/tests/jwt/test_get_JWKS.py index dd0919d84..f091ab938 100644 --- a/tests/jwt/test_get_JWKS.py +++ b/tests/jwt/test_get_JWKS.py @@ -64,7 +64,7 @@ async def test_that_default_getJWKS_api_does_not_work_when_disabled( ), framework="fastapi", recipe_list=[ - jwt.init(override=jwt.OverrideConfig(apis=apis_override_get_JWKS)) + jwt.init(override=jwt.JWTOverrideConfig(apis=apis_override_get_JWKS)) ], ) @@ -96,7 +96,7 @@ async def get_jwks(user_context: Dict[str, Any]): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[jwt.init(override=jwt.OverrideConfig(functions=func_override))], + recipe_list=[jwt.init(override=jwt.JWTOverrideConfig(functions=func_override))], ) response = driver_config_client.get(url="/auth/jwt/jwks.json") diff --git a/tests/jwt/test_override.py b/tests/jwt/test_override.py index 856477f09..19f64c367 100644 --- a/tests/jwt/test_override.py +++ b/tests/jwt/test_override.py @@ -106,7 +106,9 @@ async def create_jwt_( website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[jwt.init(override=jwt.OverrideConfig(functions=custom_functions))], + recipe_list=[ + jwt.init(override=jwt.JWTOverrideConfig(functions=custom_functions)) + ], ) response = driver_config_client.post( @@ -147,7 +149,7 @@ async def get_jwks_get(api_options: APIOptions, user_context: Dict[str, Any]): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[jwt.init(override=jwt.OverrideConfig(apis=custom_api))], + recipe_list=[jwt.init(override=jwt.JWTOverrideConfig(apis=custom_api))], ) response = driver_config_client.get(url="/auth/jwt/jwks.json") diff --git a/tests/plugins/__init__.py b/tests/plugins/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/plugins/api_implementation.py b/tests/plugins/api_implementation.py new file mode 100644 index 000000000..062f56591 --- /dev/null +++ b/tests/plugins/api_implementation.py @@ -0,0 +1,24 @@ +from abc import abstractmethod +from typing import ( + List, +) + +from supertokens_python.types.recipe import BaseAPIInterface + +from .types import RecipeReturnType + + +class APIInterface(BaseAPIInterface): + @abstractmethod + def sign_in_post(self, message: str, stack: List[str]) -> RecipeReturnType: ... + + +class APIImplementation(APIInterface): + def sign_in_post(self, message: str, stack: List[str]) -> RecipeReturnType: + stack.append("original") + return RecipeReturnType( + type="API", + function="sign_in_post", + stack=stack, + message=message, + ) diff --git a/tests/plugins/config.py b/tests/plugins/config.py new file mode 100644 index 000000000..59387085d --- /dev/null +++ b/tests/plugins/config.py @@ -0,0 +1,49 @@ +from typing import Any, List, Optional + +from pydantic import Field +from supertokens_python.supertokens import ( + AppInfo, +) +from supertokens_python.types.config import ( + BaseConfig, + BaseNormalisedConfig, + BaseNormalisedOverrideConfig, + BaseOverrideConfig, + InterfaceOverride, +) +from supertokens_python.types.utils import UseDefaultIfNone + +from .api_implementation import APIInterface +from .recipe_implementation import RecipeInterface + +PluginTestOverrideConfig = BaseOverrideConfig[RecipeInterface, APIInterface] +NormalisedPluginTestOverrideConfig = BaseNormalisedOverrideConfig[ + RecipeInterface, APIInterface +] + + +class NormalizedPluginTestConfig(BaseNormalisedConfig[RecipeInterface, APIInterface]): + test_property: List[str] + + +class PluginTestConfig(BaseConfig[RecipeInterface, APIInterface]): + test_property: List[str] = Field(default_factory=lambda: ["original"]) + + +class PluginOverrideConfig(BaseOverrideConfig[RecipeInterface, APIInterface]): + config: UseDefaultIfNone[Optional[InterfaceOverride[Any]]] = lambda config: config + + +def validate_and_normalise_user_input( + config: Optional[PluginTestConfig], app_info: AppInfo +) -> NormalizedPluginTestConfig: + if config is None: + config = PluginTestConfig() + + override_config = NormalisedPluginTestOverrideConfig.from_input_config( + config.override + ) + + return NormalizedPluginTestConfig( + override=override_config, test_property=config.test_property + ) diff --git a/tests/plugins/misc.py b/tests/plugins/misc.py new file mode 100644 index 000000000..963327472 --- /dev/null +++ b/tests/plugins/misc.py @@ -0,0 +1,86 @@ +from typing import Any, Dict, Optional + +from supertokens_python.framework.request import BaseRequest +from supertokens_python.framework.response import BaseResponse +from supertokens_python.recipe.session import SessionContainer + + +class DummyRequest(BaseRequest): + def get_path(self) -> str: + return "/auth/plugin1/hello" + + def get_method(self) -> str: + return "get" + + def get_original_url(self) -> Any: + raise NotImplementedError + + def get_query_param(self, key: str, default: Optional[str] = None) -> Any: + raise NotImplementedError + + def get_query_params(self) -> Any: + raise NotImplementedError + + async def json(self) -> Any: + raise NotImplementedError + + async def form_data(self) -> Any: + raise NotImplementedError + + def method(self) -> Any: + return "get" + + def get_cookie(self, key: str) -> Any: + raise NotImplementedError + + def get_header(self, key: str) -> Any: + return None + + def get_session(self) -> Any: + raise NotImplementedError + + def set_session(self, session: SessionContainer) -> Any: + raise NotImplementedError + + def set_session_as_none(self) -> Any: + raise NotImplementedError + + +class DummyResponse(BaseResponse): + def __init__(self, content: Dict[str, Any], status_code: int = 200): + self.content = content + self.status_code = status_code + + def set_cookie( + self, + key: str, + value: str, + expires: int, + path: str = "/", + domain: Optional[str] = None, + secure: bool = False, + httponly: bool = False, + samesite: str = "lax", + ) -> Any: + raise NotImplementedError + + def set_header(self, key: str, value: str) -> None: + raise NotImplementedError + + def get_header(self, key: str) -> Optional[str]: + raise NotImplementedError + + def remove_header(self, key: str) -> None: + raise NotImplementedError + + def set_status_code(self, status_code: int) -> None: + raise NotImplementedError + + def set_json_content(self, content: Dict[str, Any]) -> Any: + raise NotImplementedError + + def set_html_content(self, content: str) -> Any: + raise NotImplementedError + + def redirect(self, url: str) -> Any: + raise NotImplementedError diff --git a/tests/plugins/plugins.py b/tests/plugins/plugins.py new file mode 100644 index 000000000..eaee4c828 --- /dev/null +++ b/tests/plugins/plugins.py @@ -0,0 +1,175 @@ +from typing import Any, List, Optional, Union + +from supertokens_python.constants import VERSION +from supertokens_python.plugins import ( + OverrideMap, + PluginDependenciesOkResponse, + RecipePluginOverride, + SuperTokensPlugin, + SuperTokensPluginDependencies, + SuperTokensPublicPlugin, +) +from supertokens_python.supertokens import SupertokensPublicConfig + +from .api_implementation import APIInterface +from .config import PluginTestConfig +from .recipe import PluginTestRecipe +from .recipe_implementation import RecipeInterface + + +def function_override_factory(identifier: str): + def function_override(original_implementation: RecipeInterface) -> RecipeInterface: + og_sign_in = original_implementation.sign_in + + def new_sign_in(message: str, stack: List[str]): + stack.append(identifier) + return og_sign_in(message, stack) + + original_implementation.sign_in = new_sign_in + return original_implementation + + return function_override + + +def api_override_factory(identifier: str): + def function_override(original_implementation: APIInterface) -> APIInterface: + sign_in_post = original_implementation.sign_in_post + + def new_sign_in_post(message: str, stack: List[str]): + stack.append(identifier) + return sign_in_post(message, stack) + + original_implementation.sign_in_post = new_sign_in_post + return original_implementation + + return function_override + + +def config_override_factory(identifier: str): + def config_override(original_config: PluginTestConfig) -> PluginTestConfig: + original_config.test_property.append(identifier) + return original_config + + return config_override + + +def init_factory(identifier: str): + def init( + config: SupertokensPublicConfig, + all_plugins: List[SuperTokensPublicPlugin], + sdk_version: str, + ): + PluginTestRecipe.init_calls.append(identifier) + + return init + + +def dependency_factory(dependencies: Optional[List[SuperTokensPlugin]]): + if dependencies is None: + dependencies = [] + + def dependency( + config: SupertokensPublicConfig, + plugins_above: List[SuperTokensPublicPlugin], + sdk_version: str, + ): + added_plugin_ids = [plugin.id for plugin in plugins_above] + plugins_to_add = [ + plugin for plugin in dependencies if plugin.id not in added_plugin_ids + ] + return PluginDependenciesOkResponse(plugins_to_add=plugins_to_add) + + return dependency + + +def plugin_factory( + identifier: str, + override_functions: bool = False, + override_apis: bool = False, + override_config: bool = False, + deps: Optional[List[SuperTokensPlugin]] = None, + add_init: bool = False, + compatible_sdk_versions: Optional[Union[str, List[str]]] = None, +): + override_map_obj: OverrideMap = {PluginTestRecipe.recipe_id: RecipePluginOverride()} + + if override_functions: + override_map_obj[ + PluginTestRecipe.recipe_id + ].functions = function_override_factory(identifier) # type: ignore + if override_apis: + override_map_obj[PluginTestRecipe.recipe_id].apis = api_override_factory( + identifier + ) # type: ignore + + if override_config: + override_map_obj[PluginTestRecipe.recipe_id].config = config_override_factory( + identifier + ) + + init_fn = None + if add_init: + init_fn = init_factory(identifier) + + if compatible_sdk_versions is None: + sdk_versions = f"=={VERSION}" + else: + sdk_versions = compatible_sdk_versions + + class Plugin(SuperTokensPlugin): + id: str = identifier + compatible_sdk_versions: Union[str, List[str]] = sdk_versions + override_map: Optional[OverrideMap] = override_map_obj + init: Any = init_fn + dependencies: Optional[SuperTokensPluginDependencies] = dependency_factory(deps) + + return Plugin() + + +Plugin1 = plugin_factory( + "plugin1", + override_functions=True, + override_config=True, + add_init=True, +) +Plugin2 = plugin_factory( + "plugin2", + override_functions=True, + override_config=True, + add_init=True, +) +Plugin3Dep1 = plugin_factory( + "plugin3dep1", + override_functions=True, + override_config=True, + deps=[Plugin1], + add_init=True, +) +Plugin3Dep2_1 = plugin_factory( + "plugin3dep2_1", + override_functions=True, + override_config=True, + deps=[Plugin2, Plugin1], + add_init=True, +) +Plugin4Dep1 = plugin_factory( + "plugin4dep1", + override_functions=True, + override_config=True, + deps=[Plugin1], + add_init=True, +) +Plugin4Dep2 = plugin_factory( + "plugin4dep2", + override_functions=True, + override_config=True, + deps=[Plugin2], + add_init=True, +) +Plugin4Dep3__2_1 = plugin_factory( + "plugin4dep3__2_1", + override_functions=True, + override_config=True, + deps=[Plugin3Dep2_1], + add_init=True, +) diff --git a/tests/plugins/recipe.py b/tests/plugins/recipe.py new file mode 100644 index 000000000..56f23d287 --- /dev/null +++ b/tests/plugins/recipe.py @@ -0,0 +1,129 @@ +from typing import ( + List, + Optional, +) + +from supertokens_python.exceptions import SuperTokensError, raise_general_exception +from supertokens_python.framework.request import BaseRequest +from supertokens_python.framework.response import BaseResponse +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.plugins import OverrideMap, apply_plugins +from supertokens_python.querier import Querier +from supertokens_python.recipe_module import APIHandled, RecipeModule +from supertokens_python.supertokens import ( + AppInfo, +) +from supertokens_python.types.base import UserContext + +from .api_implementation import APIImplementation +from .config import ( + NormalizedPluginTestConfig, + PluginTestConfig, + validate_and_normalise_user_input, +) +from .recipe_implementation import RecipeImplementation + + +class PluginTestRecipe(RecipeModule): + __instance: Optional["PluginTestRecipe"] = None + init_calls: List[str] = [] + recipe_id = "plugin_test" + + config: NormalizedPluginTestConfig + recipe_implementation: RecipeImplementation + api_implementation: APIImplementation + + def __init__( + self, recipe_id: str, app_info: AppInfo, config: Optional[PluginTestConfig] + ): + super().__init__(recipe_id=recipe_id, app_info=app_info) + self.config = validate_and_normalise_user_input( + app_info=app_info, config=config + ) + + querier = Querier.get_instance(rid_to_core=recipe_id) + recipe_implementation = RecipeImplementation( + querier=querier, + config=self.config, + ) + self.recipe_implementation = self.config.override.functions( + recipe_implementation + ) # type: ignore + + api_implementation = APIImplementation() + self.api_implementation = self.config.override.apis(api_implementation) # type: ignore + + @staticmethod + def get_instance() -> "PluginTestRecipe": + if PluginTestRecipe.__instance is not None: + return PluginTestRecipe.__instance + raise_general_exception( + "Initialisation not done. Did you forget to call the SuperTokens.init function?" + ) + + @staticmethod + def get_instance_optional() -> Optional["PluginTestRecipe"]: + return PluginTestRecipe.__instance + + @staticmethod + def init(config: Optional[PluginTestConfig]): + if config is None: + config = PluginTestConfig() + + def func(app_info: AppInfo, plugins: List[OverrideMap]): + if PluginTestRecipe.__instance is None: + PluginTestRecipe.__instance = PluginTestRecipe( + recipe_id=PluginTestRecipe.recipe_id, + app_info=app_info, + config=apply_plugins( + recipe_id=PluginTestRecipe.recipe_id, + config=config, # type: ignore + plugins=plugins, + ), + ) + return PluginTestRecipe.__instance + else: + raise_general_exception( + "PluginTestRecipe has already been initialised. Please check your code for bugs." + ) + + return func + + @staticmethod + def reset(): + PluginTestRecipe.__instance = None + PluginTestRecipe.init_calls = [] + + def get_all_cors_headers(self) -> List[str]: + return [] + + async def handle_error( + self, + request: BaseRequest, + err: SuperTokensError, + response: BaseResponse, + user_context: UserContext, + ): + raise err + + async def handle_api_request( + self, + request_id: str, + tenant_id: str, + request: BaseRequest, + path: NormalisedURLPath, + method: str, + response: BaseResponse, + user_context: UserContext, + ): + return None + + def get_apis_handled(self) -> List[APIHandled]: + return [] + + def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: + return False + + +def plugin_test_init(config: Optional[PluginTestConfig] = None): + return PluginTestRecipe.init(config=config) diff --git a/tests/plugins/recipe_implementation.py b/tests/plugins/recipe_implementation.py new file mode 100644 index 000000000..f43ed1a0f --- /dev/null +++ b/tests/plugins/recipe_implementation.py @@ -0,0 +1,38 @@ +from abc import abstractmethod +from typing import ( + TYPE_CHECKING, + List, +) + +from supertokens_python.querier import Querier +from supertokens_python.types.recipe import BaseRecipeInterface + +from .types import RecipeReturnType + +if TYPE_CHECKING: + from .config import NormalizedPluginTestConfig + + +class RecipeInterface(BaseRecipeInterface): + @abstractmethod + def sign_in(self, message: str, stack: List[str]) -> RecipeReturnType: ... + + +class RecipeImplementation(RecipeInterface): + def __init__( + self, + querier: Querier, + config: "NormalizedPluginTestConfig", + ): + super().__init__() + self.querier = querier + self.config = config + + def sign_in(self, message: str, stack: List[str]) -> RecipeReturnType: + stack.append("original") + return RecipeReturnType( + type="Recipe", + function="sign_in", + stack=stack, + message=message, + ) diff --git a/tests/plugins/test_plugins.py b/tests/plugins/test_plugins.py new file mode 100644 index 000000000..b169fb9af --- /dev/null +++ b/tests/plugins/test_plugins.py @@ -0,0 +1,510 @@ +from functools import partial +from typing import Any, Dict, List, Union +from unittest.mock import patch + +from pytest import fixture, mark, param, raises +from supertokens_python import ( + InputAppInfo, + Supertokens, + SupertokensConfig, + SupertokensExperimentalConfig, + init, +) +from supertokens_python.plugins import ( + PluginRouteHandler, + PluginRouteHandlerFunctionErrorResponse, + PluginRouteHandlerFunctionOkResponse, + SuperTokensPlugin, +) +from supertokens_python.post_init_callbacks import PostSTInitCallbacks +from supertokens_python.supertokens import SupertokensPublicConfig + +from tests.utils import outputs, reset + +from .config import PluginTestConfig, PluginTestOverrideConfig +from .misc import DummyRequest, DummyResponse +from .plugins import ( + Plugin1, + Plugin2, + Plugin3Dep1, + Plugin3Dep2_1, + Plugin4Dep1, + Plugin4Dep2, + Plugin4Dep3__2_1, + api_override_factory, + function_override_factory, + plugin_factory, +) +from .recipe import PluginTestRecipe, plugin_test_init +from .types import RecipeReturnType + + +@fixture(autouse=True) +def setup_and_teardown(): + reset() + PluginTestRecipe.reset() + PostSTInitCallbacks.reset() + yield + reset() + PluginTestRecipe.reset() + PostSTInitCallbacks.reset() + + +def recipe_factory(override_functions: bool = False, override_apis: bool = False): + override = PluginTestOverrideConfig() + + if override_functions: + override.functions = function_override_factory("override") + if override_apis: + override.apis = api_override_factory("override") + + return plugin_test_init(config=PluginTestConfig(override=override)) + + +partial_init = partial( + init, + app_info=InputAppInfo( + app_name="plugin_test", + api_domain="api.supertokens.io", + origin="http://localhost:3001", + ), + framework="django", + supertokens_config=SupertokensConfig( + connection_uri="http://localhost:3567", + ), +) + + +@mark.parametrize( + ( + "recipe_fn_override", + "recipe_api_override", + "plugins", + "recipe_expectation", + "api_expectation", + ), + [ + param( + False, + False, + [], + outputs(["original"]), + outputs(["original"]), + id="fn_ovr=False, api_ovr=False, plugins=[]", + ), + param( + True, + False, + [], + outputs(["override", "original"]), + outputs(["original"]), + id="fn_ovr=True, api_ovr=False, plugins=[]", + ), + param( + False, + True, + [], + outputs(["original"]), + outputs(["override", "original"]), + id="fn_ovr=False, api_ovr=True, plugins=[]", + ), + param( + True, + False, + [plugin_factory("plugin1", override_functions=True)], + outputs(["override", "plugin1", "original"]), + outputs(["original"]), + id="fn_ovr=True, api_ovr=False, plugins=[Plugin1], plugin1=[fn]", + ), + param( + True, + False, + [plugin_factory("plugin1", override_apis=True)], + outputs(["override", "original"]), + outputs(["plugin1", "original"]), + id="fn_ovr=True, api_ovr=False, plugins=[Plugin1], plugin1=[api]", + ), + param( + False, + True, + [plugin_factory("plugin1", override_functions=True)], + outputs(["plugin1", "original"]), + outputs(["override", "original"]), + id="fn_ovr=False, api_ovr=True, plugins=[Plugin1], plugin1=[fn]", + ), + param( + False, + True, + [plugin_factory("plugin1", override_apis=True)], + outputs(["original"]), + outputs(["override", "plugin1", "original"]), + id="fn_ovr=False, api_ovr=True, plugins=[Plugin1], plugin1=[api]", + ), + param( + True, + False, + [ + plugin_factory("plugin1", override_functions=True), + plugin_factory("plugin2", override_functions=True), + ], + outputs(["override", "plugin2", "plugin1", "original"]), + outputs(["original"]), + id="fn_ovr=True, api_ovr=False, plugins=[Plugin1, Plugin2], plugin1=[fn], plugin2=[fn]", + ), + param( + False, + True, + [ + plugin_factory("plugin1", override_apis=True), + plugin_factory("plugin2", override_apis=True), + ], + outputs(["original"]), + outputs(["override", "plugin2", "plugin1", "original"]), + id="fn_ovr=True, api_ovr=False, plugins=[Plugin1, Plugin2], plugin1=[api], plugin2=[api]", + ), + param( + True, + True, + [ + plugin_factory("plugin1", override_functions=True, override_apis=True), + plugin_factory("plugin2", override_functions=True, override_apis=True), + ], + outputs(["override", "plugin2", "plugin1", "original"]), + outputs(["override", "plugin2", "plugin1", "original"]), + id="fn_ovr=True, api_ovr=True, plugins=[Plugin1, Plugin2], plugin1=[fn,api], plugin2=[fn,api]", + ), + ], +) +def test_overrides( + recipe_fn_override: bool, + recipe_api_override: bool, + plugins: List[SuperTokensPlugin], + recipe_expectation: Any, + api_expectation: Any, +): + partial_init( + recipe_list=[ + recipe_factory( + override_functions=recipe_fn_override, override_apis=recipe_api_override + ), + ], + experimental=SupertokensExperimentalConfig( + plugins=plugins, + ), + ) + + with recipe_expectation as expected_stack: + output = PluginTestRecipe.get_instance().recipe_implementation.sign_in( + "msg", [] + ) + assert output == RecipeReturnType( + type="Recipe", + function="sign_in", + stack=expected_stack, + message="msg", + ) + + with api_expectation as expected_stack: + output = PluginTestRecipe.get_instance().api_implementation.sign_in_post( + "msg", [] + ) + assert output == RecipeReturnType( + type="API", + function="sign_in_post", + stack=expected_stack, + message="msg", + ) + + +# TODO: Figure out a way to add circular dependencies and test them +@mark.parametrize( + ( + "plugins", + "recipe_expectation", + "api_expectation", + "config_expectation", + "init_expectation", + ), + [ + param( + [Plugin1, Plugin1], + outputs(["plugin1", "original"]), + outputs(["original"]), + outputs(["original", "plugin1"]), + outputs(["plugin1"]), + id="1,1 => 1", + ), + param( + [Plugin1, Plugin2], + outputs(["plugin2", "plugin1", "original"]), + outputs(["original"]), + outputs(["original", "plugin1", "plugin2"]), + outputs(["plugin1", "plugin2"]), + id="1,2 => 2,1", + ), + param( + [Plugin3Dep1], + outputs(["plugin3dep1", "plugin1", "original"]), + outputs(["original"]), + outputs(["original", "plugin1", "plugin3dep1"]), + outputs(["plugin1", "plugin3dep1"]), + id="3->1 => 3,1", + ), + param( + [Plugin3Dep2_1], + outputs(["plugin3dep2_1", "plugin1", "plugin2", "original"]), + outputs(["original"]), + outputs(["original", "plugin2", "plugin1", "plugin3dep2_1"]), + outputs(["plugin2", "plugin1", "plugin3dep2_1"]), + id="3->(2,1) => 3,2,1", + ), + param( + [Plugin3Dep1, Plugin4Dep2], + outputs(["plugin4dep2", "plugin2", "plugin3dep1", "plugin1", "original"]), + outputs(["original"]), + outputs(["original", "plugin1", "plugin3dep1", "plugin2", "plugin4dep2"]), + outputs(["plugin1", "plugin3dep1", "plugin2", "plugin4dep2"]), + id="3->1,4->2 => 4,2,3,1", + ), + param( + [Plugin4Dep3__2_1], + outputs( + ["plugin4dep3__2_1", "plugin3dep2_1", "plugin1", "plugin2", "original"] + ), + outputs(["original"]), + outputs( + ["original", "plugin2", "plugin1", "plugin3dep2_1", "plugin4dep3__2_1"] + ), + outputs(["plugin2", "plugin1", "plugin3dep2_1", "plugin4dep3__2_1"]), + id="4->3->(2,1) => 4,3,1,2", + ), + param( + [Plugin3Dep1, Plugin4Dep1], + outputs(["plugin4dep1", "plugin3dep1", "plugin1", "original"]), + outputs(["original"]), + outputs(["original", "plugin1", "plugin3dep1", "plugin4dep1"]), + outputs(["plugin1", "plugin3dep1", "plugin4dep1"]), + id="3->1,4->1 => 4,3,1", + ), + ], +) +def test_depdendencies_and_init( + plugins: List[SuperTokensPlugin], + recipe_expectation: Any, + api_expectation: Any, + config_expectation: Any, + init_expectation: Any, +): + partial_init( + recipe_list=[ + recipe_factory(), + ], + experimental=SupertokensExperimentalConfig( + plugins=plugins, + ), + ) + + with recipe_expectation as expected_stack: + output = PluginTestRecipe.get_instance().recipe_implementation.sign_in( + "msg", [] + ) + assert output == RecipeReturnType( + type="Recipe", + function="sign_in", + stack=expected_stack, + message="msg", + ) + + with api_expectation as expected_stack: + output = PluginTestRecipe.get_instance().api_implementation.sign_in_post( + "msg", [] + ) + assert output == RecipeReturnType( + type="API", + function="sign_in_post", + stack=expected_stack, + message="msg", + ) + + with config_expectation as expected_stack: + output = PluginTestRecipe.get_instance().config.test_property + assert output == expected_stack + + with init_expectation as expected_stack: + assert PluginTestRecipe.init_calls == expected_stack + + +def test_st_config_override(): + plugin = plugin_factory("plugin1", override_functions=False, override_apis=False) + + def config_override(config: SupertokensPublicConfig) -> SupertokensPublicConfig: + config.supertokens_config = SupertokensConfig( + connection_uri="http://localhost:3567" + ) + return config + + plugin.config = config_override + + partial_init( + recipe_list=[ + recipe_factory(override_functions=False, override_apis=False), + ], + experimental=SupertokensExperimentalConfig( + plugins=[plugin], + ), + ) + + assert ( + Supertokens.get_instance().supertokens_config.connection_uri + == "http://localhost:3567" + ) + + +def test_st_config_override_non_public_property(): + plugin = plugin_factory("plugin1", override_functions=False, override_apis=False) + + def config_override(config: SupertokensPublicConfig) -> SupertokensPublicConfig: + config.recipe_list = [] # type: ignore + return config + + plugin.config = config_override + + with raises( + ValueError, match='"SupertokensPublicConfig" object has no field "recipe_list"' + ): + partial_init( + recipe_list=[ + recipe_factory(override_functions=False, override_apis=False), + ], + experimental=SupertokensExperimentalConfig( + plugins=[plugin], + ), + ) + + +# NOTE: Returning a string here to make it easier to write/test the handler +async def handler_fn(*_, **__: Dict[str, Any]) -> Any: + return "plugin1" + + +plugin_route_handler = PluginRouteHandler( + method="get", + path="/auth/plugin1/hello", + handler=handler_fn, # type: ignore - returns string for simplicity + verify_session_options=None, +) + + +async def test_route_handlers_list(): + plugin = plugin_factory("plugin1", override_functions=False, override_apis=False) + + plugin.route_handlers = [plugin_route_handler] + + partial_init( + recipe_list=[ + recipe_factory(override_functions=False, override_apis=False), + ], + experimental=SupertokensExperimentalConfig( + plugins=[plugin], + ), + ) + + st_instance = Supertokens.get_instance() + + res = await st_instance.middleware( + request=DummyRequest(), + response=DummyResponse(content={}), + user_context={}, + ) + + assert res == "plugin1" + + +@mark.parametrize( + ("handler_response", "expectation"), + [ + param( + PluginRouteHandlerFunctionOkResponse(route_handlers=[plugin_route_handler]), + outputs("plugin1"), + id="OK response with route handler", + ), + param( + PluginRouteHandlerFunctionErrorResponse( + message="error", + ), + raises(Exception, match="error"), + id="Error response", + ), + ], +) +async def test_route_handlers_callable(handler_response: Any, expectation: Any): + plugin = plugin_factory("plugin1", override_functions=False, override_apis=False) + + plugin.route_handlers = lambda *_, **__: handler_response # type: ignore + + with expectation as expected_output: + partial_init( + recipe_list=[ + recipe_factory(override_functions=False, override_apis=False), + ], + experimental=SupertokensExperimentalConfig( + plugins=[plugin], + ), + ) + + st_instance = Supertokens.get_instance() + + res = await st_instance.middleware( + request=DummyRequest(), + response=DummyResponse(content={}), + user_context={}, + ) + + assert res == expected_output + + +@mark.parametrize( + ("sdk_version", "compatible_versions", "expectation"), + [ + param( + "1.5.0", + ">=1.0.0,<2.0.0", + outputs(None), + id="[Valid][1.5.0][>=1.0.0,<2.0.0] as string", + ), + param( + "1.5.0", + [">=1.0.0", "<2.0.0"], + outputs(None), + id="[Valid][1.5.0][>=1.0.0,<2.0.0] as list of strings", + ), + param( + "2.0.0", + [">=1.0.0,<2.0.0"], + raises(Exception, match="Incompatible SDK version for plugin plugin1."), + id="[Invalid][2.0.0][>=1.0.0,<2.0.0]", + ), + ], +) +def test_versions( + sdk_version: str, + compatible_versions: Union[str, List[str]], + expectation: Any, +): + plugin = plugin_factory( + "plugin1", + override_functions=False, + override_apis=False, + compatible_sdk_versions=compatible_versions, + ) + + with patch("supertokens_python.plugins.VERSION", sdk_version): + with expectation as _: + partial_init( + recipe_list=[ + recipe_factory(override_functions=False, override_apis=False), + ], + experimental=SupertokensExperimentalConfig( + plugins=[plugin], + ), + ) diff --git a/tests/plugins/types.py b/tests/plugins/types.py new file mode 100644 index 000000000..f14bcfeb8 --- /dev/null +++ b/tests/plugins/types.py @@ -0,0 +1,10 @@ +from typing import List, Literal + +from supertokens_python.types.response import CamelCaseBaseModel + + +class RecipeReturnType(CamelCaseBaseModel): + type: Literal["Recipe", "API"] + function: str + stack: List[str] + message: str diff --git a/tests/sessions/claims/test_create_new_session.py b/tests/sessions/claims/test_create_new_session.py index 62101fa96..c3efa537b 100644 --- a/tests/sessions/claims/test_create_new_session.py +++ b/tests/sessions/claims/test_create_new_session.py @@ -52,7 +52,7 @@ async def test_should_merge_claims_and_passed_access_token_payload_obj(timestamp url=get_new_core_app_url(), recipe_list=[ session.init( - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( functions=session_functions_override_with_claim( TrueClaim, {"user-custom-claim": "foo"} ), diff --git a/tests/sessions/claims/test_verify_session.py b/tests/sessions/claims/test_verify_session.py index cbae9e39f..5bad87a72 100644 --- a/tests/sessions/claims/test_verify_session.py +++ b/tests/sessions/claims/test_verify_session.py @@ -57,7 +57,7 @@ async def new_get_global_claim_validators( session.init( anti_csrf="VIA_TOKEN", get_token_transfer_method=lambda _, __, ___: "cookie", - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( functions=session_function_override ), ) @@ -85,7 +85,7 @@ async def new_get_global_claim_validators( session.init( anti_csrf="VIA_TOKEN", get_token_transfer_method=lambda _, __, ___: "cookie", - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( functions=session_function_override ), ) diff --git a/tests/sessions/claims/utils.py b/tests/sessions/claims/utils.py index 36c63b9ac..1fba680c5 100644 --- a/tests/sessions/claims/utils.py +++ b/tests/sessions/claims/utils.py @@ -21,8 +21,10 @@ def session_functions_override_with_claim( if params is None: params = {} - def session_function_override(oi: RecipeInterface) -> RecipeInterface: - oi_create_new_session = oi.create_new_session + def session_function_override( + original_implementation: RecipeInterface, + ) -> RecipeInterface: + oi_create_new_session = original_implementation.create_new_session async def new_create_new_session( user_id: str, @@ -58,8 +60,8 @@ async def new_create_new_session( user_context, ) - oi.create_new_session = new_create_new_session - return oi + original_implementation.create_new_session = new_create_new_session + return original_implementation return session_function_override @@ -69,7 +71,7 @@ def get_st_init_args(claim: SessionClaim[Any] = TrueClaim): url=get_new_core_app_url(), recipe_list=[ session.init( - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( functions=session_functions_override_with_claim(claim), ), ), diff --git a/tests/telemetry/test_telemetry.py b/tests/telemetry/test_telemetry.py index 05d2812f6..ff7472f79 100644 --- a/tests/telemetry/test_telemetry.py +++ b/tests/telemetry/test_telemetry.py @@ -40,7 +40,7 @@ async def test_telemetry(): session.init( anti_csrf="VIA_TOKEN", cookie_domain="supertokens.io", - override=session.InputOverrideConfig(), + override=session.SessionOverrideConfig(), ) ], telemetry=True, @@ -73,7 +73,7 @@ async def test_read_from_env(): session.init( anti_csrf="VIA_TOKEN", cookie_domain="supertokens.io", - override=session.InputOverrideConfig(), + override=session.SessionOverrideConfig(), ) ], ) diff --git a/tests/test-server/app.py b/tests/test-server/app.py index d9f221db4..3d0dcb746 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -17,7 +17,6 @@ from passwordless import add_passwordless_routes # pylint: disable=import-error from session import add_session_routes # pylint: disable=import-error from supertokens_python import ( - AppInfo, InputAppInfo, Supertokens, SupertokensConfig, @@ -67,7 +66,7 @@ ) from supertokens_python.recipe.webauthn.recipe import WebauthnRecipe from supertokens_python.recipe.webauthn.types.config import WebauthnConfig -from supertokens_python.recipe_module import RecipeModule +from supertokens_python.supertokens import RecipeInit from supertokens_python.types import RecipeUserId from test_functions_mapper import ( # pylint: disable=import-error get_func, @@ -252,9 +251,7 @@ def init_st(config: Dict[str, Any]): st_reset() override_logging.reset_override_logs() - recipe_list: List[Callable[[AppInfo], RecipeModule]] = [ - dashboard.init(api_key="test") - ] + recipe_list: List[RecipeInit] = [dashboard.init(api_key="test")] for recipe_config in config.get("recipeList", []): recipe_id = recipe_config.get("recipeId") if recipe_id == "emailpassword": @@ -281,7 +278,7 @@ def init_st(config: Dict[str, Any]): ), ) ), - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=override_builder_with_logging( "EmailPassword.override.apis", recipe_config_json.get("override", {}).get("apis", None), @@ -334,7 +331,7 @@ def get_token_transfer_method( "useDynamicAccessTokenSigningKey" ), get_token_transfer_method=get_token_transfer_method, - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( apis=override_builder_with_logging( "Session.override.apis", recipe_config_json.get("override", {}).get("apis", None), @@ -364,7 +361,7 @@ def get_token_transfer_method( "AccountLinking.onAccountLinked", recipe_config_json.get("onAccountLinked"), ), - override=accountlinking.InputOverrideConfig( + override=accountlinking.AccountLinkingOverrideConfig( functions=override_builder_with_logging( "AccountLinking.override.functions", recipe_config_json.get("override", {}).get( @@ -467,7 +464,7 @@ def get_token_transfer_method( sign_in_and_up_feature=thirdparty.SignInAndUpFeature( providers=providers ), - override=thirdparty.InputOverrideConfig( + override=thirdparty.ThirdPartyOverrideConfig( functions=override_builder_with_logging( "ThirdParty.override.functions", recipe_config_json.get("override", {}).get( @@ -488,7 +485,7 @@ def get_token_transfer_method( UnknownUserIdError, ) from supertokens_python.recipe.emailverification.utils import ( - OverrideConfig as EmailVerificationOverrideConfig, + EmailVerificationOverrideConfig as EmailVerificationOverrideConfig, ) recipe_list.append( @@ -530,7 +527,7 @@ def get_token_transfer_method( recipe_list.append( multifactorauth.init( first_factors=recipe_config_json.get("firstFactors", None), - override=multifactorauth.OverrideConfig( + override=multifactorauth.MultiFactorAuthOverrideConfig( functions=override_builder_with_logging( "MultifactorAuth.override.functions", recipe_config_json.get("override", {}).get( @@ -586,7 +583,7 @@ async def send_sms( ), contact_config=contact_config, flow_type=recipe_config_json.get("flowType"), - override=passwordless.InputOverrideConfig( + override=passwordless.PasswordlessOverrideConfig( apis=override_builder_with_logging( "Passwordless.override.apis", recipe_config_json.get("override", {}).get("apis"), @@ -599,9 +596,7 @@ async def send_sms( ) ) elif recipe_id == "totp": - from supertokens_python.recipe.totp.types import ( - OverrideConfig as TOTPOverrideConfig, - ) + from supertokens_python.recipe.totp.types import TOTPOverrideConfig recipe_config_json = json.loads(recipe_config.get("config", "{}")) recipe_list.append( @@ -627,7 +622,7 @@ async def send_sms( recipe_config_json = json.loads(recipe_config.get("config", "{}")) recipe_list.append( oauth2provider.init( - override=oauth2provider.InputOverrideConfig( + override=oauth2provider.OAuth2ProviderOverrideConfig( apis=override_builder_with_logging( "OAuth2Provider.override.apis", recipe_config_json.get("override", {}).get("apis"), @@ -641,7 +636,7 @@ async def send_sms( ) elif recipe_id == "webauthn": from supertokens_python.recipe.webauthn.types.config import ( - OverrideConfig as WebauthnOverrideConfig, + WebauthnOverrideConfig, ) class WebauthnEmailDeliveryConfig( diff --git a/tests/test-server/utils.py b/tests/test-server/utils.py index f5c959527..7008fef37 100644 --- a/tests/test-server/utils.py +++ b/tests/test-server/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, cast from override_logging import log_override_event # pylint: disable=import-error from supertokens_python.recipe.emailverification import EmailVerificationClaim @@ -35,10 +35,14 @@ def fetch_value( }, ) - ret_val: Any = user_context.get("st-stub-arr-value") or ( - values[0] - if isinstance(values, list) and isinstance(values[0], list) - else values + ret_val: Any = cast( + Any, + user_context.get("st-stub-arr-value") + or ( + values[0] + if isinstance(values, list) and isinstance(values[0], list) + else values + ), ) log_override_event(f"claim-{key}.fetchValue", "RES", ret_val) diff --git a/tests/test_config.py b/tests/test_config.py index 48486fdcf..daf1bbe45 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -11,6 +11,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import re from unittest.mock import MagicMock import pytest @@ -280,7 +281,7 @@ async def test_same_site_values(st_config: SupertokensConfig): ) test_passed = False except Exception as e: - assert str(e) == 'cookie same site must be one of "strict", "lax", or "none"' + assert re.search("Input should be 'lax', 'strict' or 'none'", str(e)) assert test_passed reset() @@ -299,7 +300,7 @@ async def test_same_site_values(st_config: SupertokensConfig): ) test_passed = False except Exception as e: - assert str(e) == 'cookie same site must be one of "strict", "lax", or "none"' + assert re.search("Input should be 'lax', 'strict' or 'none'", str(e)) assert test_passed reset() diff --git a/tests/test_logger.py b/tests/test_logger.py index 8bf9993fc..78bc3ab32 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -42,11 +42,11 @@ def test_1_json_msg_format(self, datetime_mock: MagicMock): enable_debug_logging() datetime_mock.now.return_value = real_datetime(2000, 1, 1, tzinfo=timezone.utc) - with self.assertLogs(level="DEBUG") as captured: + with self.assertLogs(level="DEBUG") as captured: # type: ignore log_debug_message("API replied with status 200") - record = captured.records[0] - out = json.loads(record.msg) + record = captured.records[0] # type: ignore + out = json.loads(record.msg) # type: ignore assert out == { "t": "2000-01-01T00:00:00+00Z", diff --git a/tests/test_session.py b/tests/test_session.py index 4b5c3e1c7..9a8e91bd7 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -26,7 +26,7 @@ from supertokens_python.framework.fastapi.fastapi_middleware import get_middleware from supertokens_python.process_state import PROCESS_STATE, ProcessState from supertokens_python.recipe import session -from supertokens_python.recipe.session import InputOverrideConfig, SessionRecipe +from supertokens_python.recipe.session import SessionOverrideConfig, SessionRecipe from supertokens_python.recipe.session.asyncio import ( create_new_session as async_create_new_session, ) @@ -381,7 +381,7 @@ async def get_session_information( recipe_list=[ session.init( anti_csrf="VIA_TOKEN", - override=InputOverrideConfig( + override=SessionOverrideConfig( functions=override_session_functions, ), ) @@ -422,7 +422,7 @@ async def refresh_post(api_options: APIOptions, user_context: Dict[str, Any]): session.init( anti_csrf="VIA_TOKEN", get_token_transfer_method=lambda _, __, ___: "cookie", - override=session.InputOverrideConfig(apis=session_api_override), + override=session.SessionOverrideConfig(apis=session_api_override), ) ], ) @@ -472,7 +472,7 @@ async def refresh_post(api_options: APIOptions, user_context: Dict[str, Any]): session.init( anti_csrf="VIA_TOKEN", get_token_transfer_method=lambda _, __, ___: "cookie", - override=session.InputOverrideConfig(apis=session_api_override), + override=session.SessionOverrideConfig(apis=session_api_override), ) ], ) @@ -505,15 +505,15 @@ async def refresh_post(api_options: APIOptions, user_context: Dict[str, Any]): async def test_revoking_session_during_refresh_and_throw_unauthorized( driver_config_client: TestClient, ): - def session_api_override(oi: APIInterface) -> APIInterface: - oi_refresh_post = oi.refresh_post + def session_api_override(original_implementation: APIInterface) -> APIInterface: + oi_refresh_post = original_implementation.refresh_post async def refresh_post(api_options: APIOptions, user_context: Dict[str, Any]): await oi_refresh_post(api_options, user_context) return raise_unauthorised_exception("unauthorized", clear_tokens=True) - oi.refresh_post = refresh_post - return oi + original_implementation.refresh_post = refresh_post + return original_implementation init_args = get_st_init_args( url=get_new_core_app_url(), @@ -521,7 +521,7 @@ async def refresh_post(api_options: APIOptions, user_context: Dict[str, Any]): session.init( anti_csrf="VIA_TOKEN", get_token_transfer_method=lambda _, __, ___: "cookie", - override=session.InputOverrideConfig(apis=session_api_override), + override=session.SessionOverrideConfig(apis=session_api_override), ) ], ) @@ -584,7 +584,7 @@ async def refresh_post(api_options: APIOptions, user_context: Dict[str, Any]): session.init( anti_csrf="VIA_TOKEN", get_token_transfer_method=lambda _, __, ___: "cookie", - override=session.InputOverrideConfig(apis=session_api_override), + override=session.SessionOverrideConfig(apis=session_api_override), ) ], ) diff --git a/tests/test_user_context.py b/tests/test_user_context.py index 4ef289756..4c2e87c55 100644 --- a/tests/test_user_context.py +++ b/tests/test_user_context.py @@ -191,13 +191,13 @@ async def create_new_session( framework="fastapi", recipe_list=[ emailpassword.init( - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=apis_override_email_password, functions=functions_override_email_password, ) ), session.init( - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( functions=functions_override_session ) ), @@ -319,13 +319,13 @@ async def create_new_session( framework="fastapi", recipe_list=[ emailpassword.init( - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=apis_override_email_password, functions=functions_override_email_password, ) ), session.init( - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( functions=functions_override_session ) ), @@ -466,13 +466,13 @@ async def create_new_session( framework="fastapi", recipe_list=[ emailpassword.init( - override=emailpassword.InputOverrideConfig( + override=emailpassword.EmailPasswordOverrideConfig( apis=apis_override_email_password, functions=functions_override_email_password, ) ), session.init( - override=session.InputOverrideConfig( + override=session.SessionOverrideConfig( functions=functions_override_session ) ), diff --git a/tests/usermetadata/test_metadata.py b/tests/usermetadata/test_metadata.py index ded211d6c..a859ee442 100644 --- a/tests/usermetadata/test_metadata.py +++ b/tests/usermetadata/test_metadata.py @@ -27,7 +27,7 @@ ClearUserMetadataResult, RecipeInterface, ) -from supertokens_python.recipe.usermetadata.utils import InputOverrideConfig +from supertokens_python.recipe.usermetadata.utils import UserMetadataOverrideConfig from supertokens_python.utils import is_version_gte from tests.utils import get_new_core_app_url @@ -169,7 +169,9 @@ async def new_get_user_metadata(user_id: str, user_context: Dict[str, Any]): ), framework="fastapi", recipe_list=[ - usermetadata.init(override=InputOverrideConfig(functions=override_func)) + usermetadata.init( + override=UserMetadataOverrideConfig(functions=override_func) + ) ], )