From f69936f52e579790cf97ce062e723b620655fcb3 Mon Sep 17 00:00:00 2001 From: Neha Das Date: Mon, 16 Feb 2026 14:33:36 +0000 Subject: [PATCH] gNSI: Add backend support for Credentialz Signed-off-by: Pattela JAYARAGINI --- host_modules/glome.py | 99 +++++ host_modules/gnsi_console.py | 207 +++++++++ host_modules/ssh_mgmt.py | 299 +++++++++++++ scripts/sonic-host-server | 10 +- tests/glome_test.py | 131 ++++++ tests/gnsi_console_test.py | 445 +++++++++++++++++++ tests/ssh_mgmt_test.py | 815 +++++++++++++++++++++++++++++++++++ 7 files changed, 2004 insertions(+), 2 deletions(-) create mode 100644 host_modules/glome.py create mode 100644 host_modules/gnsi_console.py create mode 100644 host_modules/ssh_mgmt.py create mode 100644 tests/glome_test.py create mode 100644 tests/gnsi_console_test.py create mode 100644 tests/ssh_mgmt_test.py diff --git a/host_modules/glome.py b/host_modules/glome.py new file mode 100644 index 00000000..f1af89eb --- /dev/null +++ b/host_modules/glome.py @@ -0,0 +1,99 @@ +"""GLOME DBus operations handler.""" + +import configparser +import json +import logging +import os +from pathlib import Path +import shutil +import stat + +from host_modules import host_service + +MOD_NAME = 'glome' +logger = logging.getLogger(__name__) + +class Glome(host_service.HostModule): + """DBus endpoint that executes GLOME operations on switch hosts.""" + + _GLOME_PATH = '/host/glome/glome.conf' + _GLOME_BACKUP_PATH = '/host/glome/glome_backup.conf' + + def _remove_config_file(self, file_path: str) -> None: + try: + os.remove(file_path) + except FileNotFoundError: + pass + + def _write_config_file(self, payload: dict[str, str]) -> None: + file_path = Path(self._GLOME_PATH) + file_path.parent.mkdir(parents=True, exist_ok=True) + config = configparser.ConfigParser(interpolation=None) + config.add_section('service') + config.set('service', 'key', str(payload['key'])) + config.set('service', 'key-version', str(payload['key_version'])) + config.set('service', 'url-prefix', str(payload['url_prefix'])) + with file_path.open('w') as f: + config.write(f) + os.chmod(file_path, stat.S_IRUSR | stat.S_IWUSR) + + @host_service.method( + host_service.bus_name(MOD_NAME), in_signature='s', out_signature='is' + ) + def push_config(self, request: str) -> tuple[int, str]: + """Backs up and updates the GLOME configuration file. + + Creates a backup of the GLOME configuration file, then stores the request + in the GLOME configuration file on the switch host. + """ + try: + # create a checkpoint first + if os.path.exists(self._GLOME_PATH): + # copy() copies the file data and the file’s permission mode. + shutil.copy(src=self._GLOME_PATH, dst=self._GLOME_BACKUP_PATH) + else: + self._remove_config_file(self._GLOME_BACKUP_PATH) + # process the request + payload = json.loads(request) + if payload['enabled']: + self._write_config_file(payload) + else: + self._remove_config_file(self._GLOME_PATH) + except PermissionError as e: + logger.error('PermissionError: %s\nrequest: %s', request, e) + return 1, f'A PermissionError error occurred: {e}' + except OSError as e: + logger.error('OSError: %s\nrequest: %s', request, e) + return 2, f'An OSError error occurred: {e}' + except json.decoder.JSONDecodeError as e: + logger.error('JSONDecodeError: %s\nrequest: %s', request, e) + return 3, f'A JSONDecodeError error occurred: {e}' + except KeyError as e: + logger.error('KeyError: %s\nrequest: %s', request, e) + return 4, f'A KeyError error occurred: {e}' + return 0, '' + + @host_service.method( + host_service.bus_name(MOD_NAME), in_signature='', out_signature='is' + ) + def restore_checkpoint(self) -> tuple[int, str]: + """Restores the GLOME configuration file to the backup file.""" + try: + if os.path.exists(self._GLOME_BACKUP_PATH): + # copy() copies the file data and the file’s permission mode. + shutil.copy(src=self._GLOME_BACKUP_PATH, dst=self._GLOME_PATH) + else: + self._remove_config_file(self._GLOME_PATH) + except PermissionError as e: + logger.error('PermissionError: %s', e) + return 1, f'A PermissionError error occurred: {e}' + except OSError as e: + logger.error('OSError: %s', e) + return 2, f'An OSError error occurred: {e}' + return 0, '' + + +def register(): + """Return class name.""" + return Glome, MOD_NAME + diff --git a/host_modules/gnsi_console.py b/host_modules/gnsi_console.py new file mode 100644 index 00000000..9c90ef6c --- /dev/null +++ b/host_modules/gnsi_console.py @@ -0,0 +1,207 @@ +"""gNSI console module used to manage console credentials""" + +import json +import os +import shutil +import logging + +from host_modules import host_service +from utils.run_cmd import _run_command + +MOD_NAME = 'gnsi_console' + +# File path which consists of console password +PASSWD_FILE = "/etc/shadow" +PASSWD_FILE_CHECKPOINT_FILE = PASSWD_FILE + "_checkpoint" +PASSWD_FILE_TEMP = PASSWD_FILE + "_temp" + +# Openssl command to generate hashed password using SHA512-based algorithm +OPENSSL_COMMAND = "openssl passwd -6 " + +# Constant trailing info regarding each password in the password file +TRAILING_PASSWORD_INFO = ":12215:0:99999:7:::\n" + +logger = logging.getLogger(__name__) + +class GnsiConsole(host_service.HostModule): + """DBus endpoint used to update console credentials for an existing user + """ + + @host_service.method(host_service.bus_name(MOD_NAME), in_signature='as', out_signature='is') + def create_checkpoint(self, options): + """Creates checkpoint for console password file so that the current + state can be restored later using restore_checkpoint(). create_checkpoint() will be + invoked when gNSI client starts the password change process.""" + try: + shutil.copy(PASSWD_FILE, PASSWD_FILE_CHECKPOINT_FILE) + except Exception as error: + return 1, "Failed to create checkpoint with error: " + str(error) + return 0, "Successfully created checkpoint" + + @host_service.method(host_service.bus_name(MOD_NAME), in_signature='as', out_signature='is') + def restore_checkpoint(self, options): + """Restore the state of the console password file to the state when + create_checkpoint() is called, i.e., to the state when the password change process has started. + Here, a move operation is performed as move is an atomic operation.""" + if not os.path.isfile(PASSWD_FILE_CHECKPOINT_FILE): + return 1, "Checkpoint file is not present" + + # Update the /etc/shadow with the checkpoint file + result = self.update_password_file(PASSWD_FILE_CHECKPOINT_FILE) + return result[0], "restore_checkpoint: " + result[1] + + @host_service.method(host_service.bus_name(MOD_NAME), in_signature='as', out_signature='is') + def delete_checkpoint(self, options): + """Deletes the checkpoint file created in create_checkpoint(). + delete_checkpoint() is invoked at the end of the successful password + change process.""" + try: + os.remove(PASSWD_FILE_CHECKPOINT_FILE) + except Exception as error: + return 1, "Failed to delete checkpoint with error: " + str(error) + return 0, "Successfully deleted checkpoint" + + def get_hashed_password(self, text_password): + """Generates and returns hashed password for given text password using + SHA-512-based password algorithm. Returns empty string on failure.""" + rc, stdout, stderr = _run_command(OPENSSL_COMMAND + text_password) + if rc: + logger.error("%s: Failed to get hash for given text password " + "with stdout: %s, stderr: %s" + % (MOD_NAME, stdout, stderr)) + return "" + return stdout[0] + + def read_password_file(self): + """Read contents of /etc/shadow password file and return its contents + in the form of a list where each line is an element in the list""" + try: + with open(PASSWD_FILE, 'r') as f: + password_file_content_list = f.readlines() + except IOError as error: + return [], "Failed to read password file with error: " + str(error) + return password_file_content_list, "" + + def update_password_if_user_found(self, user_name, user_password, + password_file_content_list): + """If user with user_name is found in password_file_content_list, then + this function will update password with user_password in + password_file_content_list. Logs an error if user_name is not found""" + found_user = False + for index,each_line in enumerate(password_file_content_list): + if each_line.startswith(user_name): + found_user = True + password_file_content_list[index] = (user_name + ":" + + user_password + + TRAILING_PASSWORD_INFO) + if not found_user: + logger.error("%s: The given user name: %s does not exist in the " + "password file" % (MOD_NAME, user_name)) + + def create_temp_passwd_file(self, password_file_content_list): + """Writes the contents of password_file_content_list into a temporary + file""" + rc = 0 + output = "" + try: + with open(PASSWD_FILE_TEMP, 'w') as f: + f.writelines(password_file_content_list) + except IOError as error: + rc = 1 + output = ("Failed to create temporary password file with error: " + + str(error)) + + # Remove temporary password file if it exists after failing to create + # this file with password_file_content_list + if rc and os.path.isfile(PASSWD_FILE_TEMP): + try: + os.remove(PASSWD_FILE_TEMP) + except Exception as error: + output += (" and also failed to remove temporary file " + "created with error: " + str(error)) + return rc, output + + def update_password_file(self, given_password_file): + """Overwrites /etc/shadow with given_password_file through a move operation """ + rc = 0 + output = "Successfully updated console passwords" + try: + shutil.move(given_password_file, PASSWD_FILE) + except Exception as error: + rc = 1 + output = ("Failed to replace original password file with " + "given password file with error: " + + str(error)) + + # Remove given_password_file if it exists after failing to overwrite + # /etc/shadow with given_password_file + if rc and os.path.isfile(given_password_file): + try: + os.remove(given_password_file) + except Exception as error: + output += (" and also failed to remove given password file " + "with error: " + str(error)) + + return rc, output + + @host_service.method(host_service.bus_name(MOD_NAME), in_signature='as', out_signature='is') + def set(self, options): + """Updates console passwords for exisitng users based on input request. + This API does not support creation or deletion of new user accounts.""" + if not os.path.isfile(PASSWD_FILE_CHECKPOINT_FILE): + return 1, "Trying to update console password without creating checkpoint" + + """Convert input json formatted password set request into python dict. + console_password_info_dict is a python dict with the following format: + { + "ConsolePasswords": [ + { "name": "alice", "password" : "password-alice" }, + { "name": "bob", "password" : "password-bob" } + ] + } + """ + try: + console_password_info_dict = json.loads(options[0]) + except json.JSONDecodeError: + return 1, ("Failed to parse json formatted password change request: " + + options[0]) + + if "ConsolePasswords" not in console_password_info_dict: + return 1, "Received invalid password request: %s" % str(console_password_info_dict) + + # Return on failed to read contents of /etc/shadow file + password_file_content_list, errstr = self.read_password_file() + if not password_file_content_list: + return 1, errstr + + # Iterate over each line in password file and update the passwords for + # the corresponding users in the input request + for index, each_request in enumerate(console_password_info_dict["ConsolePasswords"]): + # Skip processing the current element in input request if + # either "name" or "password" key is missing + if "name" not in each_request or "password" not in each_request: + logger.error("%s: Either name or password is not present at " + "index %d in password change request: %s" + % (MOD_NAME, index, str(console_password_info_dict))) + continue + + hashed_password = self.get_hashed_password(each_request["password"]) + if not hashed_password: + continue + + self.update_password_if_user_found(each_request["name"], hashed_password, + password_file_content_list) + + # Create a temporary password file with new changes + err, errstr = self.create_temp_passwd_file(password_file_content_list) + if err: + return err, errstr + + # Update the contents in /etc/shadow password file + result = self.update_password_file(PASSWD_FILE_TEMP) + return result[0], "set: " + result[1] + + +def register(): + """Return the class name""" + return GnsiConsole, MOD_NAME diff --git a/host_modules/ssh_mgmt.py b/host_modules/ssh_mgmt.py new file mode 100644 index 00000000..b861d797 --- /dev/null +++ b/host_modules/ssh_mgmt.py @@ -0,0 +1,299 @@ +"""SSH Management. + +This host service module implements the backend support for gNSI ssh rotation. +""" + +import json +import logging +import os +import shutil + +from host_modules import host_service + + +MOD_NAME = 'ssh_mgmt' +CHECKPOINT_DIR = '/tmp/ssh_checkpoint' +COPY_TEMP_FILE = 'ssh_mgmt_file_temp' +CA_PUB_KEY_NAME = 'ssh_ca_pub_key' +CA_PUB_KEY_TEMP = 'ssh_ca_pub_key_temp' + +CA_PUB_KEY_DIR = '/etc/sonic/ssh' +PERSISTENT_CA_PUB_KEY_DIRS = [ + CA_PUB_KEY_DIR +] + +ROOT_AUTHORIZED_KEYS_NAME = 'authorized_keys' +ROOT_AUTHORIZED_KEYS_TEMP = 'authorized_keys_temp' + +ROOT_AUTHORIZED_KEYS_DIR = '/etc/sonic/ssh/root' +PERSISTENT_ROOT_AUTHORIZED_KEYS_DIRS = [ + ROOT_AUTHORIZED_KEYS_DIR +] + +ROOT_AUTHORIZED_USERS_NAME = 'authorized_users' +ROOT_AUTHORIZED_USERS_TEMP = 'authorized_users_temp' + +ROOT_AUTHORIZED_USERS_DIR = '/etc/sonic/ssh/root' +PERSISTENT_ROOT_AUTHORIZED_USERS_DIRS = [ + ROOT_AUTHORIZED_USERS_DIR +] + +logger = logging.getLogger(__name__) + + +class SshMgmt(host_service.HostModule): + """DBus endpoint that updates ssh related files.""" + + @staticmethod + def _write_options(f, options): + first = True + for option in options: + if 'name' not in option: + continue + if first: + first = False + else: + f.write(',') + if option.get('value'): + f.write(option['name'] + '="' + option['value'] + '"') + else: + f.write(option['name']) + if not first: + f.write(' ') + + @staticmethod + def _copy_file(src, dest): + """This method first copies the source file to a temp file in the + destination directory. Then moves the temp file to the destination file. + If the source file does not exist, it will return success, so that we + can support ssh_mgmt update even if the files are missing. + """ + ret_code = 0 + ret_msg = 'Successfully copy file from %s to %s' % (src, dest) + if not os.path.exists(src): + logger.error('Source file %s does not exist in ssh_mgmt copy.', src) + return ret_code, ret_msg + try: + dir = os.path.dirname(dest) + os.makedirs(dir, exist_ok=True) + shutil.copyfile(src, os.path.join(dir, COPY_TEMP_FILE)) + shutil.move(os.path.join(dir, COPY_TEMP_FILE), dest) + except Exception: + ret_code = 1 + ret_msg = 'Failed to copy file from %s to %s' % (src, dest) + logger.error('%s: %s\n', MOD_NAME, ret_msg) + try: + os.remove(os.path.join(dir, COPY_TEMP_FILE)) + except Exception: + pass + return ret_code, ret_msg + + @staticmethod + def _copy_files(src, dest): + """This method is the same as _copy_file, but copies an array of files. + The length of the src array and the dest array must be the same. + Files will be copied from src to dest with the same index. + """ + if len(src) != len(dest): + return 1, 'Length of src and dest do not match in _copy_files' + ret_code = 0 + ret_msg = '' + for i in range(len(src)): + code, msg = SshMgmt._copy_file(src[i], dest[i]) + ret_code |= code + if ret_msg: + ret_msg += ' & ' + ret_msg += msg + return ret_code, ret_msg + + @host_service.method( + host_service.bus_name(MOD_NAME), in_signature='as', out_signature='is') + def create_checkpoint(self, options): + if os.path.isdir(CHECKPOINT_DIR): + logger.error('%s: ssh_mgmt.create_checkpoint is called while' + 'checkpoint still exists', MOD_NAME) + try: + shutil.rmtree(CHECKPOINT_DIR) + except Exception: + logger.error('%s: Failed to delete old ssh mgmt checkpoint %s ' + 'in ssh_mgmt.create_checkpoint!\n', MOD_NAME, + CHECKPOINT_DIR) + return 1, 'Error in deleting checkpoint' + + os.makedirs(CHECKPOINT_DIR, exist_ok=True) + + code, msg = SshMgmt._copy_files([ + os.path.join(CA_PUB_KEY_DIR, CA_PUB_KEY_NAME), + os.path.join(ROOT_AUTHORIZED_KEYS_DIR, ROOT_AUTHORIZED_KEYS_NAME), + os.path.join(ROOT_AUTHORIZED_USERS_DIR, ROOT_AUTHORIZED_USERS_NAME) + ], [ + os.path.join(CHECKPOINT_DIR, CA_PUB_KEY_NAME), + os.path.join(CHECKPOINT_DIR, ROOT_AUTHORIZED_KEYS_NAME), + os.path.join(CHECKPOINT_DIR, ROOT_AUTHORIZED_USERS_NAME) + ]) + if code != 0: + logger.error('%s: Failed to create ssh mgmt checkpoint!\n', + MOD_NAME) + try: + shutil.rmtree(CHECKPOINT_DIR) + except Exception: + logger.error('%s: Failed to delete ssh mgmt checkpoint %s! This' + ' might block gNSI ssh operations!\n', MOD_NAME, + CHECKPOINT_DIR) + else: + msg = 'Successfully created checkpoint' + return code, msg + + @host_service.method( + host_service.bus_name(MOD_NAME), in_signature='as', out_signature='is') + def restore_checkpoint(self, options): + if not os.path.isdir(CHECKPOINT_DIR): + return 1, 'Checkpoint does not exist' + + # We will restore the checkpoint to the persistent locations as well. + src_files = ([os.path.join(CHECKPOINT_DIR, CA_PUB_KEY_NAME)] * ( + len(PERSISTENT_CA_PUB_KEY_DIRS)+1))+( + [os.path.join(CHECKPOINT_DIR, ROOT_AUTHORIZED_KEYS_NAME)] * ( + len(PERSISTENT_ROOT_AUTHORIZED_KEYS_DIRS)+1))+( + [os.path.join( + CHECKPOINT_DIR, ROOT_AUTHORIZED_USERS_NAME)] * ( + len(PERSISTENT_ROOT_AUTHORIZED_USERS_DIRS)+1)) + + dest_files = [os.path.join(x, CA_PUB_KEY_NAME) + for x in (PERSISTENT_CA_PUB_KEY_DIRS + [CA_PUB_KEY_DIR]) + ] + [os.path.join(x, ROOT_AUTHORIZED_KEYS_NAME) + for x in (PERSISTENT_ROOT_AUTHORIZED_KEYS_DIRS+[ + ROOT_AUTHORIZED_USERS_DIR])]+[ + os.path.join(x, ROOT_AUTHORIZED_USERS_NAME) + for x in (PERSISTENT_ROOT_AUTHORIZED_USERS_DIRS+[ + ROOT_AUTHORIZED_USERS_DIR])] + + code, msg = SshMgmt._copy_files(src_files, dest_files) + if code != 0: + logger.error('%s: Failed to restore ssh mgmt checkpoint!\n', + MOD_NAME) + + try: + shutil.rmtree(CHECKPOINT_DIR) + except Exception: + logger.error('%s: Failed to delete ssh mgmt checkpoint %s!\n', + MOD_NAME, CHECKPOINT_DIR) + code = 1 + if msg: + msg += ' & ' + msg += 'Error in deleting checkpoint' + + if code == 0: + msg = 'Successfully restored checkpoint' + return code, msg + + @host_service.method( + host_service.bus_name(MOD_NAME), in_signature='as', out_signature='is') + def delete_checkpoint(self, options): + if not os.path.isdir(CHECKPOINT_DIR): + return 1, 'Checkpoint does not exist' + + try: + shutil.rmtree(CHECKPOINT_DIR) + except Exception: + logger.error('%s: Failed to delete ssh mgmt checkpoint %s!\n', + MOD_NAME, CHECKPOINT_DIR) + return 1, 'Error in deleting checkpoint' + return 0, 'Successfully deleted checkpoint' + + @host_service.method( + host_service.bus_name(MOD_NAME), in_signature='as', out_signature='is') + def set(self, options): + if not os.path.isdir(CHECKPOINT_DIR): + return 1, 'Update ssh config before creating checkpoint' + + try: + json_content = json.loads(options[0]) + except json.JSONDecodeError: + return 1, 'Invalid JSON' + + if len(json_content) == 0: + logger.error('%s: Empty request in ssh_mgmt.set.\n', MOD_NAME) + + code = 0 + for ssh_key in json_content: + if ssh_key == 'SshCaPublicKey': + # Write the content to a temp file. + with open(os.path.join(CHECKPOINT_DIR, CA_PUB_KEY_TEMP), 'w') as f: + for key in json_content[ssh_key]: + f.write(key + '\n') + # Copy the temp file. + code, msg = SshMgmt._copy_files( + [os.path.join(CHECKPOINT_DIR, CA_PUB_KEY_TEMP)] * + (len(PERSISTENT_CA_PUB_KEY_DIRS)+1), + [os.path.join(CA_PUB_KEY_DIR, CA_PUB_KEY_NAME)]+[ + os.path.join(x, CA_PUB_KEY_NAME) + for x in PERSISTENT_CA_PUB_KEY_DIRS] + ) + + elif ssh_key == 'SshAccountKeys': + # Write the content to a temp file. + with open( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_KEYS_TEMP), + 'w') as f: + for account in json_content[ssh_key]: + if account.get('account') != 'root': + continue + if 'keys' not in account: + continue + for key in account['keys']: + if not key.get('key'): + continue + if 'options' in key: + SshMgmt._write_options(f, key['options']) + f.write(key['key'] + '\n') + # Copy the temp file. + code, msg = SshMgmt._copy_files( + [os.path.join(CHECKPOINT_DIR, ROOT_AUTHORIZED_KEYS_TEMP)] * + (len(PERSISTENT_ROOT_AUTHORIZED_KEYS_DIRS)+1), + [os.path.join(ROOT_AUTHORIZED_KEYS_DIR, + ROOT_AUTHORIZED_KEYS_NAME)]+[os.path.join( + x, ROOT_AUTHORIZED_KEYS_NAME) + for x in PERSISTENT_ROOT_AUTHORIZED_KEYS_DIRS] + ) + + elif ssh_key == 'SshAccountUsers': + # Write the content to a temp file. + with open( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_USERS_TEMP), + 'w') as f: + for account in json_content[ssh_key]: + if account.get('account') != 'root': + continue + if 'users' not in account: + continue + for user in account['users']: + if not user.get('name'): + continue + if 'options' in user: + SshMgmt._write_options(f, user['options']) + f.write(user['name'] + '\n') + # Copy the temp file. + code, msg = SshMgmt._copy_files( + [os.path.join(CHECKPOINT_DIR, ROOT_AUTHORIZED_USERS_TEMP)] * + (len(PERSISTENT_ROOT_AUTHORIZED_USERS_DIRS)+1), + [os.path.join(ROOT_AUTHORIZED_USERS_DIR, + ROOT_AUTHORIZED_USERS_NAME)]+[os.path.join( + x, ROOT_AUTHORIZED_USERS_NAME) + for x in PERSISTENT_ROOT_AUTHORIZED_USERS_DIRS] + ) + + else: + logger.error('%s: Invalid key in ssh_mgmt.set: %s.\n', MOD_NAME, + ssh_key) + + if code == 0: + msg = 'Successfully set credentials' + return code, msg + + +def register(): + """Return class name.""" + return SshMgmt, MOD_NAME diff --git a/scripts/sonic-host-server b/scripts/sonic-host-server index 374faf0e..695f8e7c 100755 --- a/scripts/sonic-host-server +++ b/scripts/sonic-host-server @@ -18,12 +18,15 @@ from host_modules import ( debug_service, docker_service, file_service, - gcu, + gcu, + glome, + gnsi_console, gnoi_reset, host_service, image_service, reboot, showtech, + ssh_mgmt, systemd_service ) @@ -41,7 +44,10 @@ def register_dbus(): 'file_stat': file_service.FileService('file'), 'debug_service': debug_service.DebugExecutor('DebugExecutor'), 'debug_info': debug_info.DebugArtifactCollector('debug_info'), - 'gnoi_reset': gnoi_reset.GnoiReset('gnoi_reset') + 'gnoi_reset': gnoi_reset.GnoiReset('gnoi_reset'), + 'ssh_mgmt': ssh_mgmt.SshMgmt('ssh_mgmt'), + 'gnsi_console': gnsi_console.GnsiConsole('gnsi_console'), + 'glome': glome.Glome('glome') } for mod_name, handler_class in mod_dict.items(): handlers[mod_name] = handler_class diff --git a/tests/glome_test.py b/tests/glome_test.py new file mode 100644 index 00000000..7d9f52a5 --- /dev/null +++ b/tests/glome_test.py @@ -0,0 +1,131 @@ +import filecmp +import os +import sys +import tempfile +import unittest +from unittest import mock + +test_path = os.path.dirname(os.path.abspath(__file__)) +sonic_host_service_path = os.path.dirname(test_path) +host_modules_path = os.path.join(sonic_host_service_path, 'host_modules') +sys.path.append(host_modules_path) + +import glome + + +class TestGlome(unittest.TestCase): + + payload = '[service]\nkey = key\nkey-version = 1\nurl-prefix = url_prefix\n\n' + + @classmethod + def setUpClass(cls): + with mock.patch('glome.Glome.__init__', return_value=None): + cls.glome_module = glome.Glome(glome.MOD_NAME) + + def setUp(self): + self.glome_file = tempfile.NamedTemporaryFile(mode='w+', delete=False) + self.glome_backup_file = tempfile.NamedTemporaryFile( + mode='w+', delete=False + ) + glome.Glome._GLOME_PATH = self.glome_file.name + glome.Glome._GLOME_BACKUP_PATH = self.glome_backup_file.name + + def tearDown(self): + self.glome_file.close() + self.glome_backup_file.close() + + def _get_json_payload(self, enabled=True): + if enabled: + return ( + '{"enabled": true, "key": "key", "key_version": 1, "url_prefix":' + ' "url_prefix"}' + ) + else: + return '{"enabled": false}' + + def test_push_config_checkpoint_copy(self): + self.glome_file.write(self.payload) + self.glome_file.flush() + result = self.glome_module.push_config(self._get_json_payload()) + self.assertTrue( + filecmp.cmp(glome.Glome._GLOME_PATH, glome.Glome._GLOME_BACKUP_PATH) + ) + self.assertEqual(result[0], 0) + + def test_push_config_checkpoint_remove(self): + os.remove(self.glome_file.name) + result = self.glome_module.push_config(self._get_json_payload()) + self.assertFalse(os.path.exists(glome.Glome._GLOME_BACKUP_PATH)) + self.assertEqual(result[0], 0) + + def test_push_config_checkpoint_noop(self): + os.remove(self.glome_file.name) + os.remove(self.glome_backup_file.name) + result = self.glome_module.push_config(self._get_json_payload()) + self.assertEqual(result[0], 0) + + def test_push_config_disabled_file_removed(self): + result = self.glome_module.push_config(self._get_json_payload(False)) + self.assertEqual(result[0], 0) + self.assertFalse(os.path.exists(glome.Glome._GLOME_PATH)) + + def test_push_config_disabled_file_noop(self): + result = self.glome_module.push_config(self._get_json_payload(False)) + self.assertEqual(result[0], 0) + self.assertFalse(os.path.exists(glome.Glome._GLOME_PATH)) + + def test_push_config_enabled(self): + result = self.glome_module.push_config(self._get_json_payload()) + self.assertEqual(result[0], 0) + with open(glome.Glome._GLOME_PATH, 'r') as f: + self.assertEqual(f.read(), self.payload) + + def test_push_config_error(self): + result = self.glome_module.push_config('invalid json') + self.assertNotEqual(result[0], 0) + + with mock.patch('glome.os.path.exists', mock.MagicMock(return_value=True)): + with mock.patch('glome.shutil') as mock_shutil: + mock_shutil.copy.side_effect = PermissionError + result = self.glome_module.push_config(self._get_json_payload()) + self.assertNotEqual(result[0], 0) + + with mock.patch('glome.os.path.exists', mock.MagicMock(return_value=False)): + with mock.patch('glome.os.remove') as mock_remove: + mock_remove.side_effect = OSError + result = self.glome_module.push_config(self._get_json_payload()) + self.assertNotEqual(result[0], 0) + + def test_restore_checkpoint_noop(self): + os.remove(self.glome_backup_file.name) + os.remove(self.glome_file.name) + result = self.glome_module.restore_checkpoint() + self.assertEqual(result[0], 0) + + def test_restore_checkpoint_copy(self): + self.glome_backup_file.write(self.payload) + self.glome_backup_file.flush() + result = self.glome_module.restore_checkpoint() + self.assertTrue( + filecmp.cmp(glome.Glome._GLOME_PATH, glome.Glome._GLOME_BACKUP_PATH) + ) + self.assertEqual(result[0], 0) + + def test_restore_checkpoint_remove(self): + os.remove(self.glome_backup_file.name) + result = self.glome_module.restore_checkpoint() + self.assertFalse(os.path.exists(glome.Glome._GLOME_PATH)) + self.assertEqual(result[0], 0) + + def test_restore_checkpoint_error(self): + with mock.patch('glome.os.path.exists', mock.MagicMock(return_value=True)): + with mock.patch('glome.shutil') as mock_shutil: + mock_shutil.copy.side_effect = PermissionError + result = self.glome_module.restore_checkpoint() + self.assertNotEqual(result[0], 0) + + with mock.patch('glome.os.path.exists', mock.MagicMock(return_value=False)): + with mock.patch('glome.os.remove') as mock_remove: + mock_remove.side_effect = OSError + result = self.glome_module.restore_checkpoint() + self.assertNotEqual(result[0], 0) diff --git a/tests/gnsi_console_test.py b/tests/gnsi_console_test.py new file mode 100644 index 00000000..d6356d63 --- /dev/null +++ b/tests/gnsi_console_test.py @@ -0,0 +1,445 @@ +"""Tests for gnsi_console.""" + +import importlib.util +import importlib.machinery +import json +import sys +import os +import pytest + +if sys.version_info >= (3, 3): + from unittest import mock +else: + # Expect the 'mock' package for python 2 + # https://pypi.python.org/pypi/mock + import mock + +test_path = os.path.dirname(os.path.abspath(__file__)) +sonic_host_service_path = os.path.dirname(test_path) +host_modules_path = os.path.join(sonic_host_service_path, "host_modules") +sys.path.insert(0, sonic_host_service_path) + +TEST_EXCEPTION_MESSAGE = "test raise exception message" +TEST_HASHED_PASSWORD = "$6$wNa3DzanMzQ6U.0x$2LAaCYaiAua9muP/Q04sKWMNpHnIOOu2rQ.il.3BOjeKTrxCMqwg2NIamWmhhw3HZgHZGb79RozrKVc.tDnLs1" +TEST_TEXT_PASSWORD = "some_test_password" +TEST_VALID_USER = "root" +TEST_INVALID_USER = "root_test" +TEST_OLD_PASSWORD_FILE_CONTENT = [ + "root:old_hashed_password:12215:0:99999:7:::\n", + "second_user:second_hashed_password:12215:0:99999:7:::\n" +] +TEST_UPDATED_PASSWORD_FILE_CONTENT = [ + TEST_VALID_USER + ":" + TEST_HASHED_PASSWORD + ":12215:0:99999:7:::\n", + TEST_OLD_PASSWORD_FILE_CONTENT[1] +] +TEST_VALID_PASSWORD_CHANGE_REQEST = ( + "{ \"ConsolePasswords\": [ { \"name\": \"root\", \"password\" : " + "\"new_root_text_password\" }, { \"name\": \"second_user\", \"password\" :" + " \"new_second_text_password\"}]}" +) +TEST_INVALID_PASSWORD_CHANGE_REQEST = ( + "\"ConsolePasswords\": " + "[ { \"name\": \"root\", \"password\" : \"new_root_text_password\" }, " + "{ \"name\": \"second_user\", \"password\" : \"new_second_text_password\"}]" +) +TEST_RANDOM_PASSWORD_CHANGE_REQEST = ( + "{ \"Random\": [ { \"name\": \"root\", \"password\" : " + "\"new_root_text_password\" }, { \"name\": \"second_user\", \"password\" :" + " \"new_second_text_password\"}]}" +) + +def load_source(modname, filename): + loader = importlib.machinery.SourceFileLoader(modname, filename) + spec = importlib.util.spec_from_file_location(modname, filename, loader=loader) + module = importlib.util.module_from_spec(spec) + # The module is always executed and not cached in sys.modules. + # Uncomment the following line to cache the module. + sys.modules[module.__name__] = module + loader.exec_module(module) + return module + +load_source("host_service", host_modules_path + "/host_service.py") +#load_source("infra_host", host_modules_path + "/infra_host.py") +load_source("gnsi_console", host_modules_path + "/gnsi_console.py") + +from gnsi_console import * + + +class TestGnsiConsole(object): + @classmethod + def setup_class(cls): + with mock.patch("gnsi_console.GnsiConsole.__init__", return_value=None): + cls.gnsi_console_module = GnsiConsole(MOD_NAME) + + def test_create_checkpoint_success(self): + with mock.patch("gnsi_console.shutil") as mock_shutil: + result = self.gnsi_console_module.create_checkpoint([""]) + assert result[0] == 0 + assert result[1] == "Successfully created checkpoint" + mock_shutil.copy.assert_called_once_with(PASSWD_FILE, PASSWD_FILE_CHECKPOINT_FILE) + + def raise_exception_shutil_test(self, src, dst): + raise OSError(TEST_EXCEPTION_MESSAGE) + + def test_create_checkpoint_fail(self): + with mock.patch("gnsi_console.shutil") as mock_shutil: + mock_shutil.copy = self.raise_exception_shutil_test + result = self.gnsi_console_module.create_checkpoint([""]) + assert result[0] == 1 + assert result[1] == "Failed to create checkpoint with error: " + TEST_EXCEPTION_MESSAGE + + def test_restore_checkpoint_fail_checkpoint_not_present(self): + with mock.patch("gnsi_console.os") as mock_os: + mock_os.path.isfile.return_value = False + result = self.gnsi_console_module.restore_checkpoint([""]) + assert result[0] == 1 + assert result[1] == "Checkpoint file is not present" + mock_os.path.isfile.assert_called_once_with(PASSWD_FILE_CHECKPOINT_FILE) + + def test_restore_checkpoint_success(self): + with mock.patch("gnsi_console.os") as mock_os: + with mock.patch("gnsi_console.GnsiConsole.update_password_file") as mock_update_password_file: + mock_os.path.isfile.return_value = True + mock_update_password_file.return_value = (0, "Successfully updated console passwords") + result = self.gnsi_console_module.restore_checkpoint([""]) + assert result[0] == 0 + assert result[1] == "restore_checkpoint: Successfully updated console passwords" + mock_os.path.isfile.assert_called_once_with(PASSWD_FILE_CHECKPOINT_FILE) + mock_update_password_file.assert_called_once_with(PASSWD_FILE_CHECKPOINT_FILE) + + def raise_exception_os_test(self, src): + raise OSError(TEST_EXCEPTION_MESSAGE) + + def test_delete_checkpoint_success(self): + with mock.patch("gnsi_console.os") as mock_os: + result = self.gnsi_console_module.delete_checkpoint([""]) + assert result[0] == 0 + assert result[1] == "Successfully deleted checkpoint" + mock_os.remove.assert_called_once_with(PASSWD_FILE_CHECKPOINT_FILE) + + def test_delete_checkpoint_fail(self): + with mock.patch("gnsi_console.os") as mock_os: + mock_os.remove = self.raise_exception_os_test + result = self.gnsi_console_module.delete_checkpoint([""]) + assert result[0] == 1 + assert result[1] == "Failed to delete checkpoint with error: " + TEST_EXCEPTION_MESSAGE + + def test_get_hashed_password_success(self): + with mock.patch("gnsi_console._run_command") as mock_run_command: + mock_run_command.return_value = (0, [TEST_HASHED_PASSWORD], []) + assert self.gnsi_console_module.get_hashed_password(TEST_TEXT_PASSWORD) == TEST_HASHED_PASSWORD + mock_run_command.assert_called_once_with(OPENSSL_COMMAND + TEST_TEXT_PASSWORD) + + def test_get_hashed_password_fail(self): + with mock.patch("gnsi_console._run_command") as mock_run_command: + with mock.patch("gnsi_console.logger.error") as mock_logerror: + mock_run_command.return_value = (1, ["stdout test message"], ["stderr test message"]) + assert not self.gnsi_console_module.get_hashed_password(TEST_TEXT_PASSWORD) + expected_log_message = "gnsi_console: Failed to get hash for given text password " \ + "with stdout: ['stdout test message'], " \ + "stderr: ['stderr test message']" + mock_logerror.assert_called_once_with(expected_log_message) + mock_run_command.assert_called_once_with(OPENSSL_COMMAND + TEST_TEXT_PASSWORD) + + + def test_read_password_file_success(self): + with mock.patch("gnsi_console.open") as mock_open: + mock_file_handler = mock_open.return_value.__enter__.return_value + mock_file_handler.readlines.return_value = TEST_OLD_PASSWORD_FILE_CONTENT.copy() + result = self.gnsi_console_module.read_password_file() + assert result[0] == TEST_OLD_PASSWORD_FILE_CONTENT + assert not result[1] + mock_file_handler.readlines.assert_called_once_with() + + def raise_ioerror_fh(self): + raise IOError("IOError in unit test") + + def test_read_password_file_fail(self): + with mock.patch("gnsi_console.open") as mock_open: + mock_file_handler = mock_open.return_value.__enter__.return_value + mock_file_handler.readlines = self.raise_ioerror_fh + result = self.gnsi_console_module.read_password_file() + assert not result[0] + assert result[1] == "Failed to read password file with error: " + "IOError in unit test" + + def test_update_password_if_user_found_success(self): + test_password_file_content = TEST_OLD_PASSWORD_FILE_CONTENT.copy() + self.gnsi_console_module.update_password_if_user_found(TEST_VALID_USER, + TEST_HASHED_PASSWORD, + test_password_file_content) + assert test_password_file_content == TEST_UPDATED_PASSWORD_FILE_CONTENT + + def test_update_password_if_user_found_fail(self): + with mock.patch("gnsi_console.logger.error") as mock_logerror: + test_password_file_content = TEST_OLD_PASSWORD_FILE_CONTENT.copy() + self.gnsi_console_module.update_password_if_user_found(TEST_INVALID_USER, + TEST_HASHED_PASSWORD, + test_password_file_content) + assert test_password_file_content == TEST_OLD_PASSWORD_FILE_CONTENT + mock_logerror.assert_called_once_with("gnsi_console: The given user name: %s does " + "not exist in the password file" + % TEST_INVALID_USER) + + def raise_ioerror_fh_with_one_arg(self, first_arg): + raise IOError("IOError in unit test") + + def test_create_temp_passwd_file_success(self): + with mock.patch("gnsi_console.open") as mock_open: + mock_file_handler = mock_open.return_value.__enter__.return_value + result = self.gnsi_console_module.create_temp_passwd_file( + TEST_UPDATED_PASSWORD_FILE_CONTENT.copy()) + assert result[0] == 0 + assert result[1] == "" + mock_file_handler.writelines.assert_called_once_with( + TEST_UPDATED_PASSWORD_FILE_CONTENT) + + def test_create_temp_passwd_file_fail_file_does_not_exist_on_failure(self): + with mock.patch("gnsi_console.open") as mock_open: + with mock.patch("gnsi_console.os") as mock_os: + mock_file_handler = mock_open.return_value.__enter__.return_value + mock_file_handler.writelines = self.raise_ioerror_fh_with_one_arg + mock_os.path.isfile.return_value = False + result = self.gnsi_console_module.create_temp_passwd_file( + TEST_UPDATED_PASSWORD_FILE_CONTENT.copy()) + assert result[0] == 1 + assert result[1] == ( + "Failed to create temporary password file with error: " + + "IOError in unit test") + mock_os.path.isfile.assert_called_once_with( + PASSWD_FILE_TEMP) + + def test_create_temp_passwd_file_fail_file_removed_on_failure(self): + with mock.patch("gnsi_console.open") as mock_open: + with mock.patch("gnsi_console.os") as mock_os: + mock_file_handler = mock_open.return_value.__enter__.return_value + mock_file_handler.writelines = self.raise_ioerror_fh_with_one_arg + mock_os.path.isfile.return_value = True + result = self.gnsi_console_module.create_temp_passwd_file(TEST_UPDATED_PASSWORD_FILE_CONTENT.copy()) + assert result[0] == 1 + assert result[1] == ( + "Failed to create temporary password file with error: " + + "IOError in unit test") + mock_os.path.isfile.assert_called_once_with( + PASSWD_FILE_TEMP) + mock_os.remove.assert_called_once_with(PASSWD_FILE_TEMP) + + def test_create_temp_passwd_file_on_failure_file_remove_also_fail(self): + with mock.patch("gnsi_console.open") as mock_open: + with mock.patch("gnsi_console.os") as mock_os: + mock_file_handler = mock_open.return_value.__enter__.return_value + mock_file_handler.writelines = self.raise_ioerror_fh_with_one_arg + mock_os.path.isfile.return_value = True + mock_os.remove = self.raise_exception_os_test + result = self.gnsi_console_module.create_temp_passwd_file( + TEST_UPDATED_PASSWORD_FILE_CONTENT.copy()) + assert result[0] == 1 + assert result[1] == ( + "Failed to create temporary password file with error: " + + "IOError in unit test" + + " and also failed to remove temporary file created with error: " + + TEST_EXCEPTION_MESSAGE) + mock_os.path.isfile.assert_called_once_with( + PASSWD_FILE_TEMP) + + def test_update_password_file_success(self): + with mock.patch("gnsi_console.shutil") as mock_shutil: + result = self.gnsi_console_module.update_password_file(PASSWD_FILE_TEMP) + assert result[0] == 0 + assert result[1] == "Successfully updated console passwords" + mock_shutil.move.assert_called_once_with(PASSWD_FILE_TEMP, + PASSWD_FILE) + + def test_update_password_file_fail_move_failed(self): + with mock.patch("gnsi_console.shutil") as mock_shutil: + mock_shutil.move = self.raise_exception_shutil_test + result = self.gnsi_console_module.update_password_file(PASSWD_FILE_TEMP) + assert result[0] == 1 + assert result[1] == ( + "Failed to replace original password file " + "with given password file with error: " + + TEST_EXCEPTION_MESSAGE) + + + def test_update_password_file_fail_file_does_not_exist_on_failure(self): + with mock.patch("gnsi_console.shutil") as mock_shutil: + with mock.patch("gnsi_console.os") as mock_os: + mock_shutil.move = self.raise_exception_shutil_test + mock_os.path.isfile.return_value = False + result = self.gnsi_console_module.update_password_file(PASSWD_FILE_TEMP) + assert result[0] == 1 + assert result[1] == ( + "Failed to replace original password file with given password file with error: " + + TEST_EXCEPTION_MESSAGE) + mock_os.path.isfile.assert_called_once_with( + PASSWD_FILE_TEMP) + + def test_update_password_file_fail_file_removed_on_failure(self): + with mock.patch("gnsi_console.shutil") as mock_shutil: + with mock.patch("gnsi_console.os") as mock_os: + mock_shutil.move = self.raise_exception_shutil_test + mock_os.path.isfile.return_value = True + result = self.gnsi_console_module.update_password_file(PASSWD_FILE_TEMP) + assert result[0] == 1 + assert result[1] == ( + "Failed to replace original password file with given password file with error: " + + TEST_EXCEPTION_MESSAGE) + mock_os.path.isfile.assert_called_once_with( + PASSWD_FILE_TEMP) + mock_os.remove.assert_called_once_with(PASSWD_FILE_TEMP) + + def test_update_password_file_fail_on_failure_file_remove_also_fail(self): + with mock.patch("gnsi_console.shutil") as mock_shutil: + with mock.patch("gnsi_console.os") as mock_os: + mock_shutil.move = self.raise_exception_shutil_test + mock_os.path.isfile.return_value = True + mock_os.remove = self.raise_exception_os_test + result = self.gnsi_console_module.update_password_file(PASSWD_FILE_TEMP) + assert result[0] == 1 + assert result[1] == ( + "Failed to replace original password file with given password file with error: " + + TEST_EXCEPTION_MESSAGE + + " and also failed to remove given password file with error: " + + TEST_EXCEPTION_MESSAGE) + mock_os.path.isfile.assert_called_once_with( + PASSWD_FILE_TEMP) + + def test_set_fail_checkpoint_does_not_exist(self): + with mock.patch("gnsi_console.os") as mock_os: + mock_os.path.isfile.return_value = False + result = self.gnsi_console_module.set([TEST_VALID_PASSWORD_CHANGE_REQEST]) + assert result[0] == 1 + assert result[1] == "Trying to update console password without creating checkpoint" + mock_os.path.isfile.assert_called_once_with(PASSWD_FILE_CHECKPOINT_FILE) + + def test_set_fail_invalid_json(self): + with mock.patch("gnsi_console.os") as mock_os: + mock_os.path.isfile.return_value = True + result = self.gnsi_console_module.set([TEST_INVALID_PASSWORD_CHANGE_REQEST]) + assert result[0] == 1 + assert result[1] == "Failed to parse json formatted password change request: " + TEST_INVALID_PASSWORD_CHANGE_REQEST + mock_os.path.isfile.assert_called_once_with(PASSWD_FILE_CHECKPOINT_FILE) + + def test_set_fail_key_not_present(self): + with mock.patch("gnsi_console.os") as mock_os: + mock_os.path.isfile.return_value = True + result = self.gnsi_console_module.set([TEST_RANDOM_PASSWORD_CHANGE_REQEST]) + assert result[0] == 1 + assert result[1] == "Received invalid password request: %s" % str(json.loads(TEST_RANDOM_PASSWORD_CHANGE_REQEST)) + mock_os.path.isfile.assert_called_once_with(PASSWD_FILE_CHECKPOINT_FILE) + + def test_set_read_password_failed(self): + with mock.patch("gnsi_console.os") as mock_os: + with mock.patch("gnsi_console.GnsiConsole.read_password_file") as mock_read_password_file: + mock_os.path.isfile.return_value = True + mock_read_password_file.return_value = ([], "Read password file failed") + result = self.gnsi_console_module.set([TEST_VALID_PASSWORD_CHANGE_REQEST]) + assert result[0] == 1 + assert result[1] == "Read password file failed" + mock_os.path.isfile.assert_called_once_with(PASSWD_FILE_CHECKPOINT_FILE) + mock_read_password_file.assert_called_once_with() + + def test_set_success(self): + with mock.patch("gnsi_console.os") as mock_os: + with mock.patch("gnsi_console.GnsiConsole.read_password_file") as mock_read_password_file: + with mock.patch("gnsi_console.GnsiConsole.get_hashed_password") as mock_get_hashed_password: + with mock.patch("gnsi_console.GnsiConsole.update_password_if_user_found") as mock_update_password_if_user_found: + with mock.patch("gnsi_console.GnsiConsole.update_password_file") as mock_update_password_file: + with mock.patch("gnsi_console.GnsiConsole.create_temp_passwd_file") as mock_create_temp_passwd_file: + mock_os.path.isfile.return_value = True + mock_read_password_file.return_value = (TEST_OLD_PASSWORD_FILE_CONTENT.copy(), "") + mock_get_hashed_password.return_value = TEST_HASHED_PASSWORD + mock_create_temp_passwd_file.return_value = (0, "") + mock_update_password_file.return_value = (0, "Successfully updated console passwords") + result = self.gnsi_console_module.set([TEST_VALID_PASSWORD_CHANGE_REQEST]) + assert result[0] == 0 + assert result[1] == "set: Successfully updated console passwords" + assert mock_get_hashed_password.call_count == 2 + assert mock_update_password_if_user_found.call_count == 2 + mock_os.path.isfile.assert_called_once_with(PASSWD_FILE_CHECKPOINT_FILE) + mock_read_password_file.assert_called_once_with() + mock_get_hashed_password.assert_has_calls([mock.call("new_root_text_password"), + mock.call("new_second_text_password")]) + mock_update_password_if_user_found.assert_has_calls([mock.call("root", TEST_HASHED_PASSWORD, TEST_OLD_PASSWORD_FILE_CONTENT), + mock.call("second_user", TEST_HASHED_PASSWORD, TEST_OLD_PASSWORD_FILE_CONTENT)]) + mock_update_password_file.assert_called_once_with(PASSWD_FILE_TEMP) + mock_create_temp_passwd_file.assert_called_once_with(TEST_OLD_PASSWORD_FILE_CONTENT) + + def test_set_fail_create_temp_password_file_fail(self): + with mock.patch("gnsi_console.os") as mock_os: + with mock.patch("gnsi_console.GnsiConsole.read_password_file") as mock_read_password_file: + with mock.patch("gnsi_console.GnsiConsole.get_hashed_password") as mock_get_hashed_password: + with mock.patch("gnsi_console.GnsiConsole.update_password_if_user_found") as mock_update_password_if_user_found: + with mock.patch("gnsi_console.GnsiConsole.create_temp_passwd_file") as mock_create_temp_passwd_file: + mock_os.path.isfile.return_value = True + mock_read_password_file.return_value = (TEST_OLD_PASSWORD_FILE_CONTENT.copy(), "") + mock_get_hashed_password.return_value = TEST_HASHED_PASSWORD + mock_create_temp_passwd_file.return_value = (1, "Failed to create temporary password file with error: test error message") + result = self.gnsi_console_module.set([TEST_VALID_PASSWORD_CHANGE_REQEST]) + assert result[0] == 1 + assert result[1] == "Failed to create temporary password file with error: test error message" + assert mock_get_hashed_password.call_count == 2 + assert mock_update_password_if_user_found.call_count == 2 + mock_os.path.isfile.assert_called_once_with(PASSWD_FILE_CHECKPOINT_FILE) + mock_read_password_file.assert_called_once_with() + mock_get_hashed_password.assert_has_calls([mock.call("new_root_text_password"), + mock.call("new_second_text_password")]) + mock_update_password_if_user_found.assert_has_calls([mock.call("root", TEST_HASHED_PASSWORD, TEST_OLD_PASSWORD_FILE_CONTENT), + mock.call("second_user", TEST_HASHED_PASSWORD, TEST_OLD_PASSWORD_FILE_CONTENT)]) + mock_create_temp_passwd_file.assert_called_once_with(TEST_OLD_PASSWORD_FILE_CONTENT) + + def test_set_name_and_password_keys_not_present(self): + with mock.patch("gnsi_console.os") as mock_os: + with mock.patch("gnsi_console.GnsiConsole.read_password_file") as mock_read_password_file: + with mock.patch("gnsi_console.GnsiConsole.get_hashed_password") as mock_get_hashed_password: + with mock.patch("gnsi_console.GnsiConsole.update_password_if_user_found") as mock_update_password_if_user_found: + with mock.patch("gnsi_console.GnsiConsole.update_password_file") as mock_update_password_file: + with mock.patch("gnsi_console.GnsiConsole.create_temp_passwd_file") as mock_create_temp_passwd_file: + mock_os.path.isfile.return_value = True + mock_read_password_file.return_value = (TEST_OLD_PASSWORD_FILE_CONTENT.copy(), "") + mock_get_hashed_password.return_value = TEST_HASHED_PASSWORD + mock_create_temp_passwd_file.return_value = (0, "") + mock_update_password_file.return_value = (0, "Successfully updated console passwords") + remove_name_and_password_keys = json.loads(TEST_VALID_PASSWORD_CHANGE_REQEST) + remove_name_and_password_keys["ConsolePasswords"][0].pop("name") + remove_name_and_password_keys["ConsolePasswords"][1].pop("password") + result = self.gnsi_console_module.set([json.dumps(remove_name_and_password_keys)]) + assert result[0] == 0 + assert result[1] == "set: Successfully updated console passwords" + mock_os.path.isfile.assert_called_once_with(PASSWD_FILE_CHECKPOINT_FILE) + mock_read_password_file.assert_called_once_with() + mock_update_password_file.assert_called_once_with(PASSWD_FILE_TEMP) + mock_get_hashed_password.assert_not_called() + mock_update_password_if_user_found.assert_not_called() + + def test_set_name_hashed_password_fail_for_one_request(self): + with mock.patch("gnsi_console.os") as mock_os: + with mock.patch("gnsi_console.GnsiConsole.read_password_file") as mock_read_password_file: + with mock.patch("gnsi_console.GnsiConsole.get_hashed_password") as mock_get_hashed_password: + with mock.patch("gnsi_console.GnsiConsole.update_password_if_user_found") as mock_update_password_if_user_found: + with mock.patch("gnsi_console.GnsiConsole.update_password_file") as mock_update_password_file: + with mock.patch("gnsi_console.GnsiConsole.create_temp_passwd_file") as mock_create_temp_passwd_file: + mock_os.path.isfile.return_value = True + mock_read_password_file.return_value = (TEST_OLD_PASSWORD_FILE_CONTENT.copy(), "") + mock_get_hashed_password.side_effect = [TEST_HASHED_PASSWORD, ""] + mock_create_temp_passwd_file.return_value = (0, "") + mock_update_password_file.return_value = (0, "Successfully updated console passwords") + result = self.gnsi_console_module.set([TEST_VALID_PASSWORD_CHANGE_REQEST]) + assert result[0] == 0 + assert result[1] == "set: Successfully updated console passwords" + assert mock_get_hashed_password.call_count == 2 + mock_os.path.isfile.assert_called_once_with(PASSWD_FILE_CHECKPOINT_FILE) + mock_read_password_file.assert_called_once_with() + mock_get_hashed_password.assert_has_calls([mock.call("new_root_text_password"), + mock.call("new_second_text_password")]) + mock_update_password_if_user_found.assert_called_once_with("root", TEST_HASHED_PASSWORD, TEST_OLD_PASSWORD_FILE_CONTENT) + mock_update_password_file.assert_called_once_with(PASSWD_FILE_TEMP) + mock_create_temp_passwd_file.assert_called_once_with(TEST_OLD_PASSWORD_FILE_CONTENT) + + def test_register(self): + result = register() + assert result[0] == GnsiConsole + assert result[1] == MOD_NAME + + @classmethod + def teardown_class(cls): + print("TEARDOWN") diff --git a/tests/ssh_mgmt_test.py b/tests/ssh_mgmt_test.py new file mode 100644 index 00000000..154b8423 --- /dev/null +++ b/tests/ssh_mgmt_test.py @@ -0,0 +1,815 @@ +"""Tests for ssh_mgmt""" + +import builtins +import importlib.util +import importlib.machinery +import os +import pytest +import sys +import json + +if sys.version_info >= (3, 3): + from unittest import mock +else: + import mock + +test_path = os.path.dirname(os.path.abspath(__file__)) +sonic_host_service_path = os.path.dirname(test_path) +host_modules_path = os.path.join(sonic_host_service_path, "host_modules") +sys.path.insert(0, sonic_host_service_path) + +def load_source(modname, filename): + loader = importlib.machinery.SourceFileLoader(modname, filename) + spec = importlib.util.spec_from_file_location(modname, filename, loader=loader) + module = importlib.util.module_from_spec(spec) + # The module is always executed and not cached in sys.modules. + # Uncomment the following line to cache the module. + sys.modules[module.__name__] = module + loader.exec_module(module) + return module + +TEST_EXCEPTION_MESSAGE = "test raise exception message" +load_source("host_service", host_modules_path + "/host_service.py") +load_source("ssh_mgmt", host_modules_path + "/ssh_mgmt.py") + +from ssh_mgmt import * + + +class MockFileHandler: + + def __init__(self): + self.contents = "" + + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, exception_traceback): + pass + + def write(self, content): + self.contents += content + + def close(self, content): + pass + + def get_contents(self): + return self.contents + + +class TestSshMgmt(object): + @classmethod + def setup_class(cls): + with mock.patch("ssh_mgmt.SshMgmt.__init__", return_value=None): + cls.ssh_mgmt_module = SshMgmt(MOD_NAME) + + def test_create_checkpoint(self): + # Create checkpoint succeeds. + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=False)) as mock_isdir: + with mock.patch("ssh_mgmt.os.path.exists", mock.MagicMock(return_value=True)) as mock_exists: + with mock.patch("ssh_mgmt.os.makedirs") as mock_makedirs: + with mock.patch("ssh_mgmt.os.remove") as mock_remove: + with mock.patch("ssh_mgmt.shutil") as mock_shutil: + result = self.ssh_mgmt_module.create_checkpoint([]) + assert result[0] == 0 + assert result[1] == "Successfully created checkpoint" + mock_isdir.assert_called_with(CHECKPOINT_DIR) + mock_makedirs.assert_called_with( + CHECKPOINT_DIR, exist_ok=True) + mock_exists.assert_has_calls([ + mock.call( + os.path.join(CA_PUB_KEY_DIR, + CA_PUB_KEY_NAME)), + mock.call( + os.path.join(ROOT_AUTHORIZED_KEYS_DIR, + ROOT_AUTHORIZED_KEYS_NAME)), + mock.call( + os.path.join(ROOT_AUTHORIZED_USERS_DIR, + ROOT_AUTHORIZED_USERS_NAME)), + ], + any_order=True) + mock_remove.assert_called_with( + os.path.join(CHECKPOINT_DIR, COPY_TEMP_FILE)) + mock_makedirs.assert_has_calls([ + mock.call(CHECKPOINT_DIR, exist_ok=True), + mock.call(CHECKPOINT_DIR, exist_ok=True), + mock.call(CHECKPOINT_DIR, exist_ok=True), + ], + any_order=True) + mock_shutil.copyfile.assert_has_calls([ + mock.call( + os.path.join(CA_PUB_KEY_DIR, + CA_PUB_KEY_NAME), + os.path.join(CHECKPOINT_DIR, COPY_TEMP_FILE)), + mock.call( + os.path.join(ROOT_AUTHORIZED_KEYS_DIR, + ROOT_AUTHORIZED_KEYS_NAME), + os.path.join(CHECKPOINT_DIR, COPY_TEMP_FILE)), + mock.call( + os.path.join(ROOT_AUTHORIZED_USERS_DIR, + ROOT_AUTHORIZED_USERS_NAME), + os.path.join(CHECKPOINT_DIR, COPY_TEMP_FILE)), + ], + any_order=True) + mock_shutil.move.assert_has_calls([ + mock.call( + os.path.join(CHECKPOINT_DIR, + COPY_TEMP_FILE), + os.path.join(CHECKPOINT_DIR, CA_PUB_KEY_NAME)), + mock.call( + os.path.join(CHECKPOINT_DIR, + COPY_TEMP_FILE), + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_KEYS_NAME)), + mock.call( + os.path.join(CHECKPOINT_DIR, + COPY_TEMP_FILE), + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_USERS_NAME)), + ], + any_order=True) + + # Create checkpoint succeeds when old checkpoint exists. + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=True)) as mock_isdir: + with mock.patch("ssh_mgmt.os.path.exists", mock.MagicMock(return_value=True)) as mock_exists: + with mock.patch("ssh_mgmt.os.makedirs") as mock_makedirs: + with mock.patch("ssh_mgmt.os.remove") as mock_remove: + with mock.patch("ssh_mgmt.shutil") as mock_shutil: + result = self.ssh_mgmt_module.create_checkpoint([]) + assert result[0] == 0 + assert result[1] == "Successfully created checkpoint" + mock_isdir.assert_called_with(CHECKPOINT_DIR) + mock_makedirs.assert_called_with( + CHECKPOINT_DIR, exist_ok=True) + mock_shutil.rmtree.assert_called_with( + CHECKPOINT_DIR) + mock_exists.assert_has_calls([ + mock.call( + os.path.join(CA_PUB_KEY_DIR, + CA_PUB_KEY_NAME)), + mock.call( + os.path.join(ROOT_AUTHORIZED_KEYS_DIR, + ROOT_AUTHORIZED_KEYS_NAME)), + mock.call( + os.path.join(ROOT_AUTHORIZED_USERS_DIR, + ROOT_AUTHORIZED_USERS_NAME)), + ], + any_order=True) + mock_remove.assert_called_with( + os.path.join(CHECKPOINT_DIR, COPY_TEMP_FILE)) + mock_makedirs.assert_has_calls([ + mock.call(CHECKPOINT_DIR, exist_ok=True), + mock.call(CHECKPOINT_DIR, exist_ok=True), + mock.call(CHECKPOINT_DIR, exist_ok=True), + ], + any_order=True) + mock_shutil.copyfile.assert_has_calls([ + mock.call( + os.path.join(CA_PUB_KEY_DIR, + CA_PUB_KEY_NAME), + os.path.join(CHECKPOINT_DIR, COPY_TEMP_FILE)), + mock.call( + os.path.join(ROOT_AUTHORIZED_KEYS_DIR, + ROOT_AUTHORIZED_KEYS_NAME), + os.path.join(CHECKPOINT_DIR, COPY_TEMP_FILE)), + mock.call( + os.path.join(ROOT_AUTHORIZED_USERS_DIR, + ROOT_AUTHORIZED_USERS_NAME), + os.path.join(CHECKPOINT_DIR, COPY_TEMP_FILE)) + ], + any_order=True) + mock_shutil.move.assert_has_calls([ + mock.call( + os.path.join(CHECKPOINT_DIR, + COPY_TEMP_FILE), + os.path.join(CHECKPOINT_DIR, CA_PUB_KEY_NAME)), + mock.call( + os.path.join(CHECKPOINT_DIR, + COPY_TEMP_FILE), + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_KEYS_NAME)), + mock.call( + os.path.join(CHECKPOINT_DIR, + COPY_TEMP_FILE), + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_USERS_NAME)) + ], + any_order=True) + + # Create checkpoint succeeds when source files do not exist. + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=False)) as mock_isdir: + with mock.patch("ssh_mgmt.os.path.exists", mock.MagicMock(return_value=False)) as mock_exists: + with mock.patch("ssh_mgmt.os.makedirs") as mock_makedirs: + result = self.ssh_mgmt_module.create_checkpoint([]) + assert result[0] == 0 + assert result[1] == "Successfully created checkpoint" + mock_isdir.assert_called_with(CHECKPOINT_DIR) + mock_makedirs.assert_called_with( + CHECKPOINT_DIR, exist_ok=True) + mock_exists.assert_has_calls([ + mock.call( + os.path.join(CA_PUB_KEY_DIR, + CA_PUB_KEY_NAME)), + mock.call( + os.path.join(ROOT_AUTHORIZED_KEYS_DIR, + ROOT_AUTHORIZED_KEYS_NAME)), + mock.call( + os.path.join(ROOT_AUTHORIZED_USERS_DIR, + ROOT_AUTHORIZED_USERS_NAME)), + ], + any_order=True) + + def mock_copyfile(src, dest): + raise OSError(TEST_EXCEPTION_MESSAGE) + + def mock_rmtree(dir): + raise OSError(TEST_EXCEPTION_MESSAGE) + + def mock_remove(file): + raise OSError(TEST_EXCEPTION_MESSAGE) + + # Create checkpoint fails when copy and delete operations fail. + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=False)): + with mock.patch("ssh_mgmt.os.path.exists", mock.MagicMock(return_value=True)): + with mock.patch("ssh_mgmt.os.makedirs"): + with mock.patch("ssh_mgmt.os.remove"): + with mock.patch("ssh_mgmt.shutil") as mock_shutil: + mock_shutil.copyfile = mock_copyfile + mock_shutil.rmtree = mock_rmtree + result = self.ssh_mgmt_module.create_checkpoint([]) + assert result[0] == 1 + assert result[1] != "Successfully created checkpoint" + + # Create checkpoint fails when old checkpoint exists and fail to delete it. + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=True)): + with mock.patch("ssh_mgmt.os.path.exists", mock.MagicMock(return_value=True)): + with mock.patch("ssh_mgmt.shutil") as mock_shutil: + mock_shutil.rmtree = mock_rmtree + result = self.ssh_mgmt_module.create_checkpoint([]) + assert result[0] == 1 + assert result[1] != "Successfully created checkpoint" + + # Create checkpoint success when temp file delete fails. + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=False)): + with mock.patch("ssh_mgmt.os.path.exists", mock.MagicMock(return_value=True)): + with mock.patch("ssh_mgmt.os.makedirs"): + with mock.patch("ssh_mgmt.shutil"): + os.remove = mock_remove + result = self.ssh_mgmt_module.create_checkpoint([]) + assert result[0] == 0 + assert result[1] == "Successfully created checkpoint" + + def test_restore_checkpoint(self): + # Restore checkpoint fails when checkpoint does not exist. + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=False)) as mock_isdir: + result = self.ssh_mgmt_module.restore_checkpoint([]) + assert result[0] == 1 + assert result[1] == "Checkpoint does not exist" + mock_isdir.assert_called_with(CHECKPOINT_DIR) + + # Restore checkpoint succeeds. + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=True)) as mock_isdir: + with mock.patch("ssh_mgmt.os.path.exists", mock.MagicMock(return_value=True)) as mock_exists: + with mock.patch("ssh_mgmt.os.remove") as mock_remove: + with mock.patch("ssh_mgmt.os.makedirs") as mock_makedirs: + with mock.patch("ssh_mgmt.shutil") as mock_shutil: + result = self.ssh_mgmt_module.restore_checkpoint([ + ]) + assert result[0] == 0 + assert result[1] == "Successfully restored checkpoint" + mock_isdir.assert_called_with(CHECKPOINT_DIR) + mock_exists.assert_has_calls([ + mock.call( + os.path.join(CHECKPOINT_DIR, + CA_PUB_KEY_NAME)), + mock.call( + os.path.join(CHECKPOINT_DIR, + CA_PUB_KEY_NAME)), + mock.call( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_KEYS_NAME)), + mock.call( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_KEYS_NAME)), + mock.call( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_USERS_NAME)), + mock.call( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_USERS_NAME)), + ], + any_order=True) + mock_remove.assert_has_calls([ + mock.call(os.path.join( + CA_PUB_KEY_DIR, COPY_TEMP_FILE)), + mock.call(os.path.join( + PERSISTENT_CA_PUB_KEY_DIRS[0], COPY_TEMP_FILE)), + mock.call(os.path.join( + ROOT_AUTHORIZED_KEYS_DIR, COPY_TEMP_FILE)), + mock.call(os.path.join( + PERSISTENT_ROOT_AUTHORIZED_KEYS_DIRS[0], + COPY_TEMP_FILE)), + mock.call(os.path.join( + ROOT_AUTHORIZED_USERS_DIR, COPY_TEMP_FILE)), + mock.call(os.path.join( + PERSISTENT_ROOT_AUTHORIZED_USERS_DIRS[0], + COPY_TEMP_FILE)) + ], + any_order=True) + mock_makedirs.assert_has_calls([ + mock.call(CA_PUB_KEY_DIR, exist_ok=True), + mock.call( + PERSISTENT_CA_PUB_KEY_DIRS[0], exist_ok=True), + mock.call(ROOT_AUTHORIZED_KEYS_DIR, + exist_ok=True), + mock.call( + PERSISTENT_ROOT_AUTHORIZED_KEYS_DIRS[0], exist_ok=True), + mock.call(ROOT_AUTHORIZED_USERS_DIR, + exist_ok=True), + mock.call( + PERSISTENT_ROOT_AUTHORIZED_USERS_DIRS[0], exist_ok=True), + ], + any_order=True) + mock_shutil.copyfile.assert_has_calls([ + mock.call( + os.path.join(CHECKPOINT_DIR, + CA_PUB_KEY_NAME), + os.path.join(CA_PUB_KEY_DIR, COPY_TEMP_FILE)), + mock.call( + os.path.join(CHECKPOINT_DIR, + CA_PUB_KEY_NAME), + os.path.join(PERSISTENT_CA_PUB_KEY_DIRS[0], + COPY_TEMP_FILE)), + mock.call( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_KEYS_NAME), + os.path.join(ROOT_AUTHORIZED_KEYS_DIR, + COPY_TEMP_FILE)), + mock.call( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_KEYS_NAME), + os.path.join( + PERSISTENT_ROOT_AUTHORIZED_KEYS_DIRS[0], + COPY_TEMP_FILE)), + mock.call( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_USERS_NAME), + os.path.join(ROOT_AUTHORIZED_USERS_DIR, + COPY_TEMP_FILE)), + mock.call( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_USERS_NAME), + os.path.join( + PERSISTENT_ROOT_AUTHORIZED_USERS_DIRS[0], + COPY_TEMP_FILE)) + ], + any_order=True) + mock_shutil.move.assert_has_calls([ + mock.call( + os.path.join(CA_PUB_KEY_DIR, + COPY_TEMP_FILE), + os.path.join(CA_PUB_KEY_DIR, CA_PUB_KEY_NAME)), + mock.call( + os.path.join( + PERSISTENT_CA_PUB_KEY_DIRS[0], COPY_TEMP_FILE), + os.path.join(PERSISTENT_CA_PUB_KEY_DIRS[0], + CA_PUB_KEY_NAME)), + mock.call( + os.path.join(ROOT_AUTHORIZED_KEYS_DIR, + COPY_TEMP_FILE), + os.path.join(ROOT_AUTHORIZED_KEYS_DIR, + ROOT_AUTHORIZED_KEYS_NAME)), + mock.call( + os.path.join( + PERSISTENT_ROOT_AUTHORIZED_KEYS_DIRS[0], + COPY_TEMP_FILE), + os.path.join( + PERSISTENT_ROOT_AUTHORIZED_KEYS_DIRS[0], + ROOT_AUTHORIZED_KEYS_NAME)), + mock.call( + os.path.join( + ROOT_AUTHORIZED_USERS_DIR, COPY_TEMP_FILE), + os.path.join(ROOT_AUTHORIZED_USERS_DIR, + ROOT_AUTHORIZED_USERS_NAME)), + mock.call( + os.path.join( + PERSISTENT_ROOT_AUTHORIZED_USERS_DIRS[0], + COPY_TEMP_FILE), + os.path.join( + PERSISTENT_ROOT_AUTHORIZED_USERS_DIRS[0], + ROOT_AUTHORIZED_USERS_NAME)) + ], + any_order=True) + mock_shutil.rmtree.assert_called_with( + CHECKPOINT_DIR) + + # Restore checkpoint succeeds when source files do not exist. + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=True)) as mock_isdir: + with mock.patch("ssh_mgmt.os.path.exists", mock.MagicMock(return_value=False)) as mock_exists: + with mock.patch("ssh_mgmt.shutil") as mock_shutil: + result = self.ssh_mgmt_module.restore_checkpoint([]) + assert result[0] == 0 + assert result[1] == "Successfully restored checkpoint" + mock_isdir.assert_called_with(CHECKPOINT_DIR) + mock_exists.assert_has_calls([ + mock.call( + os.path.join(CHECKPOINT_DIR, + CA_PUB_KEY_NAME)), + mock.call( + os.path.join(CHECKPOINT_DIR, + CA_PUB_KEY_NAME)), + mock.call( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_KEYS_NAME)), + mock.call( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_KEYS_NAME)), + mock.call( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_USERS_NAME)), + mock.call( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_USERS_NAME)), + ], + any_order=True) + mock_shutil.rmtree.assert_called_with( + CHECKPOINT_DIR) + + def mock_copyfile(src, dest): + raise OSError(TEST_EXCEPTION_MESSAGE) + + def mock_rmtree(dir): + raise OSError(TEST_EXCEPTION_MESSAGE) + + # Restore checkpoint fails when copy and delete operations fail. + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=True)) as mock_isdir: + with mock.patch("ssh_mgmt.os.path.exists", mock.MagicMock(return_value=True)): + with mock.patch("ssh_mgmt.os.makedirs") as mock_makedirs: + with mock.patch("ssh_mgmt.shutil") as mock_shutil: + mock_shutil.copyfile = mock_copyfile + mock_shutil.rmtree = mock_rmtree + result = self.ssh_mgmt_module.restore_checkpoint([]) + assert result[0] == 1 + assert result[1] != "Successfully restored checkpoint" + mock_isdir.assert_called_with(CHECKPOINT_DIR) + + def test_delete_checkpoint(self): + # Delete checkpoint fails when checkpoint does not exist. + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=False)) as mock_isdir: + result = self.ssh_mgmt_module.delete_checkpoint([]) + assert result[0] == 1 + assert result[1] == "Checkpoint does not exist" + mock_isdir.assert_called_with(CHECKPOINT_DIR) + + # Delete checkpoint succeeds. + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=True)) as mock_isdir: + with mock.patch("ssh_mgmt.shutil") as mock_shutil: + result = self.ssh_mgmt_module.delete_checkpoint([]) + assert result[0] == 0 + assert result[1] == "Successfully deleted checkpoint" + mock_isdir.assert_called_with(CHECKPOINT_DIR) + mock_shutil.rmtree.assert_called_with( + CHECKPOINT_DIR) + + def mock_rmtree(dir): + raise OSError(TEST_EXCEPTION_MESSAGE) + + # Delete checkpoint fails when delete operation fails. + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=True)) as mock_isdir: + with mock.patch("ssh_mgmt.shutil") as mock_shutil: + mock_shutil.rmtree = mock_rmtree + result = self.ssh_mgmt_module.delete_checkpoint([]) + assert result[0] == 1 + assert result[1] == "Error in deleting checkpoint" + mock_isdir.assert_called_with(CHECKPOINT_DIR) + + def test_set(self): + # Set fails without creating checkpoint. + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=False)) as mock_isdir: + result = self.ssh_mgmt_module.set([""]) + assert result[0] == 1 + assert result[1] == "Update ssh config before creating checkpoint" + mock_isdir.assert_called_with(CHECKPOINT_DIR) + + # Set fails with invalid JSON input. + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=True)) as mock_isdir: + result = self.ssh_mgmt_module.set(["#$%@"]) + assert result[0] == 1 + assert result[1] == "Invalid JSON" + mock_isdir.assert_called_with(CHECKPOINT_DIR) + + # Set succeeds. + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=True)) as mock_isdir: + result = self.ssh_mgmt_module.set(["{}"]) + assert result[0] == 0 + assert result[1] == "Successfully set credentials" + mock_isdir.assert_called_with(CHECKPOINT_DIR) + + # Set succeeds with additional input. + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=True)) as mock_isdir: + result = self.ssh_mgmt_module.set(['{"invalid key":{}}']) + assert result[0] == 0 + assert result[1] == "Successfully set credentials" + mock_isdir.assert_called_with(CHECKPOINT_DIR) + + def test_set_ca_pub_key(self): + f = MockFileHandler() + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=True)) as mock_isdir: + with mock.patch("builtins.open", mock.MagicMock(return_value=f)) as mock_open: + with mock.patch("ssh_mgmt.os.path.exists", mock.MagicMock(return_value=True)) as mock_exists: + with mock.patch("ssh_mgmt.os.remove") as mock_remove: + with mock.patch("ssh_mgmt.os.makedirs") as mock_makedirs: + with mock.patch("ssh_mgmt.shutil") as mock_shutil: + content = {"SshCaPublicKey": [ + "TEST-CERT #1", "TEST-CERT #2"]} + input_data = json.dumps(content) + result = self.ssh_mgmt_module.set([input_data]) + assert result[0] == 0 + assert result[1] == "Successfully set credentials" + mock_open.assert_called_with( + os.path.join(CHECKPOINT_DIR, + CA_PUB_KEY_TEMP), + "w") + assert f.get_contents() == """TEST-CERT #1 +TEST-CERT #2 +""" + mock_exists.assert_has_calls([ + mock.call(os.path.join( + CHECKPOINT_DIR, CA_PUB_KEY_TEMP)), + ], + any_order=True) + mock_remove.assert_has_calls([ + mock.call(os.path.join( + CA_PUB_KEY_DIR, COPY_TEMP_FILE)), + mock.call(os.path.join( + PERSISTENT_CA_PUB_KEY_DIRS[0], COPY_TEMP_FILE)) + ], + any_order=True) + mock_makedirs.assert_has_calls([ + mock.call(CA_PUB_KEY_DIR, exist_ok=True), + mock.call( + PERSISTENT_CA_PUB_KEY_DIRS[0], exist_ok=True), + ], + any_order=True) + mock_shutil.copyfile.assert_has_calls([ + mock.call( + os.path.join(CHECKPOINT_DIR, + CA_PUB_KEY_TEMP), + os.path.join(CA_PUB_KEY_DIR, COPY_TEMP_FILE)), + mock.call( + os.path.join(CHECKPOINT_DIR, + CA_PUB_KEY_TEMP), + os.path.join(PERSISTENT_CA_PUB_KEY_DIRS[0], + COPY_TEMP_FILE)) + ], + any_order=True) + mock_shutil.move.assert_has_calls([ + mock.call( + os.path.join(CA_PUB_KEY_DIR, + COPY_TEMP_FILE), + os.path.join(CA_PUB_KEY_DIR, + CA_PUB_KEY_NAME)), + mock.call( + os.path.join( + PERSISTENT_CA_PUB_KEY_DIRS[0], + COPY_TEMP_FILE), + os.path.join(PERSISTENT_CA_PUB_KEY_DIRS[0], + CA_PUB_KEY_NAME)) + ], + any_order=True) + + def test_set_account_keys(self): + f = MockFileHandler() + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=True)) as mock_isdir: + with mock.patch("builtins.open", mock.MagicMock(return_value=f)) as mock_open: + with mock.patch("ssh_mgmt.os.path.exists", mock.MagicMock(return_value=True)) as mock_exists: + with mock.patch("ssh_mgmt.os.remove") as mock_remove: + with mock.patch("ssh_mgmt.os.makedirs") as mock_makedirs: + with mock.patch("ssh_mgmt.shutil") as mock_shutil: + content = { + "SshAccountKeys": [{ + "account": + "root", + "keys": [{ + "key": + "Authorized-key #1", + "options": [{ + "name": "from", + "value": "*.sales.example.net,!pc.sales.example.net" + }] + }, { + "key": "Authorized-key #2", + "options": [] + }, { + "key": + "Authorized-key #3", + "options": [{ + "name": "from", + "value": "*.sales.example.net,!pc.sales.example.net" + }, { + "name": "no-port-forwarding" + }] + }] + }, { + "account": + "root", + }, { + "account": + "root", + "keys": [{ + }] + }, { + "account": + "non-root", + "keys": [{ + "key": + "Non-root account key" + }] + }] + } + input_data = json.dumps(content) + result = self.ssh_mgmt_module.set([input_data]) + assert result[0] == 0 + assert result[1] == "Successfully set credentials" + builtins.open.assert_called_with( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_KEYS_TEMP), "w") + assert f.get_contents() == """from="*.sales.example.net,!pc.sales.example.net" Authorized-key #1 +Authorized-key #2 +from="*.sales.example.net,!pc.sales.example.net",no-port-forwarding Authorized-key #3 +""" + mock_exists.assert_has_calls([ + mock.call(os.path.join( + CHECKPOINT_DIR, ROOT_AUTHORIZED_KEYS_TEMP)), + ], + any_order=True) + mock_remove.assert_has_calls([ + mock.call(os.path.join( + ROOT_AUTHORIZED_KEYS_DIR, COPY_TEMP_FILE)), + mock.call(os.path.join( + PERSISTENT_ROOT_AUTHORIZED_KEYS_DIRS[0], + COPY_TEMP_FILE)) + ], + any_order=True) + mock_makedirs.assert_has_calls([ + mock.call( + ROOT_AUTHORIZED_KEYS_DIR, exist_ok=True), + mock.call( + PERSISTENT_ROOT_AUTHORIZED_KEYS_DIRS[0], exist_ok=True), + ], + any_order=True) + mock_shutil.copyfile.assert_has_calls([ + mock.call( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_KEYS_TEMP), + os.path.join(ROOT_AUTHORIZED_KEYS_DIR, + COPY_TEMP_FILE)), + mock.call( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_KEYS_TEMP), + os.path.join( + PERSISTENT_ROOT_AUTHORIZED_KEYS_DIRS[0], + COPY_TEMP_FILE)) + ], + any_order=True) + mock_shutil.move.assert_has_calls([ + mock.call( + os.path.join( + ROOT_AUTHORIZED_KEYS_DIR, COPY_TEMP_FILE), + os.path.join(ROOT_AUTHORIZED_KEYS_DIR, + ROOT_AUTHORIZED_KEYS_NAME)), + mock.call( + os.path.join( + PERSISTENT_ROOT_AUTHORIZED_KEYS_DIRS[0], + COPY_TEMP_FILE), + os.path.join( + PERSISTENT_ROOT_AUTHORIZED_KEYS_DIRS[0], + ROOT_AUTHORIZED_KEYS_NAME)) + ], + any_order=True) + + def test_set_account_users(self): + f = MockFileHandler() + with mock.patch("ssh_mgmt.os.path.isdir", mock.MagicMock(return_value=True)) as mock_isdir: + with mock.patch("builtins.open", mock.MagicMock(return_value=f)) as mock_open: + with mock.patch("ssh_mgmt.os.path.exists", mock.MagicMock(return_value=True)) as mock_exists: + with mock.patch("ssh_mgmt.os.remove") as mock_remove: + with mock.patch("ssh_mgmt.os.makedirs") as mock_makedirs: + with mock.patch("ssh_mgmt.shutil") as mock_shutil: + content = { + "SshAccountUsers": [{ + "account": + "root", + "users": [{ + "name": + "alice", + "options": [{ + "name": "from", + "value": "*.sales.example.net,!pc.sales.example.net" + }] + }, { + "name": "bob", + "options": [] + }, { + "name": + "carol", + "options": [{ + "name": "from", + "value": "*.sales.example.net,!pc.sales.example.net" + }, { + "name": "no-port-forwarding", + "value": "" + }, { + "value": "option without name" + }] + }] + }, { + "account": + "root", + }, { + "account": + "root", + "users": [{ + }] + }, { + "account": + "non-root", + "users": [{ + "name": + "non-root-user" + }] + }] + } + input_data = json.dumps(content) + result = self.ssh_mgmt_module.set([input_data]) + assert result[0] == 0 + assert result[1] == "Successfully set credentials" + builtins.open.assert_called_with( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_USERS_TEMP), "w") + assert f.get_contents() == """from="*.sales.example.net,!pc.sales.example.net" alice +bob +from="*.sales.example.net,!pc.sales.example.net",no-port-forwarding carol +""" + mock_exists.assert_has_calls([ + mock.call(os.path.join( + CHECKPOINT_DIR, ROOT_AUTHORIZED_USERS_TEMP)), + ], + any_order=True) + mock_remove.assert_has_calls([ + mock.call(os.path.join( + ROOT_AUTHORIZED_USERS_DIR, COPY_TEMP_FILE)), + mock.call(os.path.join( + PERSISTENT_ROOT_AUTHORIZED_USERS_DIRS[0], + COPY_TEMP_FILE)) + ], + any_order=True) + mock_makedirs.assert_has_calls([ + mock.call( + ROOT_AUTHORIZED_USERS_DIR, exist_ok=True), + mock.call( + PERSISTENT_ROOT_AUTHORIZED_USERS_DIRS[0], exist_ok=True), + ], + any_order=True) + mock_shutil.copyfile.assert_has_calls([ + mock.call( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_USERS_TEMP), + os.path.join(ROOT_AUTHORIZED_USERS_DIR, + COPY_TEMP_FILE)), + mock.call( + os.path.join(CHECKPOINT_DIR, + ROOT_AUTHORIZED_USERS_TEMP), + os.path.join( + PERSISTENT_ROOT_AUTHORIZED_USERS_DIRS[0], + COPY_TEMP_FILE)) + ], + any_order=True) + mock_shutil.move.assert_has_calls([ + mock.call( + os.path.join( + ROOT_AUTHORIZED_USERS_DIR, COPY_TEMP_FILE), + os.path.join(ROOT_AUTHORIZED_USERS_DIR, + ROOT_AUTHORIZED_USERS_NAME)), + mock.call( + os.path.join( + PERSISTENT_ROOT_AUTHORIZED_USERS_DIRS[0], + COPY_TEMP_FILE), + os.path.join( + PERSISTENT_ROOT_AUTHORIZED_USERS_DIRS[0], + ROOT_AUTHORIZED_USERS_NAME)) + ], + any_order=True) + + def test_copy_files_failure(self): + result = self.ssh_mgmt_module._copy_files(["a", "b"], ["a"]) + assert result[0] == 1 + assert result[1] == "Length of src and dest do not match in _copy_files" + + def test_register(self): + result = register() + assert result[0] == SshMgmt + assert result[1] == MOD_NAME + + @classmethod + def teardown_class(cls): + print("TEARDOWN")