improve ssh client class

This commit is contained in:
Mathieu Broillet 2024-10-19 12:02:45 +02:00
parent 59ff04775c
commit 67df62ff91
Signed by: mathieu
GPG Key ID: A08E484FE95074C1

View File

@ -8,71 +8,88 @@ from custom_components.easy_computer_manager.computer import CommandOutput
class SSHClient: class SSHClient:
def __init__(self, host, username, password, port): def __init__(self, host: str, username: str, password: Optional[str] = None, port: int = 22):
self.host = host self.host = host
self.username = username self.username = username
self._password = password self._password = password
self.port = port self.port = port
self._connection = None self._connection: Optional[paramiko.SSHClient] = None
async def __aenter__(self):
await self.connect()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
self.disconnect()
async def connect(self, retried: bool = False, computer: Optional['Computer'] = None) -> None: async def connect(self, retried: bool = False, computer: Optional['Computer'] = None) -> None:
"""Open an SSH connection using Paramiko asynchronously.""" """Open an SSH connection using Paramiko asynchronously."""
self.disconnect() if self.is_connection_alive():
LOGGER.debug(f"Connection to {self.host} is already active.")
return
self.disconnect() # Ensure any previous connection is closed
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
client = paramiko.SSHClient()
# Set missing host key policy to automatically accept unknown host keys
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
try: try:
# Create the SSH client
client = paramiko.SSHClient()
# Set missing host key policy to automatically accept unknown host keys
# client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
# Offload the blocking connect call to a thread # Offload the blocking connect call to a thread
await loop.run_in_executor(None, self._blocking_connect, client) await loop.run_in_executor(None, self._blocking_connect, client)
self._connection = client self._connection = client
LOGGER.debug(f"Connected to {self.host}")
except (OSError, paramiko.SSHException) as exc: except (OSError, paramiko.SSHException) as exc:
if retried: LOGGER.debug(f"Failed to connect to {self.host}: {exc}")
await self.connect(retried=True) if not retried:
else: LOGGER.debug(f"Retrying connection to {self.host}...")
LOGGER.debug(f"Failed to connect to {self.host}: {exc}") await self.connect(retried=True) # Retry only once
finally: finally:
if computer is not None: if computer is not None and hasattr(computer, "initialized"):
if hasattr(computer, "initialized"): computer.initialized = True
computer.initialized = True
def disconnect(self) -> None: def disconnect(self) -> None:
"""Close the SSH connection.""" """Close the SSH connection."""
if self._connection is not None: if self._connection:
self._connection.close() self._connection.close()
self._connection = None LOGGER.debug(f"Disconnected from {self.host}")
self._connection = None
def _blocking_connect(self, client): def _blocking_connect(self, client: paramiko.SSHClient):
"""Perform the blocking SSH connection using Paramiko.""" """Perform the blocking SSH connection using Paramiko."""
client.connect( client.connect(
self.host, hostname=self.host,
username=self.username, username=self.username,
password=self._password, password=self._password,
port=self.port port=self.port,
look_for_keys=False, # Set this to True if using private keys
allow_agent=False
) )
async def execute_command(self, command: str) -> CommandOutput: async def execute_command(self, command: str) -> CommandOutput:
"""Execute a command on the SSH server asynchronously.""" """Execute a command on the SSH server asynchronously."""
try: if not self.is_connection_alive():
stdin, stdout, stderr = self._connection.exec_command(command) LOGGER.debug(f"Connection to {self.host} is not alive. Reconnecting...")
exit_status = stdout.channel.recv_exit_status() await self.connect()
try:
# Offload command execution to avoid blocking
loop = asyncio.get_running_loop()
stdin, stdout, stderr = await loop.run_in_executor(None, self._connection.exec_command, command)
exit_status = stdout.channel.recv_exit_status()
return CommandOutput(command, exit_status, stdout.read().decode(), stderr.read().decode()) return CommandOutput(command, exit_status, stdout.read().decode(), stderr.read().decode())
except (paramiko.SSHException, EOFError) as exc: except (paramiko.SSHException, EOFError) as exc:
LOGGER.debug(f"Failed to execute command on {self.host}: {exc}") LOGGER.error(f"Failed to execute command on {self.host}: {exc}")
return CommandOutput(command, -1, "", "") return CommandOutput(command, -1, "", "")
def is_connection_alive(self) -> bool: def is_connection_alive(self) -> bool:
"""Check if the connection is still alive asynchronously.""" """Check if the SSH connection is still alive."""
# use the code below if is_active() returns True
if self._connection is None: if self._connection is None:
return False return False
@ -83,5 +100,5 @@ class SSHClient:
self._connection.exec_command('ls', timeout=1) self._connection.exec_command('ls', timeout=1)
return True return True
except Exception: except Exception as e:
return False return False