Skip to content

feat: support cidr blocks in --allow-hosts #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 135 additions & 41 deletions pytest_socket.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# -*- coding: utf-8 -*-
import socket
import ipaddress
import pytest
import re


_true_socket = socket.socket
_true_connect = socket.socket.connect
_cached_domain_lookups = {}


class SocketBlockedError(RuntimeError):
Expand All @@ -14,31 +18,33 @@ def __init__(self, *args, **kwargs):
class SocketConnectBlockedError(RuntimeError):
def __init__(self, allowed, host, *args, **kwargs):
if allowed:
allowed = ','.join(allowed)
allowed = ",".join(allowed)
super(SocketConnectBlockedError, self).__init__(
'A test tried to use socket.socket.connect() with host "{0}" (allowed: "{1}").'.format(host, allowed)
'A test tried to use socket.socket.connect() with host "{0}" (allowed: "{1}").'.format(
host, allowed
)
)


def pytest_addoption(parser):
group = parser.getgroup('socket')
group = parser.getgroup("socket")
group.addoption(
'--disable-socket',
action='store_true',
dest='disable_socket',
help='Disable socket.socket by default to block network calls.'
"--disable-socket",
action="store_true",
dest="disable_socket",
help="Disable socket.socket by default to block network calls.",
)
group.addoption(
'--allow-hosts',
dest='allow_hosts',
metavar='ALLOWED_HOSTS_CSV',
help='Only allow specified hosts through socket.socket.connect((host, port)).'
"--allow-hosts",
dest="allow_hosts",
metavar="ALLOWED_HOSTS_CSV",
help="Only allow specified hosts through socket.socket.connect((host, port)).",
)
group.addoption(
'--allow-unix-socket',
action='store_true',
dest='allow_unix_socket',
help='Allow calls if they are to Unix domain sockets'
"--allow-unix-socket",
action="store_true",
dest="allow_unix_socket",
help="Allow calls if they are to Unix domain sockets",
)


Expand All @@ -51,39 +57,39 @@ def _socket_marker(request):
The expected behavior is that higher granularity options should override
lower granularity options.
"""
if request.config.getoption('--disable-socket'):
request.getfixturevalue('socket_disabled')
if request.config.getoption("--disable-socket"):
request.getfixturevalue("socket_disabled")

if request.node.get_closest_marker('disable_socket'):
request.getfixturevalue('socket_disabled')
if request.node.get_closest_marker('enable_socket'):
request.getfixturevalue('socket_enabled')
if request.node.get_closest_marker("disable_socket"):
request.getfixturevalue("socket_disabled")
if request.node.get_closest_marker("enable_socket"):
request.getfixturevalue("socket_enabled")


@pytest.fixture
def socket_disabled(pytestconfig):
""" disable socket.socket for duration of this test function """
allow_unix_socket = pytestconfig.getoption('--allow-unix-socket')
"""disable socket.socket for duration of this test function"""
allow_unix_socket = pytestconfig.getoption("--allow-unix-socket")
disable_socket(allow_unix_socket)
yield
enable_socket()


@pytest.fixture
def socket_enabled(pytestconfig):
""" enable socket.socket for duration of this test function """
"""enable socket.socket for duration of this test function"""
enable_socket()
yield
allow_unix_socket = pytestconfig.getoption('--allow-unix-socket')
allow_unix_socket = pytestconfig.getoption("--allow-unix-socket")
disable_socket(allow_unix_socket)


def disable_socket(allow_unix_socket=False):
""" disable socket.socket to disable the Internet. useful in testing.
"""
"""disable socket.socket to disable the Internet. useful in testing."""

class GuardedSocket(socket.socket):
""" socket guard to disable socket creation (from pytest-socket) """
"""socket guard to disable socket creation (from pytest-socket)"""

def __new__(cls, *args, **kwargs):
try:
is_unix_socket = args[0] == socket.AF_UNIX
Expand All @@ -100,20 +106,26 @@ def __new__(cls, *args, **kwargs):


def enable_socket():
""" re-enable socket.socket to enable the Internet. useful in testing.
"""
"""re-enable socket.socket to enable the Internet. useful in testing."""
socket.socket = _true_socket


def pytest_configure(config):
config.addinivalue_line("markers", "disable_socket(): Disable socket connections for a specific test")
config.addinivalue_line("markers", "enable_socket(): Enable socket connections for a specific test")
config.addinivalue_line("markers", "allow_hosts([hosts]): Restrict socket connection to defined list of hosts")
config.addinivalue_line(
"markers", "disable_socket(): Disable socket connections for a specific test"
)
config.addinivalue_line(
"markers", "enable_socket(): Enable socket connections for a specific test"
)
config.addinivalue_line(
"markers",
"allow_hosts([hosts]): Restrict socket connection to defined list of hosts",
)


def pytest_runtest_setup(item):
mark_restrictions = item.get_closest_marker('allow_hosts')
cli_restrictions = item.config.getoption('--allow-hosts')
mark_restrictions = item.get_closest_marker("allow_hosts")
cli_restrictions = item.config.getoption("--allow-hosts")
hosts = None
if mark_restrictions:
hosts = mark_restrictions.args[0]
Expand All @@ -140,23 +152,105 @@ def host_from_connect_args(args):


def socket_allow_hosts(allowed=None):
""" disable socket.socket.connect() to disable the Internet. useful in testing.
"""
"""disable socket.socket.connect() to disable the Internet. useful in testing."""
if isinstance(allowed, str):
allowed = allowed.split(',')
allowed = allowed.split(",")
if not isinstance(allowed, list):
return

ips = [a for a in allowed if is_ipaddress(a)]
cidrs = parse_cidrs_from_allowed(allowed)
domain_names = [a for a in parse_domains_from_allow(allowed)]

def guarded_connect(inst, *args):
host = host_from_connect_args(args)
if host and host in allowed:
if host and host in ips:
return _true_connect(inst, *args)
elif host_in_cidr_block(host, cidrs):
return _true_connect(inst, *args)
elif host_is_domain(host, domain_names):
return _true_connect(inst, *args)
raise SocketConnectBlockedError(allowed, host)

socket.socket.connect = guarded_connect


def remove_host_restrictions():
""" restore socket.socket.connect() to allow access to the Internet. useful in testing.
def host_in_cidr_block(host, cidrs):
if not host or len(cidrs) == 0:
return False
for cidr in cidrs:
if address_in_network(host, cidr):
return True
return False


def is_valid_cidr(network):
try:
ipaddress.ip_network(network)
except ValueError:
return False
return True


def is_ipaddress(address: str):
"""
Determine if the address is a valid IPv4 address.
"""
try:
socket.inet_aton(address)
return True
except socket.error:
return False


def host_is_domain(host, domains):
if not host or len(domains) == 0:
return False
for domain in domains:
if address_is_domain(host, domain):
return True
return False


def is_valid_domain(dn):
if dn.endswith("."):
dn = dn[:-1]
if len(dn) < 1 or len(dn) > 253:
return False
ldh_re = re.compile("^[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?$", re.IGNORECASE)
return all(ldh_re.match(x) for x in dn.split("."))


def parse_cidrs_from_allowed(allowed):
return [x for x in allowed if is_valid_cidr(x)]


def parse_domains_from_allow(allowed):
return [x for x in allowed if is_valid_domain(x)]


def address_in_network(ip, net):
return ipaddress.ip_address(ip) in ipaddress.ip_network(net)


def cache_ip_for_domain(ip, domain):
if domain not in _cached_domain_lookups:
_cached_domain_lookups[domain] = set()
_cached_domain_lookups[domain].add(ip)


def ip_is_cached_for_domain(ip, domain):
if domain in _cached_domain_lookups:
return ip in _cached_domain_lookups[domain]
return False


def address_is_domain(ip, domain):
ip_for_domain = socket.gethostbyname(domain)
cache_ip_for_domain(ip_for_domain, domain)
return ip_for_domain == ip or ip_is_cached_for_domain(ip, domain)


def remove_host_restrictions():
"""restore socket.socket.connect() to allow access to the Internet. useful in testing."""
socket.socket.connect = _true_connect
78 changes: 78 additions & 0 deletions tests/test_restrict_hosts.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,81 @@ def test_fail_2():
result.assert_outcomes(1, 0, 2)
assert_host_blocked(result, '2.2.2.2')
assert_host_blocked(result, test_url.hostname)


def test_cidr_allow(testdir, httpbin):
test_url = urlparse(httpbin.url)
testdir.makepyfile(
"""
import pytest
import socket

@pytest.mark.allow_hosts('127.0.0.0/8')
def test_pass():
socket.socket().connect(('{0}', {1}))

@pytest.mark.allow_hosts('127.0.0.0/16')
def test_pass_2():
socket.socket().connect(('{0}', {1}))

def test_fail():
socket.socket().connect(('2.2.2.2', {1}))

def test_fail_2():
socket.socket().connect(('192.168.1.10', {1}))

@pytest.mark.allow_hosts('172.20.0.0/16')
def test_fail_3():
socket.socket().connect(('{0}', {1}))
""".format(
test_url.hostname, test_url.port
)
)
result = testdir.runpytest("--verbose", "--allow-hosts=1.2.3.4")
result.assert_outcomes(2, 0, 3)
assert_host_blocked(result, "2.2.2.2")
assert_host_blocked(result, "192.168.1.10")
assert_host_blocked(result, test_url.hostname)


def test_domain_allow(testdir, httpbin):
test_url = urlparse(httpbin.url)
testdir.makepyfile(
"""
import pytest
import socket

@pytest.mark.allow_hosts('127.0.0.0/8')
def test_pass():
socket.socket().connect(('{0}', {1}))

@pytest.mark.allow_hosts('127.0.0.0/16')
def test_pass_2():
socket.socket().connect(('{0}', {1}))

@pytest.mark.allow_hosts('{0}')
def test_pass_3():
socket.socket().connect(('{0}', {1}))

@pytest.mark.allow_hosts('example.com')
def test_pass_4():
socket.socket().connect(('93.184.216.34', 443))

def test_fail():
socket.socket().connect(('2.2.2.2', {1}))

def test_fail_2():
socket.socket().connect(('192.168.1.10', {1}))

@pytest.mark.allow_hosts('172.20.0.0/16')
def test_fail_3():
socket.socket().connect(('{0}', {1}))
""".format(
test_url.hostname, test_url.port
)
)
result = testdir.runpytest("--verbose", "--allow-hosts=1.2.3.4")
result.assert_outcomes(4, 0, 3)
assert_host_blocked(result, "2.2.2.2")
assert_host_blocked(result, "192.168.1.10")
assert_host_blocked(result, test_url.hostname)