diff --git a/.speakeasy/gen.lock b/.speakeasy/gen.lock index e08b777..d76999a 100755 --- a/.speakeasy/gen.lock +++ b/.speakeasy/gen.lock @@ -3,27 +3,30 @@ id: 463c131b-53b7-4e8c-b4de-cf4f7ff95e4d management: docChecksum: f22d50ed7aa4956c14ae35514eb114c1 docVersion: 1.0.0 - speakeasyVersion: 1.241.0 - generationVersion: 2.300.0 - releaseVersion: 4.3.1 - configChecksum: 8cc801290c7c0e0dc5bc4c0e07273959 + speakeasyVersion: 1.313.1 + generationVersion: 2.347.8 + releaseVersion: 4.4.0 + configChecksum: b52b914035b5c023c16c1f1e5b3c00bd repoURL: https://github.com/speakeasy-sdks/template-sdk.git repoSubDirectory: . installationURL: https://github.com/speakeasy-sdks/template-sdk.git features: python: + additionalDependencies: 0.1.0 callbacks: 1.0.0 - core: 4.6.3 - errors: 2.81.8 + core: 4.6.12 + errors: 2.81.10 examples: 2.81.3 flattening: 2.81.1 globalSecurity: 2.83.5 + globalSecurityCallbacks: 0.1.0 globalServerURLs: 2.82.2 inputOutputModels: 2.83.1 methodSecurity: 2.82.1 methodServerURLs: 2.82.1 responseFormat: 0.1.0 - retries: 2.82.1 + retries: 2.82.2 + sdkHooks: 0.1.0 serverIDs: 2.81.1 webhooks: 1.0.0 generatedFiles: @@ -36,13 +39,13 @@ generatedFiles: - src/speakeasybar/sdk.py - py.typed - pylintrc + - scripts/publish.sh - setup.py - src/speakeasybar/__init__.py - src/speakeasybar/utils/__init__.py - src/speakeasybar/utils/retries.py - src/speakeasybar/utils/utils.py - src/speakeasybar/models/errors/sdkerror.py - - tests/helpers.py - src/speakeasybar/models/operations/login.py - src/speakeasybar/models/operations/getdrink.py - src/speakeasybar/models/operations/listdrinks.py diff --git a/.speakeasy/workflow.lock b/.speakeasy/workflow.lock new file mode 100644 index 0000000..b165fa5 --- /dev/null +++ b/.speakeasy/workflow.lock @@ -0,0 +1,29 @@ +speakeasyVersion: 1.313.1 +sources: + my-source: + sourceNamespace: my-source + sourceRevisionDigest: sha256:b26c23d8cbc993f5ded4f61e07626a7ed8ae3a6ed6b0b1138281306a1ed85a08 + sourceBlobDigest: sha256:639db1978db22dbfd275139e5511de46c975b0ddf833ed650d4559b2d1f220a5 + tags: + - latest + - main +targets: + python-template: + source: my-source + sourceNamespace: my-source + sourceRevisionDigest: sha256:b26c23d8cbc993f5ded4f61e07626a7ed8ae3a6ed6b0b1138281306a1ed85a08 + sourceBlobDigest: sha256:639db1978db22dbfd275139e5511de46c975b0ddf833ed650d4559b2d1f220a5 + outLocation: /github/workspace/repo +workflow: + workflowVersion: 1.0.0 + speakeasyVersion: latest + sources: + my-source: + inputs: + - location: ./openapi.yaml + registry: + location: registry.speakeasyapi.dev/openapi-example/openapi-example/my-source + targets: + python-template: + target: python + source: my-source diff --git a/.speakeasy/workflow.yaml b/.speakeasy/workflow.yaml index d028444..7607b5a 100644 --- a/.speakeasy/workflow.yaml +++ b/.speakeasy/workflow.yaml @@ -1,8 +1,11 @@ workflowVersion: 1.0.0 +speakeasyVersion: latest sources: my-source: inputs: - location: ./openapi.yaml + registry: + location: registry.speakeasyapi.dev/openapi-example/openapi-example/my-source targets: python-template: target: python diff --git a/README.md b/README.md index ca4fcb6..e81bad2 100755 --- a/README.md +++ b/README.md @@ -78,11 +78,10 @@ from speakeasybar.models import operations s = speakeasybar.Speakeasybar() -req = operations.LoginRequestBody( - type=operations.Type.API_KEY, -) -res = s.authentication.login(req, operations.LoginSecurity( +res = s.authentication.login(request=operations.LoginRequestBody( + type=operations.Type.API_KEY, +), security=operations.LoginSecurity( password="", username="", )) @@ -159,11 +158,10 @@ s = speakeasybar.Speakeasybar( ), ) -req = [ - operations.RequestBody(), -] -res = s.config.subscribe_to_webhooks(req) +res = s.config.subscribe_to_webhooks(request=[ + operations.RequestBody(), +]) if res is not None: # handle response @@ -216,11 +214,10 @@ s = speakeasybar.Speakeasybar( ), ) -req = [ - operations.RequestBody(), -] -res = s.config.subscribe_to_webhooks(req, +res = s.config.subscribe_to_webhooks(request=[ + operations.RequestBody(), +], RetryConfig('backoff', BackoffStrategy(1, 50, 1.1, 100), False)) if res is not None: @@ -236,17 +233,16 @@ from speakeasybar.models import operations, shared from speakeasybar.utils import BackoffStrategy, RetryConfig s = speakeasybar.Speakeasybar( - retry_config=RetryConfig('backoff', BackoffStrategy(1, 50, 1.1, 100), False) + retry_config=RetryConfig('backoff', BackoffStrategy(1, 50, 1.1, 100), False), security=shared.Security( api_key="", ), ) -req = [ - operations.RequestBody(), -] -res = s.config.subscribe_to_webhooks(req) +res = s.config.subscribe_to_webhooks(request=[ + operations.RequestBody(), +]) if res is not None: # handle response @@ -280,13 +276,12 @@ s = speakeasybar.Speakeasybar( ), ) -req = [ - operations.RequestBody(), -] - res = None try: - res = s.config.subscribe_to_webhooks(req) + res = s.config.subscribe_to_webhooks(request=[ + operations.RequestBody(), +]) + except errors.BadRequest as e: # handle exception raise(e) @@ -388,7 +383,7 @@ s = speakeasybar.Speakeasybar( ) -res = s.drinks.list_drinks(server_url="https://speakeasy.bar", drink_type=shared.DrinkType.SPIRIT) +res = s.drinks.list_drinks(drink_type=shared.DrinkType.SPIRIT, server_url="https://speakeasy.bar") if res.classes is not None: # handle response @@ -460,11 +455,10 @@ from speakeasybar.models import operations s = speakeasybar.Speakeasybar() -req = operations.LoginRequestBody( - type=operations.Type.API_KEY, -) -res = s.authentication.login(req, operations.LoginSecurity( +res = s.authentication.login(request=operations.LoginRequestBody( + type=operations.Type.API_KEY, +), security=operations.LoginSecurity( password="", username="", )) diff --git a/RELEASES.md b/RELEASES.md index 0c3359f..3430f74 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -382,4 +382,12 @@ Based on: - OpenAPI Doc 1.0.0 - Speakeasy CLI 1.207.1 (2.280.6) https://github.com/speakeasy-api/speakeasy ### Generated -- [python v4.2.1] . \ No newline at end of file +- [python v4.2.1] . + +## 2024-06-21 09:36:42 +### Changes +Based on: +- OpenAPI Doc +- Speakeasy CLI 1.313.1 (2.347.8) https://github.com/speakeasy-api/speakeasy +### Generated +- [python v4.4.0] . \ No newline at end of file diff --git a/USAGE.md b/USAGE.md index c11d07c..49f55da 100644 --- a/USAGE.md +++ b/USAGE.md @@ -11,11 +11,10 @@ from speakeasybar.models import operations s = speakeasybar.Speakeasybar() -req = operations.LoginRequestBody( - type=operations.Type.API_KEY, -) -res = s.authentication.login(req, operations.LoginSecurity( +res = s.authentication.login(request=operations.LoginRequestBody( + type=operations.Type.API_KEY, +), security=operations.LoginSecurity( password="", username="", )) @@ -92,11 +91,10 @@ s = speakeasybar.Speakeasybar( ), ) -req = [ - operations.RequestBody(), -] -res = s.config.subscribe_to_webhooks(req) +res = s.config.subscribe_to_webhooks(request=[ + operations.RequestBody(), +]) if res is not None: # handle response diff --git a/docs/sdks/authentication/README.md b/docs/sdks/authentication/README.md index 730142e..3064b41 100644 --- a/docs/sdks/authentication/README.md +++ b/docs/sdks/authentication/README.md @@ -21,11 +21,10 @@ from speakeasybar.models import operations s = speakeasybar.Speakeasybar() -req = operations.LoginRequestBody( - type=operations.Type.API_KEY, -) -res = s.authentication.login(req, operations.LoginSecurity( +res = s.authentication.login(request=operations.LoginRequestBody( + type=operations.Type.API_KEY, +), security=operations.LoginSecurity( password="", username="", )) diff --git a/docs/sdks/config/README.md b/docs/sdks/config/README.md index 8dc7da0..5664279 100644 --- a/docs/sdks/config/README.md +++ b/docs/sdks/config/README.md @@ -21,11 +21,10 @@ s = speakeasybar.Speakeasybar( ), ) -req = [ - operations.RequestBody(), -] -res = s.config.subscribe_to_webhooks(req) +res = s.config.subscribe_to_webhooks(request=[ + operations.RequestBody(), +]) if res is not None: # handle response diff --git a/gen.yaml b/gen.yaml index 499eeb8..d1aaa30 100644 --- a/gen.yaml +++ b/gen.yaml @@ -12,7 +12,7 @@ generation: baseServerURL: "" telemetryEnabled: false python: - version: 4.3.1 + version: 4.4.0 additionalDependencies: dependencies: {} extraDependencies: @@ -31,7 +31,9 @@ python: webhooks: models/webhooks inputModelSuffix: input maxMethodParams: 4 + methodArguments: require-security-and-request outputModelSuffix: output packageName: speakeasybar projectUrls: {} responseFormat: envelope + templateVersion: v1 diff --git a/scripts/publish.sh b/scripts/publish.sh new file mode 100755 index 0000000..ed45d8a --- /dev/null +++ b/scripts/publish.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +export TWINE_USERNAME=__token__ +export TWINE_PASSWORD=${PYPI_TOKEN} + +python -m pip install --upgrade pip +pip install setuptools wheel twine +python setup.py sdist bdist_wheel +twine upload dist/* diff --git a/setup.py b/setup.py index f1bc449..fe1a70c 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ setuptools.setup( name='speakeasybar', - version='4.3.1', + version='4.4.0', author='Speakeasy', description='Python Client SDK Generated by Speakeasy', url='https://github.com/speakeasy-sdks/template-sdk.git', diff --git a/src/speakeasybar/_hooks/__init__.py b/src/speakeasybar/_hooks/__init__.py index 5fd985a..b2ab14b 100644 --- a/src/speakeasybar/_hooks/__init__.py +++ b/src/speakeasybar/_hooks/__init__.py @@ -2,4 +2,3 @@ from .sdkhooks import * from .types import * -from .registration import * diff --git a/src/speakeasybar/_hooks/sdkhooks.py b/src/speakeasybar/_hooks/sdkhooks.py index 17750b6..1bd70b2 100644 --- a/src/speakeasybar/_hooks/sdkhooks.py +++ b/src/speakeasybar/_hooks/sdkhooks.py @@ -2,7 +2,6 @@ import requests from .types import SDKInitHook, BeforeRequestContext, BeforeRequestHook, AfterSuccessContext, AfterSuccessHook, AfterErrorContext, AfterErrorHook, Hooks -from .registration import init_hooks from typing import List, Optional, Tuple @@ -12,7 +11,6 @@ def __init__(self): self.before_request_hooks: List[BeforeRequestHook] = [] self.after_success_hooks: List[AfterSuccessHook] = [] self.after_error_hooks: List[AfterErrorHook] = [] - init_hooks(self) def register_sdk_init_hook(self, hook: SDKInitHook) -> None: self.sdk_init_hooks.append(hook) diff --git a/src/speakeasybar/authentication.py b/src/speakeasybar/authentication.py index ffb4898..376713a 100644 --- a/src/speakeasybar/authentication.py +++ b/src/speakeasybar/authentication.py @@ -57,6 +57,7 @@ def login(self, request: operations.LoginRequestBody, security: operations.Login res = operations.LoginResponse(status_code=http_res.status_code, content_type=http_res.headers.get('Content-Type') or '', raw_response=http_res) if http_res.status_code == 200: + # pylint: disable=no-else-return if utils.match_content_type(http_res.headers.get('Content-Type') or '', 'application/json'): out = utils.unmarshal_json(http_res.text, Optional[operations.LoginResponseBody]) res.object = out @@ -66,6 +67,7 @@ def login(self, request: operations.LoginRequestBody, security: operations.Login elif http_res.status_code == 401 or http_res.status_code >= 400 and http_res.status_code < 500: raise errors.SDKError('API error occurred', http_res.status_code, http_res.text, http_res) elif http_res.status_code >= 500 and http_res.status_code < 600: + # pylint: disable=no-else-return if utils.match_content_type(http_res.headers.get('Content-Type') or '', 'application/json'): out = utils.unmarshal_json(http_res.text, errors.APIError) raise out @@ -73,6 +75,7 @@ def login(self, request: operations.LoginRequestBody, security: operations.Login content_type = http_res.headers.get('Content-Type') raise errors.SDKError(f'unknown content-type received: {content_type}', http_res.status_code, http_res.text, http_res) else: + # pylint: disable=no-else-return if utils.match_content_type(http_res.headers.get('Content-Type') or '', 'application/json'): out = utils.unmarshal_json(http_res.text, Optional[shared.Error]) res.error = out diff --git a/src/speakeasybar/config.py b/src/speakeasybar/config.py index e4cd430..f4592e3 100644 --- a/src/speakeasybar/config.py +++ b/src/speakeasybar/config.py @@ -44,7 +44,7 @@ def subscribe_to_webhooks(self, request: List[operations.RequestBody], retries: if global_retry_config: retry_config = global_retry_config else: - retry_config = utils.RetryConfig('backoff', utils.BackoffStrategy(10, 200, 1.5, 1000), True) + retry_config = utils.RetryConfig('backoff', utils.BackoffStrategy(10, 200, 1.5, 1000), False) req = None def do_request(): @@ -54,9 +54,10 @@ def do_request(): req = self.sdk_configuration.get_hooks().before_request(BeforeRequestContext(hook_ctx), req) http_res = client.send(req) except Exception as e: - _, e = self.sdk_configuration.get_hooks().after_error(AfterErrorContext(hook_ctx), None, e) - if e is not None: - raise e + _, err = self.sdk_configuration.get_hooks().after_error(AfterErrorContext(hook_ctx), None, e) + if err is not None: + raise err from e + raise e if utils.match_status_codes(['400','4XX','5XX'], http_res.status_code): result, e = self.sdk_configuration.get_hooks().after_error(AfterErrorContext(hook_ctx), http_res, None) @@ -64,6 +65,8 @@ def do_request(): raise e if result is not None: http_res = result + else: + raise errors.SDKError('Unexpected error occurred', -1, '', None) else: http_res = self.sdk_configuration.get_hooks().after_success(AfterSuccessContext(hook_ctx), http_res) @@ -79,6 +82,7 @@ def do_request(): if http_res.status_code == 200: pass elif http_res.status_code == 400: + # pylint: disable=no-else-return if utils.match_content_type(http_res.headers.get('Content-Type') or '', 'application/json'): out = utils.unmarshal_json(http_res.text, errors.BadRequest) raise out @@ -88,6 +92,7 @@ def do_request(): elif http_res.status_code >= 400 and http_res.status_code < 500: raise errors.SDKError('API error occurred', http_res.status_code, http_res.text, http_res) elif http_res.status_code >= 500 and http_res.status_code < 600: + # pylint: disable=no-else-return if utils.match_content_type(http_res.headers.get('Content-Type') or '', 'application/json'): out = utils.unmarshal_json(http_res.text, errors.APIError) raise out @@ -95,6 +100,7 @@ def do_request(): content_type = http_res.headers.get('Content-Type') raise errors.SDKError(f'unknown content-type received: {content_type}', http_res.status_code, http_res.text, http_res) else: + # pylint: disable=no-else-return if utils.match_content_type(http_res.headers.get('Content-Type') or '', 'application/json'): out = utils.unmarshal_json(http_res.text, Optional[shared.Error]) res.error = out diff --git a/src/speakeasybar/drinks.py b/src/speakeasybar/drinks.py index 835288b..4108445 100644 --- a/src/speakeasybar/drinks.py +++ b/src/speakeasybar/drinks.py @@ -27,7 +27,7 @@ def get_drink(self, name: str) -> operations.GetDrinkResponse: base_url = utils.template_url(*self.sdk_configuration.get_server_details()) - url = utils.generate_url(operations.GetDrinkRequest, base_url, '/drink/{name}', request) + url = utils.generate_url(base_url, '/drink/{name}', request) if callable(self.sdk_configuration.security): headers, query_params = utils.get_security(self.sdk_configuration.security()) @@ -61,6 +61,7 @@ def get_drink(self, name: str) -> operations.GetDrinkResponse: res = operations.GetDrinkResponse(status_code=http_res.status_code, content_type=http_res.headers.get('Content-Type') or '', raw_response=http_res) if http_res.status_code == 200: + # pylint: disable=no-else-return if utils.match_content_type(http_res.headers.get('Content-Type') or '', 'application/json'): out = utils.unmarshal_json(http_res.text, Optional[shared.Drink]) res.drink = out @@ -70,6 +71,7 @@ def get_drink(self, name: str) -> operations.GetDrinkResponse: elif http_res.status_code >= 400 and http_res.status_code < 500: raise errors.SDKError('API error occurred', http_res.status_code, http_res.text, http_res) elif http_res.status_code >= 500 and http_res.status_code < 600: + # pylint: disable=no-else-return if utils.match_content_type(http_res.headers.get('Content-Type') or '', 'application/json'): out = utils.unmarshal_json(http_res.text, errors.APIError) raise out @@ -77,6 +79,7 @@ def get_drink(self, name: str) -> operations.GetDrinkResponse: content_type = http_res.headers.get('Content-Type') raise errors.SDKError(f'unknown content-type received: {content_type}', http_res.status_code, http_res.text, http_res) else: + # pylint: disable=no-else-return if utils.match_content_type(http_res.headers.get('Content-Type') or '', 'application/json'): out = utils.unmarshal_json(http_res.text, Optional[shared.Error]) res.error = out @@ -109,7 +112,7 @@ def list_drinks(self, drink_type: Optional[shared.DrinkType] = None, server_url: else: headers, query_params = utils.get_security(self.sdk_configuration.security) - query_params = { **utils.get_query_params(operations.ListDrinksRequest, request), **query_params } + query_params = { **utils.get_query_params(request), **query_params } headers['Accept'] = 'application/json' headers['user-agent'] = self.sdk_configuration.user_agent client = self.sdk_configuration.client @@ -137,6 +140,7 @@ def list_drinks(self, drink_type: Optional[shared.DrinkType] = None, server_url: res = operations.ListDrinksResponse(status_code=http_res.status_code, content_type=http_res.headers.get('Content-Type') or '', raw_response=http_res) if http_res.status_code == 200: + # pylint: disable=no-else-return if utils.match_content_type(http_res.headers.get('Content-Type') or '', 'application/json'): out = utils.unmarshal_json(http_res.text, Optional[List[shared.Drink]]) res.classes = out @@ -146,6 +150,7 @@ def list_drinks(self, drink_type: Optional[shared.DrinkType] = None, server_url: elif http_res.status_code >= 400 and http_res.status_code < 500: raise errors.SDKError('API error occurred', http_res.status_code, http_res.text, http_res) elif http_res.status_code >= 500 and http_res.status_code < 600: + # pylint: disable=no-else-return if utils.match_content_type(http_res.headers.get('Content-Type') or '', 'application/json'): out = utils.unmarshal_json(http_res.text, errors.APIError) raise out @@ -153,6 +158,7 @@ def list_drinks(self, drink_type: Optional[shared.DrinkType] = None, server_url: content_type = http_res.headers.get('Content-Type') raise errors.SDKError(f'unknown content-type received: {content_type}', http_res.status_code, http_res.text, http_res) else: + # pylint: disable=no-else-return if utils.match_content_type(http_res.headers.get('Content-Type') or '', 'application/json'): out = utils.unmarshal_json(http_res.text, Optional[shared.Error]) res.error = out diff --git a/src/speakeasybar/ingredients.py b/src/speakeasybar/ingredients.py index 54bab50..9ff019b 100644 --- a/src/speakeasybar/ingredients.py +++ b/src/speakeasybar/ingredients.py @@ -34,7 +34,7 @@ def list_ingredients(self, ingredients: Optional[List[str]] = None) -> operation else: headers, query_params = utils.get_security(self.sdk_configuration.security) - query_params = { **utils.get_query_params(operations.ListIngredientsRequest, request), **query_params } + query_params = { **utils.get_query_params(request), **query_params } headers['Accept'] = 'application/json' headers['user-agent'] = self.sdk_configuration.user_agent client = self.sdk_configuration.client @@ -62,6 +62,7 @@ def list_ingredients(self, ingredients: Optional[List[str]] = None) -> operation res = operations.ListIngredientsResponse(status_code=http_res.status_code, content_type=http_res.headers.get('Content-Type') or '', raw_response=http_res) if http_res.status_code == 200: + # pylint: disable=no-else-return if utils.match_content_type(http_res.headers.get('Content-Type') or '', 'application/json'): out = utils.unmarshal_json(http_res.text, Optional[List[shared.Ingredient]]) res.classes = out @@ -71,6 +72,7 @@ def list_ingredients(self, ingredients: Optional[List[str]] = None) -> operation elif http_res.status_code >= 400 and http_res.status_code < 500: raise errors.SDKError('API error occurred', http_res.status_code, http_res.text, http_res) elif http_res.status_code >= 500 and http_res.status_code < 600: + # pylint: disable=no-else-return if utils.match_content_type(http_res.headers.get('Content-Type') or '', 'application/json'): out = utils.unmarshal_json(http_res.text, errors.APIError) raise out @@ -78,6 +80,7 @@ def list_ingredients(self, ingredients: Optional[List[str]] = None) -> operation content_type = http_res.headers.get('Content-Type') raise errors.SDKError(f'unknown content-type received: {content_type}', http_res.status_code, http_res.text, http_res) else: + # pylint: disable=no-else-return if utils.match_content_type(http_res.headers.get('Content-Type') or '', 'application/json'): out = utils.unmarshal_json(http_res.text, Optional[shared.Error]) res.error = out diff --git a/src/speakeasybar/models/operations/login.py b/src/speakeasybar/models/operations/login.py index 6f0b1bb..6222859 100644 --- a/src/speakeasybar/models/operations/login.py +++ b/src/speakeasybar/models/operations/login.py @@ -17,6 +17,7 @@ class LoginSecurity: + class Type(str, Enum): API_KEY = 'apiKey' JWT = 'JWT' diff --git a/src/speakeasybar/models/operations/subscribetowebhooks.py b/src/speakeasybar/models/operations/subscribetowebhooks.py index 681be62..13240b7 100644 --- a/src/speakeasybar/models/operations/subscribetowebhooks.py +++ b/src/speakeasybar/models/operations/subscribetowebhooks.py @@ -9,6 +9,7 @@ from speakeasybar import utils from typing import Optional + class Webhook(str, Enum): STOCK_UPDATE = 'stockUpdate' diff --git a/src/speakeasybar/models/shared/drinktype.py b/src/speakeasybar/models/shared/drinktype.py index dc4eb26..328d01e 100644 --- a/src/speakeasybar/models/shared/drinktype.py +++ b/src/speakeasybar/models/shared/drinktype.py @@ -3,6 +3,7 @@ from __future__ import annotations from enum import Enum + class DrinkType(str, Enum): r"""The type of drink.""" COCKTAIL = 'cocktail' diff --git a/src/speakeasybar/models/shared/ingredienttype.py b/src/speakeasybar/models/shared/ingredienttype.py index 35a56e3..a3262b2 100644 --- a/src/speakeasybar/models/shared/ingredienttype.py +++ b/src/speakeasybar/models/shared/ingredienttype.py @@ -3,6 +3,7 @@ from __future__ import annotations from enum import Enum + class IngredientType(str, Enum): r"""The type of ingredient.""" FRESH = 'fresh' diff --git a/src/speakeasybar/models/shared/order.py b/src/speakeasybar/models/shared/order.py index 5e5e831..acbde3e 100644 --- a/src/speakeasybar/models/shared/order.py +++ b/src/speakeasybar/models/shared/order.py @@ -7,6 +7,7 @@ from enum import Enum from speakeasybar import utils + class Status(str, Enum): r"""The status of the order.""" PENDING = 'pending' diff --git a/src/speakeasybar/models/shared/ordertype.py b/src/speakeasybar/models/shared/ordertype.py index 8861ba2..5796c48 100644 --- a/src/speakeasybar/models/shared/ordertype.py +++ b/src/speakeasybar/models/shared/ordertype.py @@ -3,6 +3,7 @@ from __future__ import annotations from enum import Enum + class OrderType(str, Enum): r"""The type of order.""" DRINK = 'drink' diff --git a/src/speakeasybar/orders.py b/src/speakeasybar/orders.py index c37b179..9b843a8 100644 --- a/src/speakeasybar/orders.py +++ b/src/speakeasybar/orders.py @@ -40,7 +40,7 @@ def create_order(self, request_body: List[shared.OrderInput], callback_url: Opti headers['content-type'] = req_content_type if data is None and form is None: raise Exception('request body is required') - query_params = { **utils.get_query_params(operations.CreateOrderRequest, request), **query_params } + query_params = { **utils.get_query_params(request), **query_params } headers['Accept'] = 'application/json' headers['user-agent'] = self.sdk_configuration.user_agent client = self.sdk_configuration.client @@ -68,6 +68,7 @@ def create_order(self, request_body: List[shared.OrderInput], callback_url: Opti res = operations.CreateOrderResponse(status_code=http_res.status_code, content_type=http_res.headers.get('Content-Type') or '', raw_response=http_res) if http_res.status_code == 200: + # pylint: disable=no-else-return if utils.match_content_type(http_res.headers.get('Content-Type') or '', 'application/json'): out = utils.unmarshal_json(http_res.text, Optional[shared.Order]) res.order = out @@ -77,6 +78,7 @@ def create_order(self, request_body: List[shared.OrderInput], callback_url: Opti elif http_res.status_code >= 400 and http_res.status_code < 500: raise errors.SDKError('API error occurred', http_res.status_code, http_res.text, http_res) elif http_res.status_code >= 500 and http_res.status_code < 600: + # pylint: disable=no-else-return if utils.match_content_type(http_res.headers.get('Content-Type') or '', 'application/json'): out = utils.unmarshal_json(http_res.text, errors.APIError) raise out @@ -84,6 +86,7 @@ def create_order(self, request_body: List[shared.OrderInput], callback_url: Opti content_type = http_res.headers.get('Content-Type') raise errors.SDKError(f'unknown content-type received: {content_type}', http_res.status_code, http_res.text, http_res) else: + # pylint: disable=no-else-return if utils.match_content_type(http_res.headers.get('Content-Type') or '', 'application/json'): out = utils.unmarshal_json(http_res.text, Optional[shared.Error]) res.error = out diff --git a/src/speakeasybar/sdk.py b/src/speakeasybar/sdk.py index a0ba66e..504ff7b 100644 --- a/src/speakeasybar/sdk.py +++ b/src/speakeasybar/sdk.py @@ -76,6 +76,7 @@ def __init__(self, 'organization': organization or 'api', }, } + self.sdk_configuration = SDKConfiguration( client, @@ -94,7 +95,7 @@ def __init__(self, self.sdk_configuration.server_url = server_url # pylint: disable=protected-access - self.sdk_configuration._hooks = hooks + self.sdk_configuration.__dict__['_hooks'] = hooks self._init_sdks() diff --git a/src/speakeasybar/sdkconfiguration.py b/src/speakeasybar/sdkconfiguration.py index 86fa8e1..54aa852 100644 --- a/src/speakeasybar/sdkconfiguration.py +++ b/src/speakeasybar/sdkconfiguration.py @@ -25,6 +25,7 @@ """Contains the list of servers available to the SDK""" + class ServerEnvironment(str, Enum): r"""The environment name. Defaults to the production environment.""" PROD = 'prod' @@ -41,11 +42,13 @@ class SDKConfiguration: server_defaults: Dict[str, Dict[str, str]] = field(default_factory=Dict) language: str = 'python' openapi_doc_version: str = '1.0.0' - sdk_version: str = '4.3.1' - gen_version: str = '2.300.0' - user_agent: str = 'speakeasy-sdk/python 4.3.1 2.300.0 1.0.0 speakeasybar' + sdk_version: str = '4.4.0' + gen_version: str = '2.347.8' + user_agent: str = 'speakeasy-sdk/python 4.4.0 2.347.8 1.0.0 speakeasybar' retry_config: Optional[RetryConfig] = None - _hooks: Optional[SDKHooks] = None + + def __post_init__(self): + self._hooks = SDKHooks() def get_server_details(self) -> Tuple[str, Dict[str, str]]: if self.server_url is not None and self.server_url != '': diff --git a/src/speakeasybar/utils/utils.py b/src/speakeasybar/utils/utils.py index ddd914c..b6a5542 100644 --- a/src/speakeasybar/utils/utils.py +++ b/src/speakeasybar/utils/utils.py @@ -9,8 +9,17 @@ from decimal import Decimal from email.message import Message from enum import Enum -from typing import (Any, Callable, Dict, List, Optional, Tuple, Union, - get_args, get_origin) +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Union, + get_args, + get_origin, +) from xmlrpc.client import boolean from typing_inspect import is_optional_type import dateutil.parser @@ -30,13 +39,13 @@ def get_security(security: Any) -> Tuple[Dict[str, str], Dict[str, str]]: if value is None: continue - metadata = sec_field.metadata.get('security') + metadata = sec_field.metadata.get("security") if metadata is None: continue - if metadata.get('option'): + if metadata.get("option"): _parse_security_option(headers, query_params, value) return headers, query_params - if metadata.get('scheme'): + if metadata.get("scheme"): # Special case for basic auth which could be a flattened struct if metadata.get("sub_type") == "basic" and not is_dataclass(value): _parse_security_scheme(headers, query_params, metadata, security) @@ -46,69 +55,85 @@ def get_security(security: Any) -> Tuple[Dict[str, str], Dict[str, str]]: return headers, query_params -def _parse_security_option(headers: Dict[str, str], query_params: Dict[str, str], option: Any): +def _parse_security_option( + headers: Dict[str, str], query_params: Dict[str, str], option: Any +): opt_fields: Tuple[Field, ...] = fields(option) for opt_field in opt_fields: - metadata = opt_field.metadata.get('security') - if metadata is None or metadata.get('scheme') is None: + metadata = opt_field.metadata.get("security") + if metadata is None or metadata.get("scheme") is None: continue _parse_security_scheme( - headers, query_params, metadata, getattr(option, opt_field.name)) + headers, query_params, metadata, getattr(option, opt_field.name) + ) -def _parse_security_scheme(headers: Dict[str, str], query_params: Dict[str, str], scheme_metadata: Dict, scheme: Any): - scheme_type = scheme_metadata.get('type') - sub_type = scheme_metadata.get('sub_type') +def _parse_security_scheme( + headers: Dict[str, str], + query_params: Dict[str, str], + scheme_metadata: Dict, + scheme: Any, +): + scheme_type = scheme_metadata.get("type") + sub_type = scheme_metadata.get("sub_type") if is_dataclass(scheme): - if scheme_type == 'http' and sub_type == 'basic': + if scheme_type == "http" and sub_type == "basic": _parse_basic_auth_scheme(headers, scheme) return scheme_fields: Tuple[Field, ...] = fields(scheme) for scheme_field in scheme_fields: - metadata = scheme_field.metadata.get('security') - if metadata is None or metadata.get('field_name') is None: + metadata = scheme_field.metadata.get("security") + if metadata is None or metadata.get("field_name") is None: continue value = getattr(scheme, scheme_field.name) _parse_security_scheme_value( - headers, query_params, scheme_metadata, metadata, value) + headers, query_params, scheme_metadata, metadata, value + ) else: _parse_security_scheme_value( - headers, query_params, scheme_metadata, scheme_metadata, scheme) + headers, query_params, scheme_metadata, scheme_metadata, scheme + ) -def _parse_security_scheme_value(headers: Dict[str, str], query_params: Dict[str, str], scheme_metadata: Dict, security_metadata: Dict, value: Any): - scheme_type = scheme_metadata.get('type') - sub_type = scheme_metadata.get('sub_type') +def _parse_security_scheme_value( + headers: Dict[str, str], + query_params: Dict[str, str], + scheme_metadata: Dict, + security_metadata: Dict, + value: Any, +): + scheme_type = scheme_metadata.get("type") + sub_type = scheme_metadata.get("sub_type") - header_name = str(security_metadata.get('field_name')) + header_name = str(security_metadata.get("field_name")) if scheme_type == "apiKey": - if sub_type == 'header': + if sub_type == "header": headers[header_name] = value - elif sub_type == 'query': + elif sub_type == "query": query_params[header_name] = value else: - raise Exception('not supported') + raise Exception("not supported") elif scheme_type == "openIdConnect": headers[header_name] = _apply_bearer(value) - elif scheme_type == 'oauth2': - if sub_type != 'client_credentials': + elif scheme_type == "oauth2": + if sub_type != "client_credentials": headers[header_name] = _apply_bearer(value) - elif scheme_type == 'http': - if sub_type == 'bearer': + elif scheme_type == "http": + if sub_type == "bearer": headers[header_name] = _apply_bearer(value) else: - raise Exception('not supported') + raise Exception("not supported") else: - raise Exception('not supported') + raise Exception("not supported") def _apply_bearer(token: str) -> str: - return token.lower().startswith('bearer ') and token or f'Bearer {token}' + return token.lower().startswith("bearer ") and token or f"Bearer {token}" def _parse_basic_auth_scheme(headers: Dict[str, str], scheme: Any): @@ -117,101 +142,130 @@ def _parse_basic_auth_scheme(headers: Dict[str, str], scheme: Any): scheme_fields: Tuple[Field, ...] = fields(scheme) for scheme_field in scheme_fields: - metadata = scheme_field.metadata.get('security') - if metadata is None or metadata.get('field_name') is None: + metadata = scheme_field.metadata.get("security") + if metadata is None or metadata.get("field_name") is None: continue - field_name = metadata.get('field_name') + field_name = metadata.get("field_name") value = getattr(scheme, scheme_field.name) - if field_name == 'username': + if field_name == "username": username = value - if field_name == 'password': + if field_name == "password": password = value - data = f'{username}:{password}'.encode() - headers['Authorization'] = f'Basic {base64.b64encode(data).decode()}' + data = f"{username}:{password}".encode() + headers["Authorization"] = f"Basic {base64.b64encode(data).decode()}" -def generate_url(clazz: type, server_url: str, path: str, path_params: Any, - gbls: Optional[Dict[str, Dict[str, Dict[str, Any]]]] = None) -> str: - path_param_fields: Tuple[Field, ...] = fields(clazz) +def generate_url( + server_url: str, + path: str, + path_params: Any, + gbls: Optional[Any] = None, +) -> str: + path_param_values: Dict[str, str] = {} + + globals_already_populated = _populate_path_params( + path_params, gbls, path_param_values, [] + ) + if gbls is not None: + _populate_path_params(gbls, None, path_param_values, globals_already_populated) + + for key, value in path_param_values.items(): + path = path.replace("{" + key + "}", value, 1) + + return remove_suffix(server_url, "/") + path + + +def _populate_path_params( + path_params: Any, + gbls: Any, + path_param_values: Dict[str, str], + skip_fields: List[str], +) -> List[str]: + globals_already_populated: List[str] = [] + + path_param_fields: Tuple[Field, ...] = fields(path_params) for field in path_param_fields: - request_metadata = field.metadata.get('request') - if request_metadata is not None: + if field.name in skip_fields: continue - param_metadata = field.metadata.get('path_param') + param_metadata = field.metadata.get("path_param") if param_metadata is None: continue - param = getattr( - path_params, field.name) if path_params is not None else None - param = _populate_from_globals( - field.name, param, 'pathParam', gbls) + param = getattr(path_params, field.name) if path_params is not None else None + param, global_found = _populate_from_globals( + field.name, param, "path_param", gbls + ) + if global_found: + globals_already_populated.append(field.name) if param is None: continue f_name = param_metadata.get("field_name", field.name) - serialization = param_metadata.get('serialization', '') - if serialization != '': + serialization = param_metadata.get("serialization", "") + if serialization != "": serialized_params = _get_serialized_params( - param_metadata, field.type, f_name, param) + param_metadata, field.type, f_name, param + ) for key, value in serialized_params.items(): - path = path.replace( - '{' + key + '}', value, 1) + path_param_values[key] = value else: - if param_metadata.get('style', 'simple') == 'simple': + if param_metadata.get("style", "simple") == "simple": if isinstance(param, List): pp_vals: List[str] = [] for pp_val in param: if pp_val is None: continue pp_vals.append(_val_to_string(pp_val)) - path = path.replace( - '{' + param_metadata.get('field_name', field.name) + '}', ",".join(pp_vals), 1) + path_param_values[param_metadata.get("field_name", field.name)] = ( + ",".join(pp_vals) + ) elif isinstance(param, Dict): pp_vals: List[str] = [] for pp_key in param: if param[pp_key] is None: continue - if param_metadata.get('explode'): - pp_vals.append( - f"{pp_key}={_val_to_string(param[pp_key])}") + if param_metadata.get("explode"): + pp_vals.append(f"{pp_key}={_val_to_string(param[pp_key])}") else: - pp_vals.append( - f"{pp_key},{_val_to_string(param[pp_key])}") - path = path.replace( - '{' + param_metadata.get('field_name', field.name) + '}', ",".join(pp_vals), 1) + pp_vals.append(f"{pp_key},{_val_to_string(param[pp_key])}") + path_param_values[param_metadata.get("field_name", field.name)] = ( + ",".join(pp_vals) + ) elif not isinstance(param, (str, int, float, complex, bool, Decimal)): pp_vals: List[str] = [] param_fields: Tuple[Field, ...] = fields(param) for param_field in param_fields: - param_value_metadata = param_field.metadata.get( - 'path_param') + param_value_metadata = param_field.metadata.get("path_param") if not param_value_metadata: continue - parm_name = param_value_metadata.get( - 'field_name', field.name) + param_name = param_value_metadata.get("field_name", field.name) param_field_val = getattr(param, param_field.name) if param_field_val is None: continue - if param_metadata.get('explode'): + if param_metadata.get("explode"): pp_vals.append( - f"{parm_name}={_val_to_string(param_field_val)}") + f"{param_name}={_val_to_string(param_field_val)}" + ) else: pp_vals.append( - f"{parm_name},{_val_to_string(param_field_val)}") - path = path.replace( - '{' + param_metadata.get('field_name', field.name) + '}', ",".join(pp_vals), 1) + f"{param_name},{_val_to_string(param_field_val)}" + ) + path_param_values[param_metadata.get("field_name", field.name)] = ( + ",".join(pp_vals) + ) else: - path = path.replace( - '{' + param_metadata.get('field_name', field.name) + '}', _val_to_string(param), 1) + path_param_values[param_metadata.get("field_name", field.name)] = ( + _val_to_string(param) + ) - return remove_suffix(server_url, '/') + path + return globals_already_populated def is_optional(field): @@ -220,100 +274,145 @@ def is_optional(field): def template_url(url_with_params: str, params: Dict[str, str]) -> str: for key, value in params.items(): - url_with_params = url_with_params.replace( - '{' + key + '}', value) + url_with_params = url_with_params.replace("{" + key + "}", value) return url_with_params -def get_query_params(clazz: type, query_params: Any, gbls: Optional[Dict[str, Dict[str, Dict[str, Any]]]] = None) -> Dict[ - str, List[str]]: +def get_query_params( + query_params: Any, + gbls: Optional[Any] = None, +) -> Dict[str, List[str]]: params: Dict[str, List[str]] = {} - param_fields: Tuple[Field, ...] = fields(clazz) + globals_already_populated = _populate_query_params(query_params, gbls, params, []) + if gbls is not None: + _populate_query_params(gbls, None, params, globals_already_populated) + + return params + + +def _populate_query_params( + query_params: Any, + gbls: Any, + query_param_values: Dict[str, List[str]], + skip_fields: List[str], +) -> List[str]: + globals_already_populated: List[str] = [] + + param_fields: Tuple[Field, ...] = fields(query_params) for field in param_fields: - request_metadata = field.metadata.get('request') - if request_metadata is not None: + if field.name in skip_fields: continue - metadata = field.metadata.get('query_param') + metadata = field.metadata.get("query_param") if not metadata: continue param_name = field.name - value = getattr( - query_params, param_name) if query_params is not None else None + value = getattr(query_params, param_name) if query_params is not None else None - value = _populate_from_globals(param_name, value, 'queryParam', gbls) + value, global_found = _populate_from_globals( + param_name, value, "query_param", gbls + ) + if global_found: + globals_already_populated.append(param_name) f_name = metadata.get("field_name") - serialization = metadata.get('serialization', '') - if serialization != '': + serialization = metadata.get("serialization", "") + if serialization != "": serialized_parms = _get_serialized_params( - metadata, field.type, f_name, value) + metadata, field.type, f_name, value + ) for key, value in serialized_parms.items(): - if key in params: - params[key].extend(value) + if key in query_param_values: + query_param_values[key].extend(value) else: - params[key] = [value] + query_param_values[key] = [value] else: - style = metadata.get('style', 'form') - if style == 'deepObject': - params = {**params, **_get_deep_object_query_params( - metadata, f_name, value)} - elif style == 'form': - params = {**params, **_get_delimited_query_params( - metadata, f_name, value, ",")} - elif style == 'pipeDelimited': - params = {**params, **_get_delimited_query_params( - metadata, f_name, value, "|")} + style = metadata.get("style", "form") + if style == "deepObject": + _populate_deep_object_query_params( + metadata, f_name, value, query_param_values + ) + elif style == "form": + _populate_delimited_query_params( + metadata, f_name, value, ",", query_param_values + ) + elif style == "pipeDelimited": + _populate_delimited_query_params( + metadata, f_name, value, "|", query_param_values + ) else: - raise Exception('not yet implemented') - return params + raise Exception("not yet implemented") + return globals_already_populated -def get_headers(headers_params: Any, gbls: Optional[Dict[str, Dict[str, Dict[str, Any]]]] = None) -> Dict[str, str]: - if headers_params is None: - return {} +def get_headers(headers_params: Any, gbls: Optional[Any] = None) -> Dict[str, str]: headers: Dict[str, str] = {} + globals_already_populated = [] + if headers_params is not None: + globals_already_populated = _populate_headers(headers_params, gbls, headers, []) + if gbls is not None: + _populate_headers(gbls, None, headers, globals_already_populated) + + return headers + + +def _populate_headers( + headers_params: Any, + gbls: Any, + header_values: Dict[str, str], + skip_fields: List[str], +) -> List[str]: + globals_already_populated: List[str] = [] + param_fields: Tuple[Field, ...] = fields(headers_params) for field in param_fields: - metadata = field.metadata.get('header') + if field.name in skip_fields: + continue + + metadata = field.metadata.get("header") if not metadata: continue - value = _populate_from_globals(field.name, getattr(headers_params, field.name), 'header', gbls) - value = _serialize_header(metadata.get('explode', False), value) + value, global_found = _populate_from_globals( + field.name, getattr(headers_params, field.name), "header", gbls + ) + if global_found: + globals_already_populated.append(field.name) + value = _serialize_header(metadata.get("explode", False), value) - if value != '': - headers[metadata.get('field_name', field.name)] = value + if value != "": + header_values[metadata.get("field_name", field.name)] = value - return headers + return globals_already_populated -def _get_serialized_params(metadata: Dict, field_type: type, field_name: str, obj: Any) -> Dict[str, str]: +def _get_serialized_params( + metadata: Dict, field_type: type, field_name: str, obj: Any +) -> Dict[str, str]: params: Dict[str, str] = {} - serialization = metadata.get('serialization', '') - if serialization == 'json': - params[metadata.get("field_name", field_name) - ] = marshal_json(obj, field_type) + serialization = metadata.get("serialization", "") + if serialization == "json": + params[metadata.get("field_name", field_name)] = marshal_json(obj, field_type) return params -def _get_deep_object_query_params(metadata: Dict, field_name: str, obj: Any) -> Dict[str, List[str]]: - params: Dict[str, List[str]] = {} - +def _populate_deep_object_query_params( + metadata: Dict, field_name: str, obj: Any, params: Dict[str, List[str]] +): if obj is None: - return params + return if is_dataclass(obj): obj_fields: Tuple[Field, ...] = fields(obj) for obj_field in obj_fields: - obj_param_metadata = obj_field.metadata.get('query_param') + obj_param_metadata = obj_field.metadata.get("query_param") if not obj_param_metadata: continue @@ -326,19 +425,23 @@ def _get_deep_object_query_params(metadata: Dict, field_name: str, obj: Any) -> if val is None: continue - if params.get( - f'{metadata.get("field_name", field_name)}[{obj_param_metadata.get("field_name", obj_field.name)}]') is None: + if ( + params.get( + f'{metadata.get("field_name", field_name)}[{obj_param_metadata.get("field_name", obj_field.name)}]' + ) + is None + ): params[ - f'{metadata.get("field_name", field_name)}[{obj_param_metadata.get("field_name", obj_field.name)}]'] = [ - ] + f'{metadata.get("field_name", field_name)}[{obj_param_metadata.get("field_name", obj_field.name)}]' + ] = [] params[ - f'{metadata.get("field_name", field_name)}[{obj_param_metadata.get("field_name", obj_field.name)}]'].append( - _val_to_string(val)) + f'{metadata.get("field_name", field_name)}[{obj_param_metadata.get("field_name", obj_field.name)}]' + ].append(_val_to_string(val)) else: params[ - f'{metadata.get("field_name", field_name)}[{obj_param_metadata.get("field_name", obj_field.name)}]'] = [ - _val_to_string(obj_val)] + f'{metadata.get("field_name", field_name)}[{obj_param_metadata.get("field_name", obj_field.name)}]' + ] = [_val_to_string(obj_val)] elif isinstance(obj, Dict): for key, value in obj.items(): if value is None: @@ -349,20 +452,23 @@ def _get_deep_object_query_params(metadata: Dict, field_name: str, obj: Any) -> if val is None: continue - if params.get(f'{metadata.get("field_name", field_name)}[{key}]') is None: - params[f'{metadata.get("field_name", field_name)}[{key}]'] = [ - ] + if ( + params.get(f'{metadata.get("field_name", field_name)}[{key}]') + is None + ): + params[f'{metadata.get("field_name", field_name)}[{key}]'] = [] - params[ - f'{metadata.get("field_name", field_name)}[{key}]'].append(_val_to_string(val)) + params[f'{metadata.get("field_name", field_name)}[{key}]'].append( + _val_to_string(val) + ) else: params[f'{metadata.get("field_name", field_name)}[{key}]'] = [ - _val_to_string(value)] - return params + _val_to_string(value) + ] def _get_query_param_field_name(obj_field: Field) -> str: - obj_param_metadata = obj_field.metadata.get('query_param') + obj_param_metadata = obj_field.metadata.get("query_param") if not obj_param_metadata: return "" @@ -370,29 +476,53 @@ def _get_query_param_field_name(obj_field: Field) -> str: return obj_param_metadata.get("field_name", obj_field.name) -def _get_delimited_query_params(metadata: Dict, field_name: str, obj: Any, delimiter: str) -> Dict[ - str, List[str]]: - return _populate_form(field_name, metadata.get("explode", True), obj, _get_query_param_field_name, delimiter) +def _populate_delimited_query_params( + metadata: Dict, + field_name: str, + obj: Any, + delimiter: str, + query_param_values: Dict[str, List[str]], +): + _populate_form( + field_name, + metadata.get("explode", True), + obj, + _get_query_param_field_name, + delimiter, + query_param_values, + ) SERIALIZATION_METHOD_TO_CONTENT_TYPE = { - 'json': 'application/json', - 'form': 'application/x-www-form-urlencoded', - 'multipart': 'multipart/form-data', - 'raw': 'application/octet-stream', - 'string': 'text/plain', + "json": "application/json", + "form": "application/x-www-form-urlencoded", + "multipart": "multipart/form-data", + "raw": "application/octet-stream", + "string": "text/plain", } -def serialize_request_body(request: Any, request_type: type, request_field_name: str, nullable: bool, optional: bool, serialization_method: str, encoder=None) -> Tuple[ - Optional[str], Optional[Any], Optional[Any]]: +def serialize_request_body( + request: Any, + request_type: type, + request_field_name: str, + nullable: bool, + optional: bool, + serialization_method: str, + encoder=None, +) -> Tuple[Optional[str], Optional[Any], Optional[Any]]: if request is None: if not nullable and optional: return None, None, None if not is_dataclass(request) or not hasattr(request, request_field_name): - return serialize_content_type(request_field_name, request_type, SERIALIZATION_METHOD_TO_CONTENT_TYPE[serialization_method], - request, encoder) + return serialize_content_type( + request_field_name, + request_type, + SERIALIZATION_METHOD_TO_CONTENT_TYPE[serialization_method], + request, + encoder, + ) request_val = getattr(request, request_field_name) @@ -405,22 +535,28 @@ def serialize_request_body(request: Any, request_type: type, request_field_name: for field in request_fields: if field.name == request_field_name: - request_metadata = field.metadata.get('request') + request_metadata = field.metadata.get("request") break if request_metadata is None: - raise Exception('invalid request type') + raise Exception("invalid request type") - return serialize_content_type(request_field_name, request_type, request_metadata.get('media_type', 'application/octet-stream'), - request_val) + return serialize_content_type( + request_field_name, + request_type, + request_metadata.get("media_type", "application/octet-stream"), + request_val, + ) -def serialize_content_type(field_name: str, request_type: Any, media_type: str, request: Any, encoder=None) -> Tuple[Optional[str], Optional[Any], Optional[List[List[Any]]]]: - if re.match(r'(application|text)\/.*?\+*json.*', media_type) is not None: +def serialize_content_type( + field_name: str, request_type: Any, media_type: str, request: Any, encoder=None +) -> Tuple[Optional[str], Optional[Any], Optional[List[List[Any]]]]: + if re.match(r"(application|text)\/.*?\+*json.*", media_type) is not None: return media_type, marshal_json(request, request_type, encoder), None - if re.match(r'multipart\/.*', media_type) is not None: + if re.match(r"multipart\/.*", media_type) is not None: return serialize_multipart_form(media_type, request) - if re.match(r'application\/x-www-form-urlencoded.*', media_type) is not None: + if re.match(r"application\/x-www-form-urlencoded.*", media_type) is not None: return media_type, serialize_form_data(field_name, request), None if isinstance(request, (bytes, bytearray)): return media_type, request, None @@ -428,10 +564,13 @@ def serialize_content_type(field_name: str, request_type: Any, media_type: str, return media_type, request, None raise Exception( - f"invalid request body type {type(request)} for mediaType {media_type}") + f"invalid request body type {type(request)} for mediaType {media_type}" + ) -def serialize_multipart_form(media_type: str, request: Any) -> Tuple[str, Any, List[List[Any]]]: +def serialize_multipart_form( + media_type: str, request: Any +) -> Tuple[str, Any, List[List[Any]]]: form: List[List[Any]] = [] request_fields = fields(request) @@ -440,7 +579,7 @@ def serialize_multipart_form(media_type: str, request: Any) -> Tuple[str, Any, L if val is None: continue - field_metadata = field.metadata.get('multipart_form') + field_metadata = field.metadata.get("multipart_form") if not field_metadata: continue @@ -452,40 +591,40 @@ def serialize_multipart_form(media_type: str, request: Any) -> Tuple[str, Any, L content = bytes() for file_field in file_fields: - file_metadata = file_field.metadata.get('multipart_form') + file_metadata = file_field.metadata.get("multipart_form") if file_metadata is None: continue if file_metadata.get("content") is True: content = getattr(val, file_field.name) else: - field_name = file_metadata.get( - "field_name", file_field.name) + field_name = file_metadata.get("field_name", file_field.name) file_name = getattr(val, file_field.name) if field_name == "" or file_name == "" or content == bytes(): - raise Exception('invalid multipart/form-data file') + raise Exception("invalid multipart/form-data file") form.append([field_name, [file_name, content]]) elif field_metadata.get("json") is True: - to_append = [field_metadata.get("field_name", field.name), [ - None, marshal_json(val, field.type), "application/json"]] + to_append = [ + field_metadata.get("field_name", field.name), + [None, marshal_json(val, field.type), "application/json"], + ] form.append(to_append) else: - field_name = field_metadata.get( - "field_name", field.name) + field_name = field_metadata.get("field_name", field.name) if isinstance(val, List): for value in val: if value is None: continue - form.append( - [field_name + "[]", [None, _val_to_string(value)]]) + form.append([field_name + "[]", [None, _val_to_string(value)]]) else: form.append([field_name, [None, _val_to_string(val)]]) return media_type, None, form -def serialize_dict(original: Dict, explode: bool, field_name, existing: Optional[Dict[str, List[str]]]) -> Dict[ - str, List[str]]: +def serialize_dict( + original: Dict, explode: bool, field_name, existing: Optional[Dict[str, List[str]]] +) -> Dict[str, List[str]]: if existing is None: existing = {} @@ -514,32 +653,37 @@ def serialize_form_data(field_name: str, data: Any) -> Dict[str, Any]: if val is None: continue - metadata = field.metadata.get('form') + metadata = field.metadata.get("form") if metadata is None: continue - field_name = metadata.get('field_name', field.name) + field_name = metadata.get("field_name", field.name) - if metadata.get('json'): + if metadata.get("json"): form[field_name] = [marshal_json(val, field.type)] else: - if metadata.get('style', 'form') == 'form': - form = {**form, **_populate_form( - field_name, metadata.get('explode', True), val, _get_form_field_name, ",")} + if metadata.get("style", "form") == "form": + _populate_form( + field_name, + metadata.get("explode", True), + val, + _get_form_field_name, + ",", + form, + ) else: - raise Exception( - f'Invalid form style for field {field.name}') + raise Exception(f"Invalid form style for field {field.name}") elif isinstance(data, Dict): for key, value in data.items(): form[key] = [_val_to_string(value)] else: - raise Exception(f'Invalid request body type for field {field_name}') + raise Exception(f"Invalid request body type for field {field_name}") return form def _get_form_field_name(obj_field: Field) -> str: - obj_param_metadata = obj_field.metadata.get('form') + obj_param_metadata = obj_field.metadata.get("form") if not obj_param_metadata: return "" @@ -547,12 +691,16 @@ def _get_form_field_name(obj_field: Field) -> str: return obj_param_metadata.get("field_name", obj_field.name) -def _populate_form(field_name: str, explode: boolean, obj: Any, get_field_name_func: Callable, delimiter: str) -> \ - Dict[str, List[str]]: - params: Dict[str, List[str]] = {} - +def _populate_form( + field_name: str, + explode: boolean, + obj: Any, + get_field_name_func: Callable, + delimiter: str, + form: Dict[str, List[str]], +): if obj is None: - return params + return form if is_dataclass(obj): items = [] @@ -560,7 +708,7 @@ def _populate_form(field_name: str, explode: boolean, obj: Any, get_field_name_f obj_fields: Tuple[Field, ...] = fields(obj) for obj_field in obj_fields: obj_field_name = get_field_name_func(obj_field) - if obj_field_name == '': + if obj_field_name == "": continue val = getattr(obj, obj_field.name) @@ -568,13 +716,12 @@ def _populate_form(field_name: str, explode: boolean, obj: Any, get_field_name_f continue if explode: - params[obj_field_name] = [_val_to_string(val)] + form[obj_field_name] = [_val_to_string(val)] else: - items.append( - f'{obj_field_name}{delimiter}{_val_to_string(val)}') + items.append(f"{obj_field_name}{delimiter}{_val_to_string(val)}") if len(items) > 0: - params[field_name] = [delimiter.join(items)] + form[field_name] = [delimiter.join(items)] elif isinstance(obj, Dict): items = [] for key, value in obj.items(): @@ -582,12 +729,12 @@ def _populate_form(field_name: str, explode: boolean, obj: Any, get_field_name_f continue if explode: - params[key] = [_val_to_string(value)] + form[key] = [_val_to_string(value)] else: - items.append(f'{key}{delimiter}{_val_to_string(value)}') + items.append(f"{key}{delimiter}{_val_to_string(value)}") if len(items) > 0: - params[field_name] = [delimiter.join(items)] + form[field_name] = [delimiter.join(items)] elif isinstance(obj, List): items = [] @@ -596,37 +743,35 @@ def _populate_form(field_name: str, explode: boolean, obj: Any, get_field_name_f continue if explode: - if not field_name in params: - params[field_name] = [] - params[field_name].append(_val_to_string(value)) + if not field_name in form: + form[field_name] = [] + form[field_name].append(_val_to_string(value)) else: items.append(_val_to_string(value)) if len(items) > 0: - params[field_name] = [delimiter.join( - [str(item) for item in items])] + form[field_name] = [delimiter.join([str(item) for item in items])] else: - params[field_name] = [_val_to_string(obj)] + form[field_name] = [_val_to_string(obj)] - return params + return form def _serialize_header(explode: bool, obj: Any) -> str: if obj is None: - return '' + return "" if is_dataclass(obj): items = [] obj_fields: Tuple[Field, ...] = fields(obj) for obj_field in obj_fields: - obj_param_metadata = obj_field.metadata.get('header') + obj_param_metadata = obj_field.metadata.get("header") if not obj_param_metadata: continue - obj_field_name = obj_param_metadata.get( - 'field_name', obj_field.name) - if obj_field_name == '': + obj_field_name = obj_param_metadata.get("field_name", obj_field.name) + if obj_field_name == "": continue val = getattr(obj, obj_field.name) @@ -634,14 +779,13 @@ def _serialize_header(explode: bool, obj: Any) -> str: continue if explode: - items.append( - f'{obj_field_name}={_val_to_string(val)}') + items.append(f"{obj_field_name}={_val_to_string(val)}") else: items.append(obj_field_name) items.append(_val_to_string(val)) if len(items) > 0: - return ','.join(items) + return ",".join(items) elif isinstance(obj, Dict): items = [] @@ -650,13 +794,13 @@ def _serialize_header(explode: bool, obj: Any) -> str: continue if explode: - items.append(f'{key}={_val_to_string(value)}') + items.append(f"{key}={_val_to_string(value)}") else: items.append(key) items.append(_val_to_string(value)) if len(items) > 0: - return ','.join([str(item) for item in items]) + return ",".join([str(item) for item in items]) elif isinstance(obj, List): items = [] @@ -667,38 +811,36 @@ def _serialize_header(explode: bool, obj: Any) -> str: items.append(_val_to_string(value)) if len(items) > 0: - return ','.join(items) + return ",".join(items) else: - return f'{_val_to_string(obj)}' + return f"{_val_to_string(obj)}" - return '' + return "" -def unmarshal_json(data, typ, decoder=None): - unmarshal = make_dataclass('Unmarshal', [('res', typ)], - bases=(DataClassJsonMixin,)) +def unmarshal_json(data, typ, decoder=None, infer_missing=False): + unmarshal = make_dataclass("Unmarshal", [("res", typ)], bases=(DataClassJsonMixin,)) json_dict = json.loads(data) try: - out = unmarshal.from_dict({"res": json_dict}) + out = unmarshal.from_dict({"res": json_dict}, infer_missing=infer_missing) except AttributeError as attr_err: raise AttributeError( - f'unable to unmarshal {data} as {typ} - {attr_err}') from attr_err + f"unable to unmarshal {data} as {typ} - {attr_err}" + ) from attr_err return out.res if decoder is None else decoder(out.res) def marshal_json(val, typ, encoder=None): if not is_optional_type(typ) and val is None: - raise ValueError( - f"Could not marshal None into non-optional type: {typ}") + raise ValueError(f"Could not marshal None into non-optional type: {typ}") - marshal = make_dataclass('Marshal', [('res', typ)], - bases=(DataClassJsonMixin,)) + marshal = make_dataclass("Marshal", [("res", typ)], bases=(DataClassJsonMixin,)) marshaller = marshal(res=val) json_dict = marshaller.to_dict() val = json_dict["res"] if encoder is None else encoder(json_dict["res"]) - return json.dumps(val, separators=(',', ':'), sort_keys=True) + return json.dumps(val, separators=(",", ":"), sort_keys=True) def match_content_type(content_type: str, pattern: str) -> boolean: @@ -706,7 +848,7 @@ def match_content_type(content_type: str, pattern: str) -> boolean: return True msg = Message() - msg['content-type'] = content_type + msg["content-type"] = content_type media_type = msg.get_content_type() if media_type == pattern: @@ -714,7 +856,7 @@ def match_content_type(content_type: str, pattern: str) -> boolean: parts = media_type.split("/") if len(parts) == 2: - if pattern in (f'{parts[0]}/*', f'*/{parts[1]}'): + if pattern in (f"{parts[0]}/*", f"*/{parts[1]}"): return True return False @@ -766,6 +908,33 @@ def bigintdecoder(val): raise ValueError(f"{val} is a float") return int(val) +def integerstrencoder(optional: bool): + def integerstrencode(val: int): + if optional and val is None: + return None + return str(val) + + return integerstrencode + + +def integerstrdecoder(val): + if isinstance(val, float): + raise ValueError(f"{val} is a float") + return int(val) + + +def numberstrencoder(optional: bool): + def numberstrencode(val: float): + if optional and val is None: + return None + return str(val) + + return numberstrencode + + +def numberstrdecoder(val): + return float(val) + def decimalencoder(optional: bool, as_str: bool): def decimalencode(val: Decimal): @@ -839,6 +1008,7 @@ def selective_encoder(val: Any): if type(val) in all_encoders: return all_encoders[type(val)](val) return val + return selective_encoder @@ -852,6 +1022,7 @@ def selective_decoder(val: Any): except (TypeError, ValueError): continue return decoded + return selective_decoder @@ -866,33 +1037,51 @@ def _val_to_string(val) -> str: if isinstance(val, bool): return str(val).lower() if isinstance(val, datetime): - return str(val.isoformat().replace('+00:00', 'Z')) + return str(val.isoformat().replace("+00:00", "Z")) if isinstance(val, Enum): return str(val.value) return str(val) -def _populate_from_globals(param_name: str, value: Any, param_type: str, gbls: Optional[Dict[str, Dict[str, Dict[str, Any]]]]): - if value is None and gbls is not None: - if 'parameters' in gbls: - if param_type in gbls['parameters']: - if param_name in gbls['parameters'][param_type]: - global_value = gbls['parameters'][param_type][param_name] - if global_value is not None: - value = global_value +def _populate_from_globals( + param_name: str, value: Any, param_type: str, gbls: Any +) -> Tuple[Any, bool]: + if gbls is None: + return value, False + + global_fields = fields(gbls) + + found = False + for field in global_fields: + if field.name is not param_name: + continue - return value + found = True + + if value is not None: + return value, True + + global_value = getattr(gbls, field.name) + + param_metadata = field.metadata.get(param_type) + if param_metadata is None: + return value, True + + return global_value, True + + return value, found def decoder_with_discriminator(field_name): def decode_fx(obj): - kls = getattr(sys.modules['sdk.models.shared'], obj[field_name]) + kls = getattr(sys.modules["sdk.models.shared"], obj[field_name]) return unmarshal_json(json.dumps(obj), kls) + return decode_fx def remove_suffix(input_string, suffix): if suffix and input_string.endswith(suffix): - return input_string[:-len(suffix)] + return input_string[: -len(suffix)] return input_string diff --git a/tests/helpers.py b/tests/helpers.py deleted file mode 100644 index b3d0950..0000000 --- a/tests/helpers.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Code generated by Speakeasy (https://speakeasyapi.dev). DO NOT EDIT.""" - -import re - - -def sort_query_parameters(url): - parts = url.split("?") - - if len(parts) == 1: - return url - - query = parts[1] - params = query.split("&") - - params.sort(key=lambda x: x.split('=')[0]) - - return parts[0] + "?" + "&".join(params) - - -def sort_serialized_maps(inp: any, regex: str, delim: str): - - def sort_map(m): - entire_match = m.group(0) - - groups = m.groups() - - for group in groups: - pairs = [] - if '=' in group: - pairs = group.split(delim) - - pairs.sort(key=lambda x: x.split('=')[0]) - else: - values = group.split(delim) - - if len(values) == 1: - pairs = values - else: - pairs = [''] * int(len(values)/2) - # loop though every 2nd item - for i in range(0, len(values), 2): - pairs[int(i/2)] = values[i] + delim + values[i+1] - - pairs.sort(key=lambda x: x.split(delim)[0]) - - entire_match = entire_match.replace(group, delim.join(pairs)) - - return entire_match - - if isinstance(inp, str): - return re.sub(regex, sort_map, inp) - elif isinstance(inp, list): - for i, v in enumerate(inp): - inp[i] = sort_serialized_maps(v, regex, delim) - return inp - elif isinstance(inp, dict): - for k, v in inp.items(): - inp[k] = sort_serialized_maps(v, regex, delim) - return inp - else: - raise Exception("Unsupported type")