improve ssh client class
This commit is contained in:
parent
59ff04775c
commit
67df62ff91
@ -8,71 +8,88 @@ from custom_components.easy_computer_manager.computer import CommandOutput
|
|||||||
|
|
||||||
|
|
||||||
class SSHClient:
|
class SSHClient:
|
||||||
def __init__(self, host, username, password, port):
|
def __init__(self, host: str, username: str, password: Optional[str] = None, port: int = 22):
|
||||||
self.host = host
|
self.host = host
|
||||||
self.username = username
|
self.username = username
|
||||||
self._password = password
|
self._password = password
|
||||||
self.port = port
|
self.port = port
|
||||||
self._connection = None
|
self._connection: Optional[paramiko.SSHClient] = None
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
await self.connect()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.disconnect()
|
||||||
|
|
||||||
async def connect(self, retried: bool = False, computer: Optional['Computer'] = None) -> None:
|
async def connect(self, retried: bool = False, computer: Optional['Computer'] = None) -> None:
|
||||||
"""Open an SSH connection using Paramiko asynchronously."""
|
"""Open an SSH connection using Paramiko asynchronously."""
|
||||||
self.disconnect()
|
if self.is_connection_alive():
|
||||||
|
LOGGER.debug(f"Connection to {self.host} is already active.")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.disconnect() # Ensure any previous connection is closed
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
try:
|
|
||||||
# Create the SSH client
|
|
||||||
client = paramiko.SSHClient()
|
client = paramiko.SSHClient()
|
||||||
|
|
||||||
# Set missing host key policy to automatically accept unknown host keys
|
# Set missing host key policy to automatically accept unknown host keys
|
||||||
# client.load_system_host_keys()
|
|
||||||
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||||
|
|
||||||
|
try:
|
||||||
# Offload the blocking connect call to a thread
|
# Offload the blocking connect call to a thread
|
||||||
await loop.run_in_executor(None, self._blocking_connect, client)
|
await loop.run_in_executor(None, self._blocking_connect, client)
|
||||||
self._connection = client
|
self._connection = client
|
||||||
|
LOGGER.debug(f"Connected to {self.host}")
|
||||||
|
|
||||||
except (OSError, paramiko.SSHException) as exc:
|
except (OSError, paramiko.SSHException) as exc:
|
||||||
if retried:
|
|
||||||
await self.connect(retried=True)
|
|
||||||
else:
|
|
||||||
LOGGER.debug(f"Failed to connect to {self.host}: {exc}")
|
LOGGER.debug(f"Failed to connect to {self.host}: {exc}")
|
||||||
|
if not retried:
|
||||||
|
LOGGER.debug(f"Retrying connection to {self.host}...")
|
||||||
|
await self.connect(retried=True) # Retry only once
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if computer is not None:
|
if computer is not None and hasattr(computer, "initialized"):
|
||||||
if hasattr(computer, "initialized"):
|
|
||||||
computer.initialized = True
|
computer.initialized = True
|
||||||
|
|
||||||
def disconnect(self) -> None:
|
def disconnect(self) -> None:
|
||||||
"""Close the SSH connection."""
|
"""Close the SSH connection."""
|
||||||
if self._connection is not None:
|
if self._connection:
|
||||||
self._connection.close()
|
self._connection.close()
|
||||||
|
LOGGER.debug(f"Disconnected from {self.host}")
|
||||||
self._connection = None
|
self._connection = None
|
||||||
|
|
||||||
def _blocking_connect(self, client):
|
def _blocking_connect(self, client: paramiko.SSHClient):
|
||||||
"""Perform the blocking SSH connection using Paramiko."""
|
"""Perform the blocking SSH connection using Paramiko."""
|
||||||
client.connect(
|
client.connect(
|
||||||
self.host,
|
hostname=self.host,
|
||||||
username=self.username,
|
username=self.username,
|
||||||
password=self._password,
|
password=self._password,
|
||||||
port=self.port
|
port=self.port,
|
||||||
|
look_for_keys=False, # Set this to True if using private keys
|
||||||
|
allow_agent=False
|
||||||
)
|
)
|
||||||
|
|
||||||
async def execute_command(self, command: str) -> CommandOutput:
|
async def execute_command(self, command: str) -> CommandOutput:
|
||||||
"""Execute a command on the SSH server asynchronously."""
|
"""Execute a command on the SSH server asynchronously."""
|
||||||
try:
|
if not self.is_connection_alive():
|
||||||
stdin, stdout, stderr = self._connection.exec_command(command)
|
LOGGER.debug(f"Connection to {self.host} is not alive. Reconnecting...")
|
||||||
exit_status = stdout.channel.recv_exit_status()
|
await self.connect()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Offload command execution to avoid blocking
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
stdin, stdout, stderr = await loop.run_in_executor(None, self._connection.exec_command, command)
|
||||||
|
|
||||||
|
exit_status = stdout.channel.recv_exit_status()
|
||||||
return CommandOutput(command, exit_status, stdout.read().decode(), stderr.read().decode())
|
return CommandOutput(command, exit_status, stdout.read().decode(), stderr.read().decode())
|
||||||
|
|
||||||
except (paramiko.SSHException, EOFError) as exc:
|
except (paramiko.SSHException, EOFError) as exc:
|
||||||
LOGGER.debug(f"Failed to execute command on {self.host}: {exc}")
|
LOGGER.error(f"Failed to execute command on {self.host}: {exc}")
|
||||||
return CommandOutput(command, -1, "", "")
|
return CommandOutput(command, -1, "", "")
|
||||||
|
|
||||||
def is_connection_alive(self) -> bool:
|
def is_connection_alive(self) -> bool:
|
||||||
"""Check if the connection is still alive asynchronously."""
|
"""Check if the SSH connection is still alive."""
|
||||||
# use the code below if is_active() returns True
|
|
||||||
if self._connection is None:
|
if self._connection is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -83,5 +100,5 @@ class SSHClient:
|
|||||||
self._connection.exec_command('ls', timeout=1)
|
self._connection.exec_command('ls', timeout=1)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception:
|
except Exception as e:
|
||||||
return False
|
return False
|
||||||
|
Loading…
Reference in New Issue
Block a user