From f080053f48a687a2c48e1eb888d427cbdbbd5b53 Mon Sep 17 00:00:00 2001 From: Richard Si Date: Sun, 9 Mar 2025 18:40:11 -0400 Subject: [PATCH] feat: Support Python subprocesses Pytest-socket should be able to block socket calls in Python subprocesses created by tests (e.g., pip's test suite). To hook into new Python subprocesses, we can use a .pth file to run code during Python startup. This is what pytest-cov does to automagically support subprocess coverage tracking. State is passed to the .pth file via the _PYTEST_SOCKET_SUBPROCESS environment variable with JSON as the encoding format. Some refactoring was necessary to allow for the right pytest-socket state to be easily passed down to the .pth file (w/o recalculating or rerunning the entirety of pytest_socket.py). Testing-wise, majority of the tests contained in test_socket.py and test_restrict_hosts.py were copied as subprocess tests. While this doesn't cover every single surface, this should be sufficient to ensure the subprocess support is working properly. --- README.md | 1 + pyproject.toml | 1 + pytest_socket.embed | 48 +++++ pytest_socket.pth | 1 + pytest_socket.py | 63 ++++++- tests/common.py | 6 + tests/test_restrict_hosts.py | 8 +- tests/test_subprocess.py | 350 +++++++++++++++++++++++++++++++++++ 8 files changed, 461 insertions(+), 17 deletions(-) create mode 100644 pytest_socket.embed create mode 100644 pytest_socket.pth create mode 100644 tests/test_subprocess.py diff --git a/README.md b/README.md index 34b5bdc..c942c0f 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ tests to ensure network calls are prevented. ## Features - Disables all network calls flowing through Python\'s `socket` interface. +- Python subprocesses are supported ## Requirements diff --git a/pyproject.toml b/pyproject.toml index c1ffc2a..86069ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ include = [ { path = "README.md", format = "sdist" }, { path = "tests", format = "sdist" }, { path = ".flake8", format = "sdist" }, + { path = "pytest_socket.pth", format = ["sdist", "wheel"] } ] classifiers = [ "Development Status :: 4 - Beta", diff --git a/pytest_socket.embed b/pytest_socket.embed new file mode 100644 index 0000000..19df7e9 --- /dev/null +++ b/pytest_socket.embed @@ -0,0 +1,48 @@ +"""Inject pytest-socket into Python subprocesses via a .pth file. + +To update the .pth file, simply run this script which will write the code +below to a .pth file in the same directory as a single line. +""" + +import os +os.environ["_PYTEST_SOCKET_SUBPROCESS"] = "" + +# .PTH START +if config := os.getenv("_PYTEST_SOCKET_SUBPROCESS", None): + import json + state = None + try: + import socket + from pytest_socket import disable_socket, _create_guarded_connect + state = json.loads(config) + + if state["mode"] == "disable": + disable_socket(allow_unix_socket=state["allow_unix_socket"]) + elif state["mode"] == "allow-hosts": + socket.socket.connect = _create_guarded_connect( + allowed_hosts=state["allowed_hosts"], + allow_unix_socket=state["allow_unix_socket"], + _pretty_allowed_list=state["_pretty_allowed_list"] + ) + + except Exception as exc: + import sys + sys.stderr.write( + "pytest-socket: Failed to set up subprocess socket patching.\n" + f"Configuration: {state}\n" + f"{exc.__class__.__name__}: {exc}\n" + ) +# .PTH END + +if __name__ == "__main__": + from pathlib import Path + + src = Path(__file__) + dst = src.with_suffix(".pth") + lines = src.read_text().splitlines() + code = "\n".join(lines[lines.index("# .PTH START") + 1 : lines.index("# .PTH END")]) + + print(f"Writing to {dst}") + # Only lines beginning with an import will be executed. + # https://docs.python.org/3/library/site.html + dst.write_text(f"import os; exec({code!r})\n") diff --git a/pytest_socket.pth b/pytest_socket.pth new file mode 100644 index 0000000..28f1bcd --- /dev/null +++ b/pytest_socket.pth @@ -0,0 +1 @@ +import os; exec('if config := os.getenv("_PYTEST_SOCKET_SUBPROCESS", None):\n import json\n state = None\n try:\n import socket\n from pytest_socket import disable_socket, _create_guarded_connect\n state = json.loads(config)\n\n if state["mode"] == "disable":\n disable_socket(allow_unix_socket=state["allow_unix_socket"])\n elif state["mode"] == "allow-hosts":\n socket.socket.connect = _create_guarded_connect(\n allowed_hosts=state["allowed_hosts"],\n allow_unix_socket=state["allow_unix_socket"],\n _pretty_allowed_list=state["_pretty_allowed_list"]\n )\n\n except Exception as exc:\n import sys\n sys.stderr.write(\n "pytest-socket: Failed to set up subprocess socket patching.\\n"\n f"Configuration: {state}\\n"\n f"{exc.__class__.__name__}: {exc}\\n"\n )') diff --git a/pytest_socket.py b/pytest_socket.py index 22bdc71..40b4696 100644 --- a/pytest_socket.py +++ b/pytest_socket.py @@ -1,5 +1,7 @@ import ipaddress import itertools +import json +import os import socket import typing from collections import defaultdict @@ -7,10 +9,26 @@ import pytest +_SUBPROCESS_ENVVAR = "_PYTEST_SOCKET_SUBPROCESS" _true_socket = socket.socket _true_connect = socket.socket.connect +def update_subprocess_config(config: typing.Dict[str, object]) -> None: + """Enable pytest-socket in Python subprocesses. + + The configuration will be read by the .pth file to mirror the + restrictions in the main process. + """ + os.environ[_SUBPROCESS_ENVVAR] = json.dumps(config) + + +def delete_subprocess_config() -> None: + """Disable pytest-socket in Python subprocesses.""" + if _SUBPROCESS_ENVVAR in os.environ: + del os.environ[_SUBPROCESS_ENVVAR] + + class SocketBlockedError(RuntimeError): def __init__(self, *_args, **_kwargs): super().__init__("A test tried to use socket.socket.") @@ -103,11 +121,15 @@ def __new__(cls, family=-1, type=-1, proto=-1, fileno=None): raise SocketBlockedError() socket.socket = GuardedSocket + update_subprocess_config( + {"mode": "disable", "allow_unix_socket": allow_unix_socket} + ) def enable_socket(): """re-enable socket.socket to enable the Internet. useful in testing.""" socket.socket = _true_socket + delete_subprocess_config() def pytest_configure(config): @@ -249,6 +271,25 @@ def normalize_allowed_hosts( return ip_hosts +def _create_guarded_connect( + allowed_hosts: typing.Sequence[str], + allow_unix_socket: bool, + _pretty_allowed_list: typing.Sequence[str], +) -> typing.Callable: + """Create a function to replace socket.connect.""" + + def guarded_connect(inst, *args): + host = host_from_connect_args(args) + if host in allowed_hosts or ( + _is_unix_socket(inst.family) and allow_unix_socket + ): + return _true_connect(inst, *args) + + raise SocketConnectBlockedError(_pretty_allowed_list, host) + + return guarded_connect + + def socket_allow_hosts( allowed: typing.Union[str, typing.List[str], None] = None, allow_unix_socket: bool = False, @@ -276,19 +317,21 @@ def socket_allow_hosts( ] ) - def guarded_connect(inst, *args): - host = host_from_connect_args(args) - if host in allowed_ip_hosts_and_hostnames or ( - _is_unix_socket(inst.family) and allow_unix_socket - ): - return _true_connect(inst, *args) - - raise SocketConnectBlockedError(allowed_list, host) - - socket.socket.connect = guarded_connect + socket.socket.connect = _create_guarded_connect( + allowed_ip_hosts_and_hostnames, allow_unix_socket, allowed_list + ) + update_subprocess_config( + { + "mode": "allow-hosts", + "allowed_hosts": list(allowed_ip_hosts_and_hostnames), + "allow_unix_socket": allow_unix_socket, + "_pretty_allowed_list": allowed_list, + } + ) def _remove_restrictions(): """restore socket.socket.* to allow access to the Internet. useful in testing.""" socket.socket = _true_socket socket.socket.connect = _true_connect + delete_subprocess_config() diff --git a/tests/common.py b/tests/common.py index fe731f0..21dab22 100644 --- a/tests/common.py +++ b/tests/common.py @@ -14,3 +14,9 @@ def assert_socket_blocked(result, passed=0, skipped=0, failed=1): result.stdout.fnmatch_lines( "*Socket*Blocked*Error: A test tried to use socket.socket.*" ) + + +def assert_host_blocked(result, host): + result.stdout.fnmatch_lines( + f'*A test tried to use socket.socket.connect() with host "{host}"*' + ) diff --git a/tests/test_restrict_hosts.py b/tests/test_restrict_hosts.py index 26c8ae9..8dd76a8 100644 --- a/tests/test_restrict_hosts.py +++ b/tests/test_restrict_hosts.py @@ -4,7 +4,7 @@ import pytest -from pytest_socket import normalize_allowed_hosts +from pytest_socket import assert_host_blocked, normalize_allowed_hosts localhost = "127.0.0.1" @@ -46,12 +46,6 @@ def {2}(): """ -def assert_host_blocked(result, host): - result.stdout.fnmatch_lines( - f'*A test tried to use socket.socket.connect() with host "{host}"*' - ) - - @pytest.fixture def assert_connect(httpbin, testdir): def assert_socket_connect(should_pass, **kwargs): diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py new file mode 100644 index 0000000..4f1f4e6 --- /dev/null +++ b/tests/test_subprocess.py @@ -0,0 +1,350 @@ +"""Verify that pytest-socket works in Python subprocess. + +These tests are based off test_socket.py and test_restrict_hosts.py +""" + +import inspect +import json +import os +import textwrap +from importlib.metadata import Distribution +from typing import List, Optional, Union + +import pytest + +from conftest import unix_sockets_only +from tests.common import assert_host_blocked, assert_socket_blocked + + +def is_pytest_socket_editably_installed() -> bool: + pytest_socket_package = Distribution.from_name("pytest_socket") + direct_url = json.loads(pytest_socket_package.read_text("direct_url.json")) + editable = direct_url.get("dir_info", {}).get("editable", False) + if editable and "CI" in os.environ: + raise RuntimeError("CI should be testing against a normal install") + + return editable + + +pytestmark = pytest.mark.skipif( + is_pytest_socket_editably_installed(), + reason=".pth files don't work under an editable install", +) + + +localhost = "127.0.0.1" + +# These templates are only used by the AllowHosts tests. +connect_code_template = """ + {3} + def {2}(): + run("import socket; socket.socket().connect(('{0}', {1}))") +""" +urlopen_code_template = """ + {3} + def {2}(): + run(multiline_code(''' + from urllib.request import urlopen + assert urlopen('http://{0}:{1}/').getcode() == 200 + ''')) +""" +urlopen_hostname_code_template = """ + {3} + def {2}(): + # Skip {{1}} as we expect {{0}} to be the full hostname with or without port + run(multiline_code(''' + from urllib.request import urlopen + assert urlopen('http://{0}').getcode() == 200 + ''')) +""" + + +def make_test_file(testdir, code): + template = textwrap.dedent( + """ + import textwrap + import subprocess + import sys + import pytest + + def run(code): + subprocess.run([sys.executable, '-c', code], check=True) + + def multiline_code(code): + return textwrap.dedent(code) + + SOCKET_CODE = 'import socket; socket.socket(socket.AF_INET, socket.SOCK_STREAM)' + UNIX_SOCKET_CODE = multiline_code(''' + import socket + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + ''') + URLLIB_CODE = multiline_code(''' + try: + from urllib.request import urlopen + except ImportError: + from urllib2 import urlopen + urlopen('https://httpstat.us/200').getcode() == 200 + ''') + + {0} + """ + ) + code = textwrap.dedent(code) + testdir.makepyfile(template.format(code)) + + +@pytest.fixture +def assert_connect(httpbin, testdir): + def assert_socket_connect( + should_pass: bool, + *, + host: str = httpbin.host, + cli_arg: Optional[str] = None, + mark_arg: Optional[Union[str, List[str]]] = None, + code_template: str = connect_code_template, + ): + # get the name of the calling function + test_name = inspect.stack()[1][3] + mark = "" + + if mark_arg: + if isinstance(mark_arg, str): + mark = f'@pytest.mark.allow_hosts("{mark_arg}")' + elif isinstance(mark_arg, list): + hosts = '","'.join(mark_arg) + mark = f'@pytest.mark.allow_hosts(["{hosts}"])' + code = code_template.format(host, httpbin.port, test_name, mark) + make_test_file(testdir, code) + + if cli_arg: + result = testdir.runpytest(f"--allow-hosts={cli_arg}") + else: + result = testdir.runpytest() + + if should_pass: + result.assert_outcomes(passed=1, skipped=0, failed=0) + else: + result.assert_outcomes(passed=0, skipped=0, failed=1) + assert_host_blocked(result, host) + + return result + + return assert_socket_connect + + +def test_disable_socket_fixture(testdir): + make_test_file( + testdir, + """ + def test_socket(socket_disabled): + run(SOCKET_CODE) + """, + ) + result = testdir.runpytest() + assert_socket_blocked(result) + + +def test_disable_socket_marker(testdir): + make_test_file( + testdir, + """ + @pytest.mark.disable_socket + def test_socket(): + run(SOCKET_CODE) + """, + ) + result = testdir.runpytest() + assert_socket_blocked(result) + + +def test_enable_socket_fixture(testdir): + make_test_file( + testdir, + """ + def test_socket(socket_enabled): + run(SOCKET_CODE) + """, + ) + result = testdir.runpytest("--disable-socket") + result.assert_outcomes(passed=1, skipped=0, failed=0) + with pytest.raises(BaseException): + assert_socket_blocked(result) + + +def test_enable_socket_marker(testdir): + make_test_file( + testdir, + """ + @pytest.mark.enable_socket + def test_socket(): + run(SOCKET_CODE) + """, + ) + result = testdir.runpytest("--disable-socket") + result.assert_outcomes(passed=1, skipped=0, failed=0) + with pytest.raises(BaseException): + assert_socket_blocked(result) + + +@pytest.mark.parametrize("mode", ["default", "enabled"]) +def test_urllib_succeeds(testdir, mode: str): + fixture = "socket_enabled" if mode == "enabled" else "" + make_test_file( + testdir, + f""" + def test_socket_urllib({fixture}): + run(URLLIB_CODE) + """, + ) + result = testdir.runpytest() + result.assert_outcomes(passed=1, skipped=0, failed=0) + with pytest.raises(BaseException): + assert_socket_blocked(result) + + +def test_disabled_urllib_fails(testdir): + make_test_file( + testdir, + """ + @pytest.mark.disable_socket + def test_disable_socket_urllib(): + run(URLLIB_CODE) + """, + ) + result = testdir.runpytest() + assert_socket_blocked(result) + + +def test_mix_and_match(testdir, socket_enabled): + make_test_file( + testdir, + """ + def test_socket1(): + run(SOCKET_CODE) + def test_socket_enabled(socket_enabled): + run(SOCKET_CODE) + def test_socket2(): + run(SOCKET_CODE) + """, + ) + result = testdir.runpytest("--disable-socket") + result.assert_outcomes(passed=1, skipped=0, failed=2) + + +def test_socket_subclass_is_still_blocked(testdir): + make_test_file( + testdir, + """ + code = multiline_code(''' + import socket + class MySocket(socket.socket): + pass + MySocket(socket.AF_INET, socket.SOCK_STREAM) + ''') + + @pytest.mark.disable_socket + def test_subclass_is_still_blocked(): + run(code) + """, + ) + result = testdir.runpytest() + assert_socket_blocked(result) + + +@unix_sockets_only +@pytest.mark.parametrize("state", ["blocked", "allowed"]) +def test_unix_domain_sockets_with_disable_socket(testdir, state: str): + make_test_file( + testdir, + """ + def test_inet(): + run(SOCKET_CODE) + def test_unix_socket(): + run(UNIX_SOCKET_CODE) + """, + ) + if state == "allowed": + result = testdir.runpytest("--disable-socket", "--allow-unix-socket") + result.assert_outcomes(passed=1, skipped=0, failed=1) + else: + result = testdir.runpytest("--disable-socket") + result.assert_outcomes(passed=0, skipped=0, failed=2) + + +class TestAllowHosts: + def test_single_cli_arg_connect_enabled(self, assert_connect): + assert_connect(True, cli_arg=localhost) + + def test_single_cli_arg_connect_enabled_localhost_resolved(self, assert_connect): + assert_connect(True, cli_arg="localhost") + + def test_single_cli_arg_127_0_0_1_hostname_localhost_connect_disabled( + self, assert_connect + ): + assert_connect(False, cli_arg=localhost, host="localhost") + + def test_single_cli_arg_localhost_hostname_localhost_connect_enabled( + self, assert_connect + ): + assert_connect(True, cli_arg="localhost", host="localhost") + + def test_single_cli_arg_connect_disabled_hostname_resolved(self, assert_connect): + result = assert_connect( + False, + cli_arg="localhost", + host="1.2.3.4", + code_template=urlopen_hostname_code_template, + ) + result.stdout.fnmatch_lines( + '*A test tried to use socket.socket.connect() with host "1.2.3.4" ' + '(allowed: "localhost (127.0.0.1*' + ) + + def test_single_cli_arg_connect_enabled_hostname_unresolvable(self, assert_connect): + assert_connect(False, cli_arg="unresolvable") + + def test_multiple_cli_arg_connect_enabled(self, assert_connect): + assert_connect(True, cli_arg=localhost + ",1.2.3.4") + + def test_single_mark_arg_connect_enabled(self, assert_connect): + assert_connect(True, mark_arg=localhost) + + def test_multiple_mark_arg_csv_connect_enabled(self, assert_connect): + assert_connect(True, mark_arg=localhost + ",1.2.3.4") + + def test_multiple_mark_arg_list_connect_enabled(self, assert_connect): + assert_connect(True, mark_arg=[localhost, "1.2.3.4"]) + + def test_mark_cli_conflict_mark_wins_connect_enabled(self, assert_connect): + assert_connect(True, mark_arg=[localhost], cli_arg="1.2.3.4") + + def test_single_cli_arg_connect_disabled(self, assert_connect): + assert_connect(False, cli_arg="1.2.3.4") + + def test_multiple_cli_arg_connect_disabled(self, assert_connect): + assert_connect(False, cli_arg="5.6.7.8,1.2.3.4") + + def test_single_mark_arg_connect_disabled(self, assert_connect): + assert_connect(False, mark_arg="1.2.3.4") + + def test_multiple_mark_arg_csv_connect_disabled(self, assert_connect): + assert_connect(False, mark_arg="5.6.7.8,1.2.3.4") + + def test_multiple_mark_arg_list_connect_disabled(self, assert_connect): + assert_connect(False, mark_arg=["5.6.7.8", "1.2.3.4"]) + + def test_mark_cli_conflict_mark_wins_connect_disabled(self, assert_connect): + assert_connect(False, mark_arg=["1.2.3.4"], cli_arg=localhost) + + def test_default_urlopen_succeeds_by_default(self, assert_connect): + assert_connect(True, code_template=urlopen_code_template) + + def test_single_cli_arg_urlopen_enabled(self, assert_connect): + assert_connect( + True, cli_arg=localhost + ",1.2.3.4", code_template=urlopen_code_template + ) + + def test_single_mark_arg_urlopen_enabled(self, assert_connect): + assert_connect( + True, mark_arg=[localhost, "1.2.3.4"], code_template=urlopen_code_template + )