Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions tests/unit/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,23 @@ def test_replace_json_post_data_parameters():
assert request_data == expected_data


def test_replace_recursive_post_data_parameters():
body = b'{"nested": "change", "one": {"key": "secret", "nested": {"key": "secret"}}}'
request = Request("POST", "http://google.com", body, {})
request.headers["Content-Type"] = "application/json"
replace_post_data_parameters(
request,
[
("key", None),
("nested", "aboba"),
],
recursive=True
)
request_data = json.loads(request.body)
expected_data = json.loads('{"nested": "aboba", "one": {"nested": "aboba"}}')
assert request_data == expected_data


def test_remove_json_post_data_parameters():
# Test the backward-compatible API wrapper.
body = b'{"id": "secret", "foo": "bar", "baz": "qux"}'
Expand Down
7 changes: 6 additions & 1 deletion vcr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,14 @@ def _build_before_record_request(self, options):
functools.partial(filters.replace_query_parameters, replacements=replacements),
)
if filter_post_data_parameters:
recursive = options.get("recursive_filter", False)
replacements = [p if isinstance(p, tuple) else (p, None) for p in filter_post_data_parameters]
filter_functions.append(
functools.partial(filters.replace_post_data_parameters, replacements=replacements),
functools.partial(
filters.replace_post_data_parameters,
replacements=replacements,
recursive=recursive,
),
)

hosts_to_ignore = set(ignore_hosts)
Expand Down
52 changes: 37 additions & 15 deletions vcr/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,35 @@ def remove_query_parameters(request, query_parameters_to_remove):
return replace_query_parameters(request, replacements)


def replace_post_data_parameters(request, replacements):
def filtering_body(request, body_data, replacements):
"""Filtering the request body by default to only top level keys"""
for k, rv in replacements.items():
if k in body_data:
ov = body_data.pop(k)
if callable(rv):
rv = rv(key=k, value=ov, request=request)
if rv is not None:
body_data[k] = rv


def recursive_filtering_body(request, body_data, replacements):
"""Recursive filtering the request body with nested keys"""
for k, ov in list(body_data.items()):
if isinstance(ov, dict):
recursive_filtering_body(request, ov, replacements)
if not ov:
body_data.pop(k)
if k in replacements:
rv = replacements[k]
if callable(rv):
rv = rv(key=k, value=ov, request=request)
if rv is not None:
body_data[k] = rv
elif k in body_data:
body_data.pop(k)


def replace_post_data_parameters(request, replacements, recursive=False):
"""Replace post data in request--either form data or json--according to replacements.

The replacements should be a list of (key, value) pairs where the value can be any of:
Expand All @@ -86,23 +114,17 @@ def replace_post_data_parameters(request, replacements):
if request.method == "POST" and not isinstance(request.body, BytesIO):
if isinstance(request.body, dict):
new_body = request.body.copy()
for k, rv in replacements.items():
if k in new_body:
ov = new_body.pop(k)
if callable(rv):
rv = rv(key=k, value=ov, request=request)
if rv is not None:
new_body[k] = rv
if recursive:
recursive_filtering_body(request, new_body, replacements)
else:
filtering_body(request, new_body, replacements)
request.body = new_body
elif request.headers.get("Content-Type") == "application/json":
json_data = json.loads(request.body)
for k, rv in replacements.items():
if k in json_data:
ov = json_data.pop(k)
if callable(rv):
rv = rv(key=k, value=ov, request=request)
if rv is not None:
json_data[k] = rv
if recursive:
recursive_filtering_body(request, json_data, replacements)
else:
filtering_body(request, json_data, replacements)
request.body = json.dumps(json_data).encode("utf-8")
else:
if isinstance(request.body, str):
Expand Down