diff --git a/mockssh/conftest.py b/mockssh/conftest.py index 4510446..e0fae4d 100644 --- a/mockssh/conftest.py +++ b/mockssh/conftest.py @@ -16,7 +16,7 @@ SAMPLE_USER_KEY = os.path.join(os.path.dirname(__file__), "sample-user-key") - +SAMPLE_USER_PASSWORD = "greeneggs&spam" @fixture def user_key_path() -> str: @@ -27,6 +27,8 @@ def user_key_path() -> str: def server() -> Iterator[mockssh.server.Server]: users = { "sample-user": SAMPLE_USER_KEY, + "sample-user2": {"type": "password", "password": SAMPLE_USER_PASSWORD}, + "sample-user3": {"type": "key", "private_key_path": SAMPLE_USER_KEY}, } with Server(users) as s: yield s diff --git a/mockssh/server.py b/mockssh/server.py index 9bb157e..58fa763 100644 --- a/mockssh/server.py +++ b/mockssh/server.py @@ -1,3 +1,4 @@ +import collections.abc import errno import logging import os @@ -11,7 +12,6 @@ from mockssh import sftp from mockssh.streaming import StreamTransfer -from paramiko.client import SSHClient from typing import Dict __all__ = [ @@ -21,6 +21,97 @@ SERVER_KEY_PATH = os.path.join(os.path.dirname(__file__), "server-key") +class UserData: + """ + parameters: + data: needs to either be a string (the path to the private key), or + a dictionary of one of the following forms: + {"type": "key", + "private_key_path": , + "key_type": # Optional. One of: "ssh-rsa", "ssh-dss", + # ECDSAKey format identifiers, or "ssh-ed25519" + } + or + {"type": "password", + "password": + } + public_key: The public key to use (rather than calculate). + Useful if assigning to _users directly. + + """ + allowed_credential_types = ('key', 'password') + + def __init__(self, data, public_key=None): + self.private_key_path = None + self.key_type = None + self.public_key = None + self.password = None + + if isinstance(data, collections.abc.Mapping): + if 'type' in data: + if data['type'] in self.allowed_credential_types: + self.credential_type = data['type'] + else: + raise ValueError('Unrecognized credential type.') + else: + raise ValueError( + "users dictionary value is missing key 'type'." + ) + else: + # backwards-compatible, assume data is a path to private key + self.credential_type = "key" + data = {"type": "key", + "private_key_path": data, + "key_type": "ssh-rsa" + } + + if self.credential_type == 'key': + try: + self.private_key_path = data["private_key_path"] + except KeyError: + raise ValueError( + "users dictionary value is missing key 'private_key_path'" + ) + self.key_type = data.get("key_type", "ssh-rsa") + if public_key is None: + self.public_key = self.calculate_public_key( + self.private_key_path, + self.key_type + ) + else: + self.public_key = public_key # supports 'server._users = ' + # assignments... + elif self.credential_type == 'password': + try: + self.password = data['password'] + except KeyError: + raise ValueError( + "users dictionary value is missing key 'password'" + ) + + @staticmethod + def calculate_public_key(private_key_path, key_type="ssh-rsa"): + if key_type == "ssh-rsa": + public_key = paramiko.RSAKey.from_private_key_file( + private_key_path + ) + elif key_type == "ssh-dss": + public_key = paramiko.DSSKey.from_private_key_file( + private_key_path + ) + elif key_type in paramiko.ECDSAKey.supported_key_format_identifiers(): + public_key = paramiko.ECDSAKey.from_private_key_file( + private_key_path + ) + elif key_type == "ssh-ed25519": + public_key = paramiko.Ed25519Key.from_private_key_file( + private_key_path + ) + else: + raise Exception("Unable to handle key of type {}".format(key_type)) + return public_key + + class Handler(paramiko.ServerInterface): log = logging.getLogger(__name__) @@ -64,16 +155,44 @@ def handle_client(self, channel): except EOFError: self.log.debug("Tried to close already closed channel") + def check_auth_password(self, username, password): + try: + user_data = self.server._userdata[username] + except KeyError: + self.log.debug("Unknown user '%s'", username) + return paramiko.AUTH_FAILED + + if user_data.credential_type != 'password': + self.log.debug("User data for user '%s' is not of type " + "'password'; rejecting password." + ) + return paramiko.AUTH_FAILED + + if user_data.password == password: + self.log.debug("Accepting password for user '%s'", username) + return paramiko.AUTH_SUCCESSFUL + + self.log.debug("Rejecting password for user '%s'", username) + return paramiko.AUTH_FAILED + def check_auth_publickey(self, username, key): try: - _, known_public_key = self.server._users[username] + user_data = self.server._userdata[username] except KeyError: self.log.debug("Unknown user '%s'", username) return paramiko.AUTH_FAILED - if known_public_key == key: + + if user_data.credential_type != 'key': + self.log.debug("User data for user '%s' is not of type " + "'key'; rejecting public key." + ) + return paramiko.AUTH_FAILED + + if user_data.public_key == key: self.log.debug("Accepting public key for user '%s'", username) return paramiko.AUTH_SUCCESSFUL - self.log.debug("Rejecting public ley for user '%s'", username) + + self.log.debug("Rejecting public key for user '%s'", username) return paramiko.AUTH_FAILED def check_channel_exec_request(self, channel, command): @@ -86,7 +205,11 @@ def check_channel_request(self, kind, chanid): return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED def get_allowed_auths(self, username): - return "publickey" + ud = self.server._userdata[username] + if ud.credential_type == 'key': + return "publickey" + else: + return "password" class Server(object): @@ -98,23 +221,56 @@ class Server(object): def __init__(self, users: Dict[str, str]) -> None: self._socket = None self._thread = None - self._users = {} - for uid, private_key_path in users.items(): - self.add_user(uid, private_key_path) + self._userdata = {} + self._users_cached = None + for uid, credential in users.items(): + user_data = UserData(credential) + self._userdata[uid] = user_data + + + @property + def _users(self): + if self._users_cached is None: + self._users_cached = { + k:(ud.private_key_path, ud.public_key) + for k, ud in self._userdata.items() if ud.credential_type == 'key' + } + return self._users_cached + + @_users.setter + def _users(self, value): + # Questionable, but backwards compatible if someone ever set + # this directly. Obviously only supports using private keys. + self._users_cached = None + self._userdata = {} + for uid, data in value.items(): + private_key_path, public_key = data + self._userdata[uid] = UserData( + private_key_path, + public_key=public_key + ) + def add_user(self, uid: str, private_key_path: str, keytype: str="ssh-rsa") -> None: if keytype == "ssh-rsa": - key = paramiko.RSAKey.from_private_key_file(private_key_path) + paramiko.RSAKey.from_private_key_file(private_key_path) elif keytype == "ssh-dss": - key = paramiko.DSSKey.from_private_key_file(private_key_path) + paramiko.DSSKey.from_private_key_file(private_key_path) elif keytype in paramiko.ECDSAKey.supported_key_format_identifiers(): - key = paramiko.ECDSAKey.from_private_key_file(private_key_path) + paramiko.ECDSAKey.from_private_key_file(private_key_path) elif keytype == "ssh-ed25519": - key = paramiko.Ed25519Key.from_private_key_file(private_key_path) + paramiko.Ed25519Key.from_private_key_file(private_key_path) else: raise Exception("Unable to handle key of type {}".format(keytype)) + + self._users_cached = None # invalidate cache + ud = {"type": "key", + "private_key_path": private_key_path, + "key_type": keytype + } - self._users[uid] = (private_key_path, key) + user_data = UserData(ud) + self._userdata[uid] = user_data def __enter__(self) -> "Server": self._socket = s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -154,20 +310,28 @@ def __exit__(self, *exc_info) -> None: self._socket = None self._thread = None - def client(self, uid: str) -> SSHClient: - private_key_path, _ = self._users[uid] + def client(self, uid): + ud = self._userdata[uid] c = paramiko.SSHClient() host_keys = c.get_host_keys() + key = paramiko.RSAKey.from_private_key_file(SERVER_KEY_PATH) host_keys.add(self.host, "ssh-rsa", key) host_keys.add("[%s]:%d" % (self.host, self.port), "ssh-rsa", key) c.set_missing_host_key_policy(paramiko.RejectPolicy()) - c.connect(hostname=self.host, - port=self.port, - username=uid, - key_filename=private_key_path, - allow_agent=False, - look_for_keys=False) + conn_kwargs = { + "hostname": self.host, + "port": self.port, + "username": uid, + "allow_agent": False, + "look_for_keys": False + } + if ud.credential_type == 'key': + conn_kwargs["key_filename"] = ud.private_key_path + else: + conn_kwargs["password"] = ud.password + + c.connect(**conn_kwargs) return c @property @@ -176,4 +340,4 @@ def port(self) -> int: @property def users(self): - return self._users.keys() + return self._userdata.keys() diff --git a/mockssh/test_server.py b/mockssh/test_server.py index 00e2af2..6f34dad 100644 --- a/mockssh/test_server.py +++ b/mockssh/test_server.py @@ -13,6 +13,8 @@ def test_ssh_session(server: Server): for uid in server.users: + print('Testing multiple connections with user', uid) + print('=================================================') with server.client(uid) as c: _, stdout, _ = c.exec_command("ls /") assert "etc" in (codecs.decode(bit, "utf8") @@ -58,6 +60,7 @@ def _test_multiple_connections(server: Server): user, private_key = list(server._users.items())[0] open(pkey_path, 'w').write(open(private_key[0]).read()) ssh_command = 'ssh -oStrictHostKeyChecking=no ' + ssh_command += '-oUserKnownHostsFile=/dev/null ' ssh_command += "-i %s -p %s %s@localhost " % (pkey_path, server.port, user) ssh_command += 'echo hello' p = subprocess.check_output(ssh_command, shell=True)