Skip to content

feat: Support Python subprocesses #409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ tests to ensure network calls are prevented.
## Features

- Disables all network calls flowing through Python\'s `socket` interface.
- Python subprocesses are supported

## Requirements

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ include = [
{ path = "README.md", format = "sdist" },
{ path = "tests", format = "sdist" },
{ path = ".flake8", format = "sdist" },
{ path = "pytest_socket.pth", format = ["sdist", "wheel"] }
]
classifiers = [
"Development Status :: 4 - Beta",
Expand Down
48 changes: 48 additions & 0 deletions pytest_socket.embed
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Inject pytest-socket into Python subprocesses via a .pth file.

To update the .pth file, simply run this script which will write the code
below to a .pth file in the same directory as a single line.
"""

import os
os.environ["_PYTEST_SOCKET_SUBPROCESS"] = ""

# .PTH START
if config := os.getenv("_PYTEST_SOCKET_SUBPROCESS", None):
import json
state = None
try:
import socket
from pytest_socket import disable_socket, _create_guarded_connect
state = json.loads(config)

if state["mode"] == "disable":
disable_socket(allow_unix_socket=state["allow_unix_socket"])
elif state["mode"] == "allow-hosts":
socket.socket.connect = _create_guarded_connect(
allowed_hosts=state["allowed_hosts"],
allow_unix_socket=state["allow_unix_socket"],
_pretty_allowed_list=state["_pretty_allowed_list"]
)

except Exception as exc:
import sys
sys.stderr.write(
"pytest-socket: Failed to set up subprocess socket patching.\n"
f"Configuration: {state}\n"
f"{exc.__class__.__name__}: {exc}\n"
)
# .PTH END

if __name__ == "__main__":
from pathlib import Path

src = Path(__file__)
dst = src.with_suffix(".pth")
lines = src.read_text().splitlines()
code = "\n".join(lines[lines.index("# .PTH START") + 1 : lines.index("# .PTH END")])

print(f"Writing to {dst}")
# Only lines beginning with an import will be executed.
# https://docs.python.org/3/library/site.html
dst.write_text(f"import os; exec({code!r})\n")
1 change: 1 addition & 0 deletions pytest_socket.pth
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import os; exec('if config := os.getenv("_PYTEST_SOCKET_SUBPROCESS", None):\n import json\n state = None\n try:\n import socket\n from pytest_socket import disable_socket, _create_guarded_connect\n state = json.loads(config)\n\n if state["mode"] == "disable":\n disable_socket(allow_unix_socket=state["allow_unix_socket"])\n elif state["mode"] == "allow-hosts":\n socket.socket.connect = _create_guarded_connect(\n allowed_hosts=state["allowed_hosts"],\n allow_unix_socket=state["allow_unix_socket"],\n _pretty_allowed_list=state["_pretty_allowed_list"]\n )\n\n except Exception as exc:\n import sys\n sys.stderr.write(\n "pytest-socket: Failed to set up subprocess socket patching.\\n"\n f"Configuration: {state}\\n"\n f"{exc.__class__.__name__}: {exc}\\n"\n )')
63 changes: 53 additions & 10 deletions pytest_socket.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,34 @@
import ipaddress
import itertools
import json
import os
import socket
import typing
from collections import defaultdict
from dataclasses import dataclass, field

import pytest

_SUBPROCESS_ENVVAR = "_PYTEST_SOCKET_SUBPROCESS"
_true_socket = socket.socket
_true_connect = socket.socket.connect


def update_subprocess_config(config: typing.Dict[str, object]) -> None:
"""Enable pytest-socket in Python subprocesses.

The configuration will be read by the .pth file to mirror the
restrictions in the main process.
"""
os.environ[_SUBPROCESS_ENVVAR] = json.dumps(config)


def delete_subprocess_config() -> None:
"""Disable pytest-socket in Python subprocesses."""
if _SUBPROCESS_ENVVAR in os.environ:
del os.environ[_SUBPROCESS_ENVVAR]


class SocketBlockedError(RuntimeError):
def __init__(self, *_args, **_kwargs):
super().__init__("A test tried to use socket.socket.")
Expand Down Expand Up @@ -103,11 +121,15 @@ def __new__(cls, family=-1, type=-1, proto=-1, fileno=None):
raise SocketBlockedError()

socket.socket = GuardedSocket
update_subprocess_config(
{"mode": "disable", "allow_unix_socket": allow_unix_socket}
)


def enable_socket():
"""re-enable socket.socket to enable the Internet. useful in testing."""
socket.socket = _true_socket
delete_subprocess_config()


def pytest_configure(config):
Expand Down Expand Up @@ -249,6 +271,25 @@ def normalize_allowed_hosts(
return ip_hosts


def _create_guarded_connect(
allowed_hosts: typing.Sequence[str],
allow_unix_socket: bool,
_pretty_allowed_list: typing.Sequence[str],
) -> typing.Callable:
"""Create a function to replace socket.connect."""

def guarded_connect(inst, *args):
host = host_from_connect_args(args)
if host in allowed_hosts or (
_is_unix_socket(inst.family) and allow_unix_socket
):
return _true_connect(inst, *args)

raise SocketConnectBlockedError(_pretty_allowed_list, host)

return guarded_connect


def socket_allow_hosts(
allowed: typing.Union[str, typing.List[str], None] = None,
allow_unix_socket: bool = False,
Expand Down Expand Up @@ -276,19 +317,21 @@ def socket_allow_hosts(
]
)

def guarded_connect(inst, *args):
host = host_from_connect_args(args)
if host in allowed_ip_hosts_and_hostnames or (
_is_unix_socket(inst.family) and allow_unix_socket
):
return _true_connect(inst, *args)

raise SocketConnectBlockedError(allowed_list, host)

socket.socket.connect = guarded_connect
socket.socket.connect = _create_guarded_connect(
allowed_ip_hosts_and_hostnames, allow_unix_socket, allowed_list
)
update_subprocess_config(
{
"mode": "allow-hosts",
"allowed_hosts": list(allowed_ip_hosts_and_hostnames),
"allow_unix_socket": allow_unix_socket,
"_pretty_allowed_list": allowed_list,
}
)


def _remove_restrictions():
"""restore socket.socket.* to allow access to the Internet. useful in testing."""
socket.socket = _true_socket
socket.socket.connect = _true_connect
delete_subprocess_config()
6 changes: 6 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,9 @@ def assert_socket_blocked(result, passed=0, skipped=0, failed=1):
result.stdout.fnmatch_lines(
"*Socket*Blocked*Error: A test tried to use socket.socket.*"
)


def assert_host_blocked(result, host):
result.stdout.fnmatch_lines(
f'*A test tried to use socket.socket.connect() with host "{host}"*'
)
8 changes: 1 addition & 7 deletions tests/test_restrict_hosts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from pytest_socket import normalize_allowed_hosts
from pytest_socket import assert_host_blocked, normalize_allowed_hosts

localhost = "127.0.0.1"

Expand Down Expand Up @@ -46,12 +46,6 @@ def {2}():
"""


def assert_host_blocked(result, host):
result.stdout.fnmatch_lines(
f'*A test tried to use socket.socket.connect() with host "{host}"*'
)


@pytest.fixture
def assert_connect(httpbin, testdir):
def assert_socket_connect(should_pass, **kwargs):
Expand Down
Loading
Loading