diff --git a/tests/test_socket.py b/tests/test_socket.py index f0015b8..08de1e7 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -106,6 +106,29 @@ def test_client_socket_close(): server_socket.close() +def test_client_socket_closed(): + server_socket = TServerSocket(host="localhost", port=12345) + server_socket.listen() + + client_socket = TSocket(host="localhost", port=12345) + client_socket.open() + + conn = server_socket.accept() + client_socket.close() + assert not client_socket.is_open() + + with pytest.raises(TTransportException) as e: + client_socket.read(1024) + assert "Could not read from closed socket" in e.value.message + + with pytest.raises(TTransportException) as e: + client_socket.write(b"world") + assert "Could not write into closed socket" in e.value.message + + conn.close() + server_socket.close() + + def test_server_socket_close(): server_socket = TServerSocket(host="localhost", port=12345) server_socket.listen() @@ -124,6 +147,17 @@ def test_server_socket_close(): server_socket.close() +def test_server_socket_closed(): + server_socket = TServerSocket(host="localhost", port=12345) + server_socket.listen() + + server_socket.close() + + with pytest.raises(TTransportException) as e: + server_socket.accept() + assert "Could not accept on closed socket" in e.value.message + + def test_client_socket_set_timeout(): server_socket = TServerSocket(host="localhost", port=12345, client_timeout=100) diff --git a/tests/test_sslsocket.py b/tests/test_sslsocket.py index 4948a12..5742bdd 100644 --- a/tests/test_sslsocket.py +++ b/tests/test_sslsocket.py @@ -96,3 +96,17 @@ def test_persist_ssl_context(): ssl_context=client_ssl_context) _test_socket(server_socket, client_socket) + + +def test_server_socket_closed(): + server_ssl_context = create_thriftpy_context(server_side=True) + server_ssl_context.load_cert_chain(certfile="ssl/server.pem") + server_socket = TSSLServerSocket(host="localhost", port=12345, + ssl_context=server_ssl_context) + server_socket.listen() + + server_socket.close() + + with pytest.raises(TTransportException) as e: + server_socket.accept() + assert "Could not accept on closed socket" in e.value.message diff --git a/thriftpy/transport/socket.py b/thriftpy/transport/socket.py index 51a5283..63eabb0 100644 --- a/thriftpy/transport/socket.py +++ b/thriftpy/transport/socket.py @@ -104,6 +104,11 @@ def open(self): message="Could not connect to %s" % str(addr)) def read(self, sz): + if self.sock is None: + raise TTransportException( + type=TTransportException.NOT_OPEN, + message="Could not read from closed socket") + try: buff = self.sock.recv(sz) except socket.error as e: @@ -126,6 +131,11 @@ def read(self, sz): return buff def write(self, buff): + if self.sock is None: + raise TTransportException( + type=TTransportException.NOT_OPEN, + message="Could not write into closed socket") + self.sock.sendall(buff) def flush(self): @@ -209,6 +219,11 @@ def listen(self): self.sock.listen(self.backlog) def accept(self): + if self.sock is None: + raise TTransportException( + type=TTransportException.NOT_OPEN, + message="Could not accept on closed socket") + client, _ = self.sock.accept() if self.client_timeout: client.settimeout(self.client_timeout) @@ -221,5 +236,6 @@ def close(self): try: self.sock.shutdown(socket.SHUT_RDWR) self.sock.close() + self.sock = None except (socket.error, OSError): pass diff --git a/thriftpy/transport/sslsocket.py b/thriftpy/transport/sslsocket.py index 3972e34..0f52cd4 100644 --- a/thriftpy/transport/sslsocket.py +++ b/thriftpy/transport/sslsocket.py @@ -7,6 +7,7 @@ import ssl import struct +from . import TTransportException from ._ssl import ( create_thriftpy_context, RESTRICTED_SERVER_CIPHERS, @@ -109,6 +110,11 @@ def __init__(self, host, port, socket_family=socket.AF_INET, self.ssl_context.load_cert_chain(certfile=certfile) def accept(self): + if self.sock is None: + raise TTransportException( + type=TTransportException.NOT_OPEN, + message="Could not accept on closed socket") + sock, _ = self.sock.accept() try: ssl_sock = self.ssl_context.wrap_socket(sock, server_side=True)