Skip to content

Commit ec2b14b

Browse files
committed
Improve typing in the module
Fixes pallets-eco#65 - make it pass mypy tests (i.e., improve type hinting) - add mypy test to CI - add py.typed to setup so that other packages recognize it's typed - enable async view function decoration (via ensure_sync() as noted on https://flask.palletsprojects.com/en/latest/async-await/#extensions) Signed-off-by: Marek Pikuła <[email protected]>
1 parent 064976a commit ec2b14b

File tree

7 files changed

+191
-115
lines changed

7 files changed

+191
-115
lines changed

Diff for: .github/workflows/tests.yml

+20
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,26 @@ name: Tests
33
on: [push, pull_request]
44

55
jobs:
6+
mypy:
7+
runs-on: ubuntu-latest
8+
strategy:
9+
fail-fast: false
10+
11+
steps:
12+
- uses: actions/checkout@v1
13+
- name: Set up Python 3.7
14+
uses: actions/setup-python@v1
15+
with:
16+
python-version: "3.7"
17+
- name: Install dependencies
18+
if: steps.cache-pip.outputs.cache-hit != 'true'
19+
run: |
20+
python -m pip install --upgrade pip
21+
pip install -r requirements/test.pip
22+
- name: Run mypy satic check
23+
run: |
24+
python3 -m mypy flask_pydantic/
25+
626
build:
727
runs-on: ${{ matrix.os }}
828
strategy:

Diff for: flask_pydantic/converters.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from typing import Type
1+
from typing import Dict, List, Type, Union
22

33
from pydantic import BaseModel
4-
from werkzeug.datastructures import ImmutableMultiDict
4+
from werkzeug.datastructures import MultiDict
55

66

77
def convert_query_params(
8-
query_params: ImmutableMultiDict, model: Type[BaseModel]
9-
) -> dict:
8+
query_params: MultiDict[str, str], model: Type[BaseModel]
9+
) -> Dict[str, Union[str, List[str]]]:
1010
"""
1111
group query parameters into lists if model defines them
1212

Diff for: flask_pydantic/core.py

+146-97
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,28 @@
1+
from collections.abc import Iterable
12
from functools import wraps
2-
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union
3+
from typing import (
4+
TYPE_CHECKING,
5+
Any,
6+
Awaitable,
7+
Callable,
8+
Dict,
9+
List,
10+
Optional,
11+
Tuple,
12+
Type,
13+
Union,
14+
)
15+
16+
from flask import Response, current_app, jsonify, request
17+
from flask.typing import ResponseReturnValue, RouteCallable
18+
19+
try:
20+
from flask_restful import ( # type: ignore
21+
original_flask_make_response as make_response,
22+
)
23+
except ImportError:
24+
from flask import make_response
325

4-
from flask import Response, current_app, jsonify, make_response, request
526
from pydantic import BaseModel, ValidationError
627
from pydantic.tools import parse_obj_as
728

@@ -13,22 +34,33 @@
1334
)
1435
from .exceptions import ValidationError as FailedValidation
1536

16-
try:
17-
from flask_restful import original_flask_make_response as make_response
18-
except ImportError:
19-
pass
37+
if TYPE_CHECKING:
38+
from pydantic.error_wrappers import ErrorDict
39+
40+
41+
ModelResponseReturnValue = Union[ResponseReturnValue, BaseModel]
42+
ModelRouteCallable = Union[
43+
Callable[..., ModelResponseReturnValue],
44+
Callable[..., Awaitable[ModelResponseReturnValue]],
45+
]
2046

2147

2248
def make_json_response(
2349
content: Union[BaseModel, Iterable[BaseModel]],
2450
status_code: int,
2551
by_alias: bool,
2652
exclude_none: bool = False,
27-
many: bool = False,
2853
) -> Response:
2954
"""serializes model, creates JSON response with given status code"""
30-
if many:
31-
js = f"[{', '.join([model.json(exclude_none=exclude_none, by_alias=by_alias) for model in content])}]"
55+
if not isinstance(content, BaseModel):
56+
js = "["
57+
js += ", ".join(
58+
[
59+
model.json(exclude_none=exclude_none, by_alias=by_alias)
60+
for model in content
61+
]
62+
)
63+
js += "]"
3264
else:
3365
js = content.json(exclude_none=exclude_none, by_alias=by_alias)
3466
response = make_response(js, status_code)
@@ -56,9 +88,9 @@ def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel
5688
return [model(**fields) for fields in content]
5789
except TypeError:
5890
# iteration through `content` fails
59-
err = [
91+
err: List["ErrorDict"] = [
6092
{
61-
"loc": ["root"],
93+
"loc": ("root",),
6294
"msg": "is not an array of objects",
6395
"type": "type_error.array",
6496
}
@@ -68,30 +100,53 @@ def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel
68100
raise ManyModelValidationError(ve.errors())
69101

70102

71-
def validate_path_params(func: Callable, kwargs: dict) -> Tuple[dict, list]:
72-
errors = []
103+
def validate_path_params(
104+
func: ModelRouteCallable, kwargs: Dict[str, Any]
105+
) -> Tuple[Dict[str, Any], List["ErrorDict"]]:
106+
errors: List["ErrorDict"] = []
73107
validated = {}
74108
for name, type_ in func.__annotations__.items():
75109
if name in {"query", "body", "form", "return"}:
76110
continue
77111
try:
78112
value = parse_obj_as(type_, kwargs.get(name))
79113
validated[name] = value
80-
except ValidationError as e:
81-
err = e.errors()[0]
82-
err["loc"] = [name]
114+
except ValidationError as error:
115+
err = error.errors()[0]
116+
err["loc"] = (name,)
83117
errors.append(err)
84118
kwargs = {**kwargs, **validated}
85119
return kwargs, errors
86120

87121

88-
def get_body_dict(**params):
89-
data = request.get_json(**params)
122+
def get_body_dict(**params: Dict[str, Any]) -> Any:
123+
data = request.get_json(**params) # type: ignore
90124
if data is None and params.get("silent"):
91125
return {}
92126
return data
93127

94128

129+
def _ensure_model_kwarg(
130+
kwarg_name: str,
131+
from_validate: Optional[Type[BaseModel]],
132+
func: ModelRouteCallable,
133+
) -> Tuple[Optional[Type[BaseModel]], bool]:
134+
"""Get model information either from wrapped function or validate kwargs."""
135+
in_func_kwargs = func.__annotations__.get(kwarg_name)
136+
if in_func_kwargs is None:
137+
return from_validate, False
138+
assert isinstance(in_func_kwargs, type) and issubclass(
139+
in_func_kwargs, BaseModel
140+
), "Model in function arguments needs to be a BaseModel."
141+
142+
# Ensure that the most "detailed" model is used.
143+
if from_validate is None:
144+
return in_func_kwargs, True
145+
if issubclass(in_func_kwargs, from_validate):
146+
return in_func_kwargs, True
147+
return from_validate, True
148+
149+
95150
def validate(
96151
body: Optional[Type[BaseModel]] = None,
97152
query: Optional[Type[BaseModel]] = None,
@@ -100,7 +155,7 @@ def validate(
100155
response_many: bool = False,
101156
request_body_many: bool = False,
102157
response_by_alias: bool = False,
103-
get_json_params: Optional[dict] = None,
158+
get_json_params: Optional[Dict[str, Any]] = None,
104159
form: Optional[Type[BaseModel]] = None,
105160
):
106161
"""
@@ -163,105 +218,93 @@ def test_route_kwargs(query:Query, body:Body, form:Form):
163218
-> that will render JSON response with serialized MyModel instance
164219
"""
165220

166-
def decorate(func: Callable) -> Callable:
221+
def decorate(func: ModelRouteCallable) -> RouteCallable:
167222
@wraps(func)
168-
def wrapper(*args, **kwargs):
169-
q, b, f, err = None, None, None, {}
170-
kwargs, path_err = validate_path_params(func, kwargs)
171-
if path_err:
172-
err["path_params"] = path_err
173-
query_in_kwargs = func.__annotations__.get("query")
174-
query_model = query_in_kwargs or query
175-
if query_model:
223+
def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> ResponseReturnValue:
224+
q, b, f, err = None, None, None, FailedValidation()
225+
func_kwargs, path_err = validate_path_params(func, kwargs)
226+
if len(path_err) > 0:
227+
err.path_params = path_err
228+
query_model, query_in_kwargs = _ensure_model_kwarg("query", query, func)
229+
if query_model is not None:
176230
query_params = convert_query_params(request.args, query_model)
177231
try:
178232
q = query_model(**query_params)
179233
except ValidationError as ve:
180-
err["query_params"] = ve.errors()
181-
body_in_kwargs = func.__annotations__.get("body")
182-
body_model = body_in_kwargs or body
183-
if body_model:
234+
err.query_params = ve.errors()
235+
body_model, body_in_kwargs = _ensure_model_kwarg("body", body, func)
236+
if body_model is not None:
184237
body_params = get_body_dict(**(get_json_params or {}))
185-
if "__root__" in body_model.__fields__:
186-
try:
187-
b = body_model(__root__=body_params).__root__
188-
except ValidationError as ve:
189-
err["body_params"] = ve.errors()
190-
elif request_body_many:
191-
try:
238+
try:
239+
if "__root__" in body_model.__fields__:
240+
b = body_model(__root__=body_params).__root__ # type: ignore
241+
elif request_body_many:
192242
b = validate_many_models(body_model, body_params)
193-
except ManyModelValidationError as e:
194-
err["body_params"] = e.errors()
195-
else:
196-
try:
243+
else:
197244
b = body_model(**body_params)
198-
except TypeError:
199-
content_type = request.headers.get("Content-Type", "").lower()
200-
media_type = content_type.split(";")[0]
201-
if media_type != "application/json":
202-
return unsupported_media_type_response(content_type)
203-
else:
204-
raise JsonBodyParsingError()
205-
except ValidationError as ve:
206-
err["body_params"] = ve.errors()
207-
form_in_kwargs = func.__annotations__.get("form")
208-
form_model = form_in_kwargs or form
209-
if form_model:
245+
except (ValidationError, ManyModelValidationError) as error:
246+
err.body_params = error.errors()
247+
except TypeError as error:
248+
content_type = request.headers.get("Content-Type", "").lower()
249+
media_type = content_type.split(";")[0]
250+
if media_type != "application/json":
251+
return unsupported_media_type_response(content_type)
252+
else:
253+
raise JsonBodyParsingError() from error
254+
form_model, form_in_kwargs = _ensure_model_kwarg("form", form, func)
255+
if form_model is not None:
210256
form_params = request.form
211-
if "__root__" in form_model.__fields__:
212-
try:
213-
f = form_model(__root__=form_params).__root__
214-
except ValidationError as ve:
215-
err["form_params"] = ve.errors()
216-
else:
217-
try:
257+
try:
258+
if "__root__" in form_model.__fields__:
259+
f = form_model(__root__=form_params).__root__ # type: ignore
260+
else:
218261
f = form_model(**form_params)
219-
except TypeError:
220-
content_type = request.headers.get("Content-Type", "").lower()
221-
media_type = content_type.split(";")[0]
222-
if media_type != "multipart/form-data":
223-
return unsupported_media_type_response(content_type)
224-
else:
225-
raise JsonBodyParsingError
226-
except ValidationError as ve:
227-
err["form_params"] = ve.errors()
228-
request.query_params = q
229-
request.body_params = b
230-
request.form_params = f
262+
except TypeError as error:
263+
content_type = request.headers.get("Content-Type", "").lower()
264+
media_type = content_type.split(";")[0]
265+
if media_type != "multipart/form-data":
266+
return unsupported_media_type_response(content_type)
267+
else:
268+
raise JsonBodyParsingError() from error
269+
except ValidationError as ve:
270+
err.form_params = ve.errors()
271+
request.query_params = q # type: ignore
272+
request.body_params = b # type: ignore
273+
request.form_params = f # type: ignore
231274
if query_in_kwargs:
232-
kwargs["query"] = q
275+
func_kwargs["query"] = q
233276
if body_in_kwargs:
234-
kwargs["body"] = b
277+
func_kwargs["body"] = b
235278
if form_in_kwargs:
236-
kwargs["form"] = f
279+
func_kwargs["form"] = f
237280

238-
if err:
281+
if err.check():
239282
if current_app.config.get(
240283
"FLASK_PYDANTIC_VALIDATION_ERROR_RAISE", False
241284
):
242-
raise FailedValidation(**err)
285+
raise err
243286
else:
244287
status_code = current_app.config.get(
245288
"FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE", 400
246289
)
247290
return make_response(
248-
jsonify({"validation_error": err}),
249-
status_code
291+
jsonify({"validation_error": err.to_dict()}), status_code
250292
)
251-
res = func(*args, **kwargs)
293+
res: ModelResponseReturnValue = current_app.ensure_sync(func)(
294+
*args, **func_kwargs
295+
)
252296

253297
if response_many:
254-
if is_iterable_of_models(res):
255-
return make_json_response(
256-
res,
257-
on_success_status,
258-
by_alias=response_by_alias,
259-
exclude_none=exclude_none,
260-
many=True,
261-
)
262-
else:
298+
if not is_iterable_of_models(res):
263299
raise InvalidIterableOfModelsException(res)
264300

301+
return make_json_response(
302+
res, # type: ignore # Iterability and type is ensured above.
303+
on_success_status,
304+
by_alias=response_by_alias,
305+
exclude_none=exclude_none,
306+
)
307+
265308
if isinstance(res, BaseModel):
266309
return make_json_response(
267310
res,
@@ -275,23 +318,29 @@ def wrapper(*args, **kwargs):
275318
and len(res) in [2, 3]
276319
and isinstance(res[0], BaseModel)
277320
):
278-
headers = None
321+
headers: Optional[
322+
Union[Dict[str, Any], Tuple[Any, ...], List[Any]]
323+
] = None
279324
status = on_success_status
280325
if isinstance(res[1], (dict, tuple, list)):
281326
headers = res[1]
282-
elif len(res) == 3 and isinstance(res[2], (dict, tuple, list)):
283-
status = res[1]
284-
headers = res[2]
285-
else:
327+
elif isinstance(res[1], int):
286328
status = res[1]
287329

330+
# Following type ignores should be fixed once
331+
# https://github.com/python/mypy/issues/1178 is fixed.
332+
if len(res) == 3 and isinstance(
333+
res[2], (dict, tuple, list) # type: ignore[misc]
334+
):
335+
headers = res[2] # type: ignore[misc]
336+
288337
ret = make_json_response(
289338
res[0],
290339
status,
291340
exclude_none=exclude_none,
292341
by_alias=response_by_alias,
293342
)
294-
if headers:
343+
if headers is not None:
295344
ret.headers.update(headers)
296345
return ret
297346

0 commit comments

Comments
 (0)