From 67df62ff919175d5e75e9395c902a0d9c8e15a9b Mon Sep 17 00:00:00 2001 From: Mathieu Broillet Date: Sat, 19 Oct 2024 12:02:45 +0200 Subject: [PATCH] improve ssh client class --- .../computer/ssh_client.py | 75 ++++++++++++------- 1 file changed, 46 insertions(+), 29 deletions(-) diff --git a/custom_components/easy_computer_manager/computer/ssh_client.py b/custom_components/easy_computer_manager/computer/ssh_client.py index 9b15a2c..2e91fcd 100644 --- a/custom_components/easy_computer_manager/computer/ssh_client.py +++ b/custom_components/easy_computer_manager/computer/ssh_client.py @@ -8,71 +8,88 @@ from custom_components.easy_computer_manager.computer import CommandOutput class SSHClient: - def __init__(self, host, username, password, port): + def __init__(self, host: str, username: str, password: Optional[str] = None, port: int = 22): self.host = host self.username = username self._password = password self.port = port - self._connection = None + self._connection: Optional[paramiko.SSHClient] = None + + async def __aenter__(self): + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + self.disconnect() async def connect(self, retried: bool = False, computer: Optional['Computer'] = None) -> None: """Open an SSH connection using Paramiko asynchronously.""" - self.disconnect() + if self.is_connection_alive(): + LOGGER.debug(f"Connection to {self.host} is already active.") + return + + self.disconnect() # Ensure any previous connection is closed loop = asyncio.get_running_loop() + client = paramiko.SSHClient() + + # Set missing host key policy to automatically accept unknown host keys + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) try: - # Create the SSH client - client = paramiko.SSHClient() - - # Set missing host key policy to automatically accept unknown host keys - # client.load_system_host_keys() - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - # Offload the blocking connect call to a thread await loop.run_in_executor(None, self._blocking_connect, client) self._connection = client + LOGGER.debug(f"Connected to {self.host}") except (OSError, paramiko.SSHException) as exc: - if retried: - await self.connect(retried=True) - else: - LOGGER.debug(f"Failed to connect to {self.host}: {exc}") + LOGGER.debug(f"Failed to connect to {self.host}: {exc}") + if not retried: + LOGGER.debug(f"Retrying connection to {self.host}...") + await self.connect(retried=True) # Retry only once + finally: - if computer is not None: - if hasattr(computer, "initialized"): - computer.initialized = True + if computer is not None and hasattr(computer, "initialized"): + computer.initialized = True def disconnect(self) -> None: """Close the SSH connection.""" - if self._connection is not None: + if self._connection: self._connection.close() - self._connection = None + LOGGER.debug(f"Disconnected from {self.host}") + self._connection = None - def _blocking_connect(self, client): + def _blocking_connect(self, client: paramiko.SSHClient): """Perform the blocking SSH connection using Paramiko.""" client.connect( - self.host, + hostname=self.host, username=self.username, password=self._password, - port=self.port + port=self.port, + look_for_keys=False, # Set this to True if using private keys + allow_agent=False ) async def execute_command(self, command: str) -> CommandOutput: """Execute a command on the SSH server asynchronously.""" - try: - stdin, stdout, stderr = self._connection.exec_command(command) - exit_status = stdout.channel.recv_exit_status() + if not self.is_connection_alive(): + LOGGER.debug(f"Connection to {self.host} is not alive. Reconnecting...") + await self.connect() + try: + # Offload command execution to avoid blocking + loop = asyncio.get_running_loop() + stdin, stdout, stderr = await loop.run_in_executor(None, self._connection.exec_command, command) + + exit_status = stdout.channel.recv_exit_status() return CommandOutput(command, exit_status, stdout.read().decode(), stderr.read().decode()) except (paramiko.SSHException, EOFError) as exc: - LOGGER.debug(f"Failed to execute command on {self.host}: {exc}") + LOGGER.error(f"Failed to execute command on {self.host}: {exc}") return CommandOutput(command, -1, "", "") def is_connection_alive(self) -> bool: - """Check if the connection is still alive asynchronously.""" - # use the code below if is_active() returns True + """Check if the SSH connection is still alive.""" if self._connection is None: return False @@ -83,5 +100,5 @@ class SSHClient: self._connection.exec_command('ls', timeout=1) return True - except Exception: + except Exception as e: return False