diff --git a/custom_components/easy_computer_manager/computer.py b/custom_components/easy_computer_manager/computer.py index f8a5039..c8bdaac 100644 --- a/custom_components/easy_computer_manager/computer.py +++ b/custom_components/easy_computer_manager/computer.py @@ -1,15 +1,16 @@ import subprocess as sp -import paramiko +import fabric2 import wakeonlan +from fabric2 import Connection from custom_components.easy_computer_manager import const, _LOGGER class OSType: - WINDOWS = "Windows" - LINUX = "Linux" - MACOS = "MacOS" + WINDOWS = "windows" + LINUX = "linux" + MACOS = "macos" class Computer: @@ -30,18 +31,31 @@ class Computer: self._audio_config = None self._bluetooth_devices = None + self._connection = None + self.setup() - async def _open_ssh_connection(self): + def _open_ssh_connection(self) -> Connection: """Open an SSH connection.""" - client = paramiko.SSHClient() - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - client.connect(self.host, port=self._port, username=self._username, password=self._password) + conf = fabric2.Config() + conf.run.hide = True + conf.run.warn = True + conf.warn = True + conf.sudo.password = self._password + conf.password = self._password + + client = Connection( + host=self.host, user=self._username, port=self._port, connect_timeout=3, + connect_kwargs={"password": self._password}, + config=conf + ) + + self._connection = client return client - def _close_ssh_connection(self, client): + def _close_ssh_connection(self, client: Connection) -> None: """Close the SSH connection.""" - if client: + if client.is_connected: client.close() def setup(self): @@ -49,8 +63,7 @@ class Computer: client = self._open_ssh_connection() # TODO: run commands here - self._operating_system = OSType.LINUX if self.run_manually( - "uname").return_code == 0 else OSType.WINDOWS # TODO: improve this + self._operating_system = OSType.LINUX # if self.run_manually("uname").return_code == 0 else OSType.WINDOWS # TODO: improve this self._operating_system_version = self.run_action("operating_system_version").output self._windows_entry_grub = self.run_action("get_windows_entry_grub").output self._monitors_config = {} @@ -134,49 +147,53 @@ class Computer: def exit_steam_big_picture(self) -> None: pass - def run_action(self, command: str, params=None) -> {}: + def run_action(self, action: str, params=None) -> dict: """Run a command via SSH. Opens a new connection for each command.""" if params is None: params = {} - if command not in const.COMMANDS: - _LOGGER.error(f"Invalid command: {command}") - return - command_template = const.COMMANDS[command] + if action not in const.COMMANDS: + _LOGGER.error(f"Invalid action: {action}") + return {} + + command_template = const.COMMANDS[action] # Check if the command has the required parameters if "params" in command_template: - if sorted(command_template.params) != sorted(params.keys()): + if sorted(command_template["params"]) != sorted(params.keys()): raise ValueError("Invalid parameters") # Check if the command is available for the operating system match self._operating_system: case OSType.WINDOWS: - command = command_template[OSType.WINDOWS] + commands = command_template[OSType.WINDOWS] case OSType.LINUX: - command = command_template[OSType.LINUX] + commands = command_template[OSType.LINUX] case _: raise ValueError("Invalid operating system") - # Replace the parameters in the command - for param in params: - command = command.replace(f"%{param}%", params[param]) + for command in commands: + # Replace the parameters in the command + for param, value in params.items(): + command = command.replace(f"%{param}%", value) + + result = self.run_manually(command) + + if result['return_code'] == 0: + _LOGGER.debug(f"Command successful: {command}") + return result + else: + _LOGGER.debug(f"Command failed: {command}") + + return {} + + def run_manually(self, command: str) -> dict: + """Run a command manually (not from predefined commands).""" # Open SSH connection, execute command, and close connection client = self._open_ssh_connection() - stdin, stdout, stderr = client.exec_command(command) - print(stdout.read().decode()) # Print the command output for debugging + result = client.run(command) self._close_ssh_connection(client) - return {"output": stdout.read().decode(), "error": stderr.read().decode(), - "return_code": stdout.channel.recv_exit_status()} - - def run_manually(self, command: str) -> {}: - """Run a command manually (not from predefined commands).""" - client = self._open_ssh_connection() - stdin, stdout, stderr = client.exec_command(command) - print(stdout.read().decode()) # Print the command output for debugging - self._close_ssh_connection(client) - - return {"output": stdout.read().decode(), "error": stderr.read().decode(), - "return_code": stdout.channel.recv_exit_status()} + return {"output": result.stdout, "error": result.stderr, + "return_code": result.return_code}