Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mockssh/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


SAMPLE_USER_KEY = os.path.join(os.path.dirname(__file__), "sample-user-key")

SAMPLE_USER_PASSWORD = "greeneggs&spam"

@fixture
def user_key_path() -> str:
Expand All @@ -27,6 +27,8 @@ def user_key_path() -> str:
def server() -> Iterator[mockssh.server.Server]:
users = {
"sample-user": SAMPLE_USER_KEY,
"sample-user2": {"type": "password", "password": SAMPLE_USER_PASSWORD},
"sample-user3": {"type": "key", "private_key_path": SAMPLE_USER_KEY},
}
with Server(users) as s:
yield s
Expand Down
208 changes: 186 additions & 22 deletions mockssh/server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections.abc
import errno
import logging
import os
Expand All @@ -11,7 +12,6 @@

from mockssh import sftp
from mockssh.streaming import StreamTransfer
from paramiko.client import SSHClient
from typing import Dict

__all__ = [
Expand All @@ -21,6 +21,97 @@
SERVER_KEY_PATH = os.path.join(os.path.dirname(__file__), "server-key")


class UserData:
"""
parameters:
data: needs to either be a string (the path to the private key), or
a dictionary of one of the following forms:
{"type": "key",
"private_key_path": <some appropriate value>,
"key_type": # Optional. One of: "ssh-rsa", "ssh-dss",
# ECDSAKey format identifiers, or "ssh-ed25519"
}
or
{"type": "password",
"password": <some appropriate value>
}
public_key: The public key to use (rather than calculate).
Useful if assigning to _users directly.

"""
allowed_credential_types = ('key', 'password')

def __init__(self, data, public_key=None):
self.private_key_path = None
self.key_type = None
self.public_key = None
self.password = None

if isinstance(data, collections.abc.Mapping):
if 'type' in data:
if data['type'] in self.allowed_credential_types:
self.credential_type = data['type']
else:
raise ValueError('Unrecognized credential type.')
else:
raise ValueError(
"users dictionary value is missing key 'type'."
)
else:
# backwards-compatible, assume data is a path to private key
self.credential_type = "key"
data = {"type": "key",
"private_key_path": data,
"key_type": "ssh-rsa"
}

if self.credential_type == 'key':
try:
self.private_key_path = data["private_key_path"]
except KeyError:
raise ValueError(
"users dictionary value is missing key 'private_key_path'"
)
self.key_type = data.get("key_type", "ssh-rsa")
if public_key is None:
self.public_key = self.calculate_public_key(
self.private_key_path,
self.key_type
)
else:
self.public_key = public_key # supports 'server._users = '
# assignments...
elif self.credential_type == 'password':
try:
self.password = data['password']
except KeyError:
raise ValueError(
"users dictionary value is missing key 'password'"
)

@staticmethod
def calculate_public_key(private_key_path, key_type="ssh-rsa"):
if key_type == "ssh-rsa":
public_key = paramiko.RSAKey.from_private_key_file(
private_key_path
)
elif key_type == "ssh-dss":
public_key = paramiko.DSSKey.from_private_key_file(
private_key_path
)
elif key_type in paramiko.ECDSAKey.supported_key_format_identifiers():
public_key = paramiko.ECDSAKey.from_private_key_file(
private_key_path
)
elif key_type == "ssh-ed25519":
public_key = paramiko.Ed25519Key.from_private_key_file(
private_key_path
)
else:
raise Exception("Unable to handle key of type {}".format(key_type))
return public_key


class Handler(paramiko.ServerInterface):
log = logging.getLogger(__name__)

Expand Down Expand Up @@ -64,16 +155,44 @@ def handle_client(self, channel):
except EOFError:
self.log.debug("Tried to close already closed channel")

def check_auth_password(self, username, password):
try:
user_data = self.server._userdata[username]
except KeyError:
self.log.debug("Unknown user '%s'", username)
return paramiko.AUTH_FAILED

if user_data.credential_type != 'password':
self.log.debug("User data for user '%s' is not of type "
"'password'; rejecting password."
)
return paramiko.AUTH_FAILED

if user_data.password == password:
self.log.debug("Accepting password for user '%s'", username)
return paramiko.AUTH_SUCCESSFUL

self.log.debug("Rejecting password for user '%s'", username)
return paramiko.AUTH_FAILED

def check_auth_publickey(self, username, key):
try:
_, known_public_key = self.server._users[username]
user_data = self.server._userdata[username]
except KeyError:
self.log.debug("Unknown user '%s'", username)
return paramiko.AUTH_FAILED
if known_public_key == key:

if user_data.credential_type != 'key':
self.log.debug("User data for user '%s' is not of type "
"'key'; rejecting public key."
)
return paramiko.AUTH_FAILED

if user_data.public_key == key:
self.log.debug("Accepting public key for user '%s'", username)
return paramiko.AUTH_SUCCESSFUL
self.log.debug("Rejecting public ley for user '%s'", username)

self.log.debug("Rejecting public key for user '%s'", username)
return paramiko.AUTH_FAILED

def check_channel_exec_request(self, channel, command):
Expand All @@ -86,7 +205,11 @@ def check_channel_request(self, kind, chanid):
return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED

def get_allowed_auths(self, username):
return "publickey"
ud = self.server._userdata[username]
if ud.credential_type == 'key':
return "publickey"
else:
return "password"


class Server(object):
Expand All @@ -98,23 +221,56 @@ class Server(object):
def __init__(self, users: Dict[str, str]) -> None:
self._socket = None
self._thread = None
self._users = {}
for uid, private_key_path in users.items():
self.add_user(uid, private_key_path)
self._userdata = {}
self._users_cached = None
for uid, credential in users.items():
user_data = UserData(credential)
self._userdata[uid] = user_data


@property
def _users(self):
if self._users_cached is None:
self._users_cached = {
k:(ud.private_key_path, ud.public_key)
for k, ud in self._userdata.items() if ud.credential_type == 'key'
}
return self._users_cached

@_users.setter
def _users(self, value):
# Questionable, but backwards compatible if someone ever set
# this directly. Obviously only supports using private keys.
self._users_cached = None
self._userdata = {}
for uid, data in value.items():
private_key_path, public_key = data
self._userdata[uid] = UserData(
private_key_path,
public_key=public_key
)


def add_user(self, uid: str, private_key_path: str, keytype: str="ssh-rsa") -> None:
if keytype == "ssh-rsa":
key = paramiko.RSAKey.from_private_key_file(private_key_path)
paramiko.RSAKey.from_private_key_file(private_key_path)
elif keytype == "ssh-dss":
key = paramiko.DSSKey.from_private_key_file(private_key_path)
paramiko.DSSKey.from_private_key_file(private_key_path)
elif keytype in paramiko.ECDSAKey.supported_key_format_identifiers():
key = paramiko.ECDSAKey.from_private_key_file(private_key_path)
paramiko.ECDSAKey.from_private_key_file(private_key_path)
elif keytype == "ssh-ed25519":
key = paramiko.Ed25519Key.from_private_key_file(private_key_path)
paramiko.Ed25519Key.from_private_key_file(private_key_path)
else:
raise Exception("Unable to handle key of type {}".format(keytype))

self._users_cached = None # invalidate cache
ud = {"type": "key",
"private_key_path": private_key_path,
"key_type": keytype
}

self._users[uid] = (private_key_path, key)
user_data = UserData(ud)
self._userdata[uid] = user_data

def __enter__(self) -> "Server":
self._socket = s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand Down Expand Up @@ -154,20 +310,28 @@ def __exit__(self, *exc_info) -> None:
self._socket = None
self._thread = None

def client(self, uid: str) -> SSHClient:
private_key_path, _ = self._users[uid]
def client(self, uid):
ud = self._userdata[uid]
c = paramiko.SSHClient()
host_keys = c.get_host_keys()

key = paramiko.RSAKey.from_private_key_file(SERVER_KEY_PATH)
host_keys.add(self.host, "ssh-rsa", key)
host_keys.add("[%s]:%d" % (self.host, self.port), "ssh-rsa", key)
c.set_missing_host_key_policy(paramiko.RejectPolicy())
c.connect(hostname=self.host,
port=self.port,
username=uid,
key_filename=private_key_path,
allow_agent=False,
look_for_keys=False)
conn_kwargs = {
"hostname": self.host,
"port": self.port,
"username": uid,
"allow_agent": False,
"look_for_keys": False
}
if ud.credential_type == 'key':
conn_kwargs["key_filename"] = ud.private_key_path
else:
conn_kwargs["password"] = ud.password

c.connect(**conn_kwargs)
return c

@property
Expand All @@ -176,4 +340,4 @@ def port(self) -> int:

@property
def users(self):
return self._users.keys()
return self._userdata.keys()
3 changes: 3 additions & 0 deletions mockssh/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

def test_ssh_session(server: Server):
for uid in server.users:
print('Testing multiple connections with user', uid)
print('=================================================')
with server.client(uid) as c:
_, stdout, _ = c.exec_command("ls /")
assert "etc" in (codecs.decode(bit, "utf8")
Expand Down Expand Up @@ -58,6 +60,7 @@ def _test_multiple_connections(server: Server):
user, private_key = list(server._users.items())[0]
open(pkey_path, 'w').write(open(private_key[0]).read())
ssh_command = 'ssh -oStrictHostKeyChecking=no '
ssh_command += '-oUserKnownHostsFile=/dev/null '
ssh_command += "-i %s -p %s %s@localhost " % (pkey_path, server.port, user)
ssh_command += 'echo hello'
p = subprocess.check_output(ssh_command, shell=True)
Expand Down