diff --git a/flask_accepts/decorators/decorators.py b/flask_accepts/decorators/decorators.py index 84d4f18..fe1b00a 100644 --- a/flask_accepts/decorators/decorators.py +++ b/flask_accepts/decorators/decorators.py @@ -135,12 +135,13 @@ def inner(*args, **kwargs): f"Error parsing request body: {schema_error}" ) if hasattr(error, "data"): - error.data["errors"].update({"schema_errors": schema_error}) + error.data["errors"].setdefault('schema_errors', {}).update(schema_error) else: error.data = {"schema_errors": schema_error} # Handle Marshmallow schema for query params if query_params_schema: + schema_error = None request_args = _convert_multidict_values_to_schema( request.args, query_params_schema) @@ -155,12 +156,13 @@ def inner(*args, **kwargs): f"Error parsing query params: {schema_error}" ) if hasattr(error, "data"): - error.data["errors"].update({"schema_errors": schema_error}) + error.data["errors"].setdefault('schema_errors', {}).update(schema_error) else: error.data = {"schema_errors": schema_error} # Handle Marshmallow schema for headers if headers_schema: + schema_error = None request_headers = _convert_multidict_values_to_schema( request.headers, headers_schema) @@ -175,12 +177,13 @@ def inner(*args, **kwargs): f"Error parsing headers: {schema_error}" ) if hasattr(error, "data"): - error.data["errors"].update({"schema_errors": schema_error}) + error.data["errors"].setdefault('schema_errors', {}).update(schema_error) else: error.data = {"schema_errors": schema_error} # Handle Marshmallow schema for form data if form_schema: + schema_error = None request_form = _convert_multidict_values_to_schema( request.form, form_schema) @@ -195,7 +198,7 @@ def inner(*args, **kwargs): f"Error parsing form data: {schema_error}" ) if hasattr(error, "data"): - error.data["errors"].update({"schema_errors": schema_error}) + error.data["errors"].setdefault('schema_errors', {}).update(schema_error) else: error.data = {"schema_errors": schema_error} diff --git a/flask_accepts/decorators/decorators_test.py b/flask_accepts/decorators/decorators_test.py index bbdfb0a..7447455 100644 --- a/flask_accepts/decorators/decorators_test.py +++ b/flask_accepts/decorators/decorators_test.py @@ -1,3 +1,4 @@ +from pytest import mark from flask import request from flask_restx import Resource, Api from marshmallow import Schema, fields @@ -303,6 +304,39 @@ def test(): assert resp.status_code == 400 +@mark.parametrize('url,data,expected_code,expected_body', [ + ('/test?foo=12', '{"bar": "yo"}', 200, {'success': True}), + ('/test', '{"bar": "yo"}', 200, {'success': True}), + ('/test', '{"bar": 1}', 400, {'schema_errors': {'bar': ['Not a valid string.']}}), + ('/test?foo=yeah', '{"bar": 1}', 400, {'errors': {'foo': "invalid literal for int() with base 10: 'yeah'", + 'schema_errors': {'bar': ['Not a valid string.'], + 'foo': ['Not a valid integer.']}}, + 'message': 'Input payload validation failed'}), +]) +def test_failure_when_both_query_param_and_request_schema_are_provided( + url, data, expected_code, expected_body, app, client, +): # noqa + api = Api(app) + + class TestSchema(Schema): + bar = fields.String(required=True) + + class TestQuerySchema(Schema): + foo = fields.Integer(required=False, missing=None, default=None) + + @api.route("/test") + class TestResource(Resource): + @accepts(schema=TestSchema, query_params_schema=TestQuerySchema, api=api) + def post(self): + return {'success': True} + + with client as cl: + # + resp = cl.post(url, data=data, content_type='application/json') + assert resp.status_code == expected_code + assert resp.json == expected_body + + def test_accepts_with_header_schema_single_value(app, client): # noqa class TestSchema(Schema): Foo = fields.Integer(required=True)