diff --git a/core/screen.py b/core/screen.py new file mode 100644 index 0000000..f41b297 --- /dev/null +++ b/core/screen.py @@ -0,0 +1,76 @@ +# Taken from https://gitlab.com/jvadair/pyscreen +# All credit goes to the original author, jvadair + +__all__ = ['Screen', 'create', 'kill', 'exists', 'ls'] + +import os +import signal +import subprocess +from datetime import datetime + + +class ScreenNotFound(Exception): + pass + + +class Screen: + def __init__(self, pid, kill_on_exit=False): + """ + :param pid: The process ID of the screen + Creates a new Screen object + """ + if not exists(pid): + raise ScreenNotFound(f"No screen with pid {pid}") + self.pid = pid + self.kill_on_exit = kill_on_exit + + def __del__(self): # Destroy screen process when object deleted + if self.kill_on_exit: + kill(self.pid) + + def send(self, command: str, end="\r") -> None: + """ + :param command: The command to be run + :param end: Appended to the end of the command - the default value is a carriage return + """ + os.system(f'screen -S {self.pid} -X stuff {command}{end}') + + +def create(name, shell=os.environ['SHELL'], logfile=None, title=None) -> Screen: + command = ["screen", "-DmS", name, '-s', shell] + if logfile: + command.append('-Logfile') + command.append(logfile) + if title: + command.append('-t') + command.append(title) + process = subprocess.Popen(command) + while not process.pid: pass + return Screen(process.pid) + + +def kill(pid): + os.kill(pid, signal.SIGTERM) + + +def exists(pid: int) -> int: + command = f"screen -S {str(pid)} -Q select .".split() + pop = subprocess.Popen(command, stdout=subprocess.DEVNULL) + pop.wait() + code = pop.returncode + return False if code else True + + +def ls() -> dict: + out = subprocess.check_output(["screen", "-ls"]).decode() + out = out.replace('\t', '') + out = out.split('\n') + out = out[1:len(out) - 2] + out = [i.replace(")", "").split("(") for i in out] + final = {} + for item in out: + process = item[0].split('.') + pid = process[0] + final[pid] = {'name': process[1]} # final: {pid:{...}, ... } + final[pid]['time'] = datetime.strptime(item[1], '%m/%d/%Y %X %p') + return final diff --git a/core/stack.py b/core/stack.py index 12c192d..8aaebef 100644 --- a/core/stack.py +++ b/core/stack.py @@ -1,75 +1,95 @@ -import logging import os import shutil import subprocess -import time +from pathlib import Path +from typing import Union, List, Optional import psutil -from core import utils, config +from core import utils, screen from core.vars import logger, PYTHON_EXEC from ui import choices -def find_correct_pid(parent_pid: int) -> int: - processes: list[psutil.Process] = [psutil.Process(parent_pid)] - create_time = processes[0].create_time() - - time.sleep(0.5) # Wait for child processes to spawn - - for i in range(1, 10): - if psutil.pid_exists(processes[0].pid + i): - child_process = psutil.Process(processes[0].pid + i) - if child_process.create_time() - create_time < 1: - processes.append(psutil.Process(processes[0].pid + i)) - else: - time.sleep(0.5 / i) - - return processes[-1].pid - - class Stack: def __init__(self, name: str, id: str, port: int, url: str): + """ + Initialize the Stack instance. + + :param name: The name of the stack + :param id: A unique identifier for the stack + :param port: The port the stack service uses + :param url: The URL associated with the stack + """ self.name = name self.id = id - self.path = os.path.join(os.path.expanduser("~"), ".ai-suite-rocm", id) - + self.path = Path.home() / ".ai-suite-rocm" / id self.url = url self.port = port + self.pid_file = self.path / f".pid" + self.pid = self.read_pid() - self.pid = config.get(f"{self.name}-pid") + def read_pid(self) -> Optional[int]: + """ + Read the PID from the PID file, if it exists. - def install(self): + :return: The PID as an integer, or None if the file does not exist or an error occurs. + """ + if self.pid_file.exists(): + return int(self.pid_file.read_text()) + else: + return None + + def write_pid(self, pid: int) -> None: + """ + Write the PID to the PID file. + + :param pid: Process ID to write + """ + with self.pid_file.open('w') as f: + f.write(str(pid)) + + def remove_pid_file(self) -> None: + """Remove the PID file if it exists.""" + if self.pid_file.exists(): + self.pid_file.unlink() + + def install(self) -> None: + """Install the stack, creating the virtual environment and performing initial setup.""" if self.is_installed(): self.update() else: self.check_for_broken_install() + self.create_dir('') self.create_venv() self._install() self.create_file('.installed', 'true') logger.info(f"Installed {self.name}") - def _install(self): + def _install(self) -> None: + """Additional installation steps specific to the stack (override as needed).""" pass - def is_installed(self): + def is_installed(self) -> bool: + """ + Check if the stack is installed by verifying the existence of the '.installed' file. + + :return: True if the stack is installed, False otherwise. + """ return self.file_exists('.installed') - def check_for_broken_install(self): - if not self.is_installed(): - if os.path.exists(self.path): - if len(os.listdir(self.path)) > 0: - logger.warning("Found files from a previous/borked/crashed installation, cleaning up...") - self.bash(f"rm -rf {self.path}") - self.create_dir('') - else: - self.create_dir('') + def check_for_broken_install(self) -> None: + """Check for a broken installation and clean up any leftover files.""" + if not self.is_installed() and self.path.exists(): + logger.warning("Found files from a previous/broken/crashed installation, cleaning up...") + self.remove_dir('') - def update(self, folder: str = 'webui'): + def update(self, folder: str = 'webui') -> None: + """Update the stack by pulling the latest changes from the repository.""" if self.is_installed(): - status = self.status() - if status: + was_running = self.status() + if was_running: self.stop() logger.info(f"Updating {self.name}") @@ -78,61 +98,69 @@ class Stack: self._update() utils.create_symlinks(symlinks) - if status: + if was_running: self.start() else: logger.warning(f"Could not update {self.name} as {self.name} is not installed") - def _update(self): + def _update(self) -> None: + """Additional update steps specific to the stack (override as needed).""" pass - def uninstall(self): + def uninstall(self) -> None: + """Uninstall the stack by stopping it and removing its files.""" logger.info(f"Uninstalling {self.name}") if self.status(): self.stop() self.bash(f"rm -rf {self.path}") + self.remove_pid_file() - def start(self): + def start(self) -> None: + """Start the stack service.""" if self.status(): logger.warning(f"{self.name} is already running") + return if self.is_installed(): self._start() else: logger.error(f"{self.name} is not installed") - def _start(self): + def _start(self) -> None: + """Additional start steps specific to the stack (override as needed).""" pass - def stop(self): + def stop(self) -> None: + """Stop the stack service by terminating the associated process.""" if self.status(): - logger.debug(f"Killing {self.name} with PID: {self.pid}") - psutil.Process(self.pid).kill() + logger.debug(f"Stopping {self.name} with PID: {self.pid}") + try: + proc = psutil.Process(self.pid) + proc.terminate() # Graceful shutdown + proc.wait(timeout=5) + except (psutil.NoSuchProcess, psutil.TimeoutExpired): + logger.warning(f"{self.name} did not terminate gracefully, forcing kill") + psutil.Process(self.pid).kill() + self.remove_pid_file() else: logger.warning(f"{self.name} is not running") - self.set_pid(None) - - def set_pid(self, pid): - self.pid = pid - if pid is not None: - config.put(f"{self.name}-pid", pid) - else: - config.remove(f"{self.name}-pid") - - def restart(self): + def restart(self) -> None: + """Restart the stack service.""" self.stop() self.start() def status(self) -> bool: - if self.pid is None: - return False + """ + Check if the stack service is running. - return psutil.pid_exists(self.pid) + :return: True if the service is running, False otherwise. + """ + return self.pid is not None and psutil.pid_exists(self.pid) - # Python/Bash utils - def create_venv(self): - venv_path = os.path.join(self.path, 'venv') + def create_venv(self) -> None: + """Create a Python virtual environment for the stack.""" + venv_path = self.path / 'venv' if not self.has_venv(): logger.info(f"Creating venv for {self.name}") self.bash(f"{PYTHON_EXEC} -m venv {venv_path} --system-site-packages") @@ -141,9 +169,16 @@ class Stack: logger.debug(f"Venv already exists for {self.name}") def has_venv(self) -> bool: - return self.dir_exists('venv') + """ + Check if the virtual environment exists. - def pip_install(self, package: str | list, no_deps: bool = False, env=[], args=[]): + :return: True if the virtual environment exists, False otherwise. + """ + return (self.path / 'venv').exists() + + def pip_install(self, package: Union[str, List[str]], no_deps: bool = False, env: List[str] = [], + args: List[str] = []) -> None: + """Install a Python package or list of packages using pip.""" if no_deps: args.append("--no-deps") @@ -155,115 +190,125 @@ class Stack: logger.info(f"Installing {package}") self.pip(f"install -U {package}", env=env, args=args) - def install_requirements(self, filename: str = 'requirements.txt', env=[]): + def install_requirements(self, filename: str = 'requirements.txt', env: List[str] = []) -> None: + """Install requirements from a given file.""" logger.info(f"Installing requirements for {self.name} ({filename})") self.pip(f"install -r {filename}", env=env) - def pip(self, cmd: str, env=[], args=[], current_dir: str = None): + def pip(self, cmd: str, env: List[str] = [], args: List[str] = [], current_dir: Optional[Path] = None) -> None: + """Run pip with a given command.""" self.python(f"-m pip {cmd}", env=env, args=args, current_dir=current_dir) - def python(self, cmd: str, env=[], args=[], current_dir: str = None, daemon: bool = False): - self.bash(f"{' '.join(env)} {self.path}/venv/bin/python {cmd} {' '.join(args)}", current_dir, daemon) + def python(self, cmd: str, env: List[str] = [], args: List[str] = [], current_dir: Optional[Path] = None, + daemon: bool = False) -> None: + """Run a Python command inside the stack's virtual environment.""" + self.bash(f"{' '.join(env)} {self.path / 'venv' / 'bin' / 'python'} {cmd} {' '.join(args)}", current_dir, + daemon) - def bash(self, cmd: str, current_dir: str = None, daemon: bool = False): - cmd = f"cd {self.path if current_dir is None else os.path.join(self.path, current_dir)} && {cmd}" + def bash(self, cmd: str, current_dir: Optional[Path] = None, daemon: bool = False) -> None: + """Run a bash command, optionally as a daemon.""" + full_cmd = f"cd {current_dir or self.path} && {cmd}" if daemon: if self.status(): choice = choices.already_running.ask() - if choice is True: self.stop() self._start() return else: - # TODO: attach to subprocess / redirect logs? return else: - logger.debug(f"Running command as daemon: {cmd}") - cmd = f"{cmd} &" - process = subprocess.Popen(cmd, shell=True, preexec_fn=os.setpgrp, - stdout=config.open_file(f"{self.id}-stdout"), - stderr=config.open_file(f"{self.id}-stderr")) - self.set_pid(find_correct_pid(process.pid)) + logger.debug(f"Running command as daemon: {full_cmd}") + + # process = subprocess.Popen(full_cmd, shell=True, preexec_fn=os.setpgrp, + # stdout=config.open_file(f"{self.id}-stdout"), + # stderr=config.open_file(f"{self.id}-stderr")) + screen_session = screen.create(name=self.id) + screen_session.send(f"'{full_cmd}'") + self.write_pid(screen_session.pid) return else: - logger.debug(f"Running command: {cmd}") + logger.debug(f"Running command: {full_cmd}") - if logger.level == logging.DEBUG: - process = subprocess.Popen(cmd, shell=True) - process.wait() - if process.returncode != 0: - raise Exception(f"Failed to run command: {cmd}") - else: - process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - out, err = process.communicate() + process = subprocess.Popen(full_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = process.communicate() - if process.returncode != 0: - logger.fatal(f"Failed to run command: {cmd}") - logger.fatal(f"Error: {err.decode('utf-8')}") - logger.fatal(f"Output: {out.decode('utf-8')}") - raise Exception(f"Failed to run command: {cmd}") + if process.returncode != 0: + logger.fatal(f"Failed to run command: {full_cmd}") + logger.fatal(f"Error: {err.decode('utf-8')}") + logger.fatal(f"Output: {out.decode('utf-8')}") + raise Exception(f"Failed to run command: {full_cmd}") - # Git utils - def git_clone(self, url: str, branch: str = None, dest: str = None): + def git_clone(self, url: str, branch: Optional[str] = None, dest: Optional[Path] = None) -> None: + """Clone a git repository.""" logger.info(f"Cloning {url}") - self.bash(f"git clone {f"-b {branch}" if branch is not None else ''} {url} {'' if dest is None else dest}") + self.bash(f"git clone {f'-b {branch}' if branch else ''} {url} {dest or ''}") - def git_pull(self, repo_folder: str, force: bool = False): - self.bash(f"git reset --hard HEAD {'&& git clean -f -d' if force else ''} && git pull", repo_folder) + def git_pull(self, repo_folder: str, force: bool = False) -> None: + """Pull changes from a git repository.""" + self.bash(f"git reset --hard HEAD {'&& git clean -f -d' if force else ''} && git pull", Path(repo_folder)) - # Prebuilt utils def install_from_prebuilt(self, name): for prebuilt in utils.get_prebuilts(): if prebuilt['name'].split("-")[0] == name: self.pip(f"install {prebuilt['browser_download_url']}") return - # File utils - def create_file(self, name, content): - with open(os.path.join(self.path, name), 'w') as f: - f.write(content) + def create_file(self, name: str, content: str) -> None: + """Create a file with the given content.""" + (self.path / name).write_text(content) - def create_dir(self, name): - if name == '': - logger.info(f"Creating directory for {self.name}") + def create_dir(self, name: str) -> None: + """Create a directory.""" + dir_path = self.path / name + logger.debug(f"Creating directory {dir_path}") + dir_path.mkdir(parents=True, exist_ok=True) - logger.debug(f"Creating directory {name}") - os.makedirs(os.path.join(self.path, name), exist_ok=True) - - def remove_file(self, name): + def remove_file(self, name: str) -> None: + """Remove a file.""" logger.debug(f"Removing file {name}") os.remove(os.path.join(self.path, name)) - def remove_dir(self, name): - logger.debug(f"Removing directory {name}") - shutil.rmtree(os.path.join(self.path, name)) + def remove_dir(self, name: str) -> None: + """Remove a directory.""" + logger.debug(f"Removing directory {name or self.path}") + if not name: + shutil.rmtree(self.path) + else: + shutil.rmtree(os.path.join(self.path, name)) - def move_file_or_dir(self, src, dest): + def move_file_or_dir(self, src: str, dest: str) -> None: + """Move a file or directory.""" logger.debug(f"Moving file/dir {src} to {dest}") os.rename(os.path.join(self.path, src), os.path.join(self.path, dest)) - def move_all_files_in_dir(self, src, dest): + def move_all_files_in_dir(self, src: str, dest: str) -> None: + """Move all files in a directory to another directory""" logger.debug(f"Moving all files in directory {src} to {dest}") for file in os.listdir(os.path.join(self.path, src)): os.rename(os.path.join(self.path, src, file), os.path.join(self.path, dest, file)) - def file_exists(self, name): - return os.path.exists(os.path.join(self.path, name)) + def file_exists(self, name: str) -> bool: + """Check if a file exists.""" + return (self.path / name).exists() - def dir_exists(self, name): - return os.path.exists(os.path.join(self.path, name)) + def dir_exists(self, name: str) -> bool: + """Check if a directory exists.""" + return (self.path / name).exists() def remove_line_in_file(self, contains: str | list, file: str): - logger.debug(f"Removing lines containing {contains} in {file}") - + """Remove lines containing a specific string from a file.""" + target_file = self.path / file + logger.debug(f"Removing lines containing {contains} in {target_file}") if isinstance(contains, list): for c in contains: - self.bash(f"sed -i '/{c}/d' {file}") + self.bash(f"sed -i '/{c}/d' {target_file}") else: - self.bash(f"sed -i '/{contains}/d' {file}") + self.bash(f"sed -i '/{contains}/d' {target_file}") def replace_line_in_file(self, match: str, replace: str, file: str): - logger.debug(f"Replacing lines containing {match} with {replace} in {file}") - self.bash(f"sed -i 's/{match}/{replace}/g' {file}") + """Replace lines containing a specific string in a file.""" + target_file = self.path / file + logger.debug(f"Replacing lines containing {match} with {replace} in {target_file}") + self.bash(f"sed -i 's/{match}/{replace}/g' {target_file}") diff --git a/core/utils.py b/core/utils.py index 606aa4f..8151d6a 100644 --- a/core/utils.py +++ b/core/utils.py @@ -4,6 +4,7 @@ import os import shutil import subprocess from pathlib import Path +from typing import List, Dict, Tuple, Union from urllib import request, error from core.stack import Stack @@ -11,7 +12,15 @@ from core.vars import ROCM_VERSION, logger def get_prebuilts(repo_owner: str = "M4TH1EU", repo_name: str = "ai-suite-rocm-local", - release_tag: str = f"prebuilt-whl-{ROCM_VERSION}") -> list: + release_tag: str = f"prebuilt-whl-{ROCM_VERSION}") -> List[dict]: + """ + Fetch prebuilt assets from a GitHub release using the GitHub API. + + :param repo_owner: GitHub repository owner + :param repo_name: GitHub repository name + :param release_tag: Release tag for fetching assets + :return: List of assets (dictionaries) from the GitHub release + """ api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/tags/{release_tag}" try: @@ -21,7 +30,6 @@ def get_prebuilts(repo_owner: str = "M4TH1EU", repo_name: str = "ai-suite-rocm-l return [] release_data = json.load(response) - assets = release_data.get('assets', []) if not assets: logger.error("No assets found in release data") @@ -31,34 +39,48 @@ def get_prebuilts(repo_owner: str = "M4TH1EU", repo_name: str = "ai-suite-rocm-l except error.URLError as e: logger.error(f"Error fetching release data: {e}") + return [] -def check_for_build_essentials(): +def check_for_build_essentials() -> None: + """ + Check if build essentials like `build-essential` and `python3.10-dev` are installed. + Raises a warning if they are missing. + """ logger.debug("Checking for build essentials...") - debian = os.path.exists('/etc/debian_version') - fedora = os.path.exists('/etc/fedora-release') + debian = Path('/etc/debian_version').exists() + fedora = Path('/etc/fedora-release').exists() if debian: - # TODO: check if these work for debian users check_gcc = run_command("dpkg -l | grep build-essential &>/dev/null", exit_on_error=False)[2] == 0 check_python = run_command("dpkg -l | grep python3.10-dev &>/dev/null", exit_on_error=False)[2] == 0 if not check_gcc or not check_python: raise UserWarning( - "The packages build-essential and python3.10-dev are required for this script to run. Please install them. See the README for more information.") + "The packages build-essential and python3.10-dev are required for this script to run. Please install them. See the README for more information." + ) elif fedora: check_gcc = run_command("rpm -q gcc &>/dev/null", exit_on_error=False)[2] == 0 check_python = run_command("rpm -q python3.10-devel &>/dev/null", exit_on_error=False)[2] == 0 if not check_gcc or not check_python: raise UserWarning( - "The package python3.10-devel and the Development Tools group are required for this script to run. Please install them. See the README for more information.") + "The package python3.10-devel and the Development Tools group are required for this script to run. Please install them. See the README for more information." + ) else: logger.warning( - "Unsupported OS detected. Please ensure you have the following packages installed or their equivalent: build-essential, python3.10-dev") + "Unsupported OS detected. Please ensure you have the following packages installed or their equivalent: build-essential, python3.10-dev" + ) -def run_command(command: str, exit_on_error: bool = True): +def run_command(command: str, exit_on_error: bool = True) -> Tuple[bytes, bytes, int]: + """ + Run a shell command and return the output, error, and return code. + + :param command: The shell command to run + :param exit_on_error: Whether to raise an exception on error + :return: A tuple containing stdout, stderr, and return code + """ logger.debug(f"Running command: {command}") process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = process.communicate() @@ -73,25 +95,44 @@ def run_command(command: str, exit_on_error: bool = True): def load_service_from_string(service: str) -> Stack: + """ + Dynamically load a Stack class based on the service string. + + :param service: Name of the service to load + :return: An instance of the corresponding Stack class + """ logger.debug(f"Loading service from string: {service}") service_name = service.replace("_", " ").title().replace(" ", "") - module = importlib.import_module(f"services.{service}") - met = getattr(module, service_name) - return met() + stack_class = getattr(module, service_name) + + return stack_class() -def find_symlink_in_folder(folder: str): +def find_symlink_in_folder(folder: Union[str, Path]) -> Dict[Path, Path]: + """ + Find all symbolic links in the given folder and map them to their resolved paths. + + :param folder: The folder to search for symlinks + :return: A dictionary mapping symlink paths to their resolved target paths + """ + folder = Path(folder) symlinks = {} - for file in Path(folder).rglob("webui/**"): + + for file in folder.rglob("webui/**"): if file.is_symlink(): symlinks[file] = file.resolve() return symlinks -def create_symlinks(symlinks: dict[Path, Path]): +def create_symlinks(symlinks: Dict[Path, Path]) -> None: + """ + Recreate symlinks from a dictionary mapping target paths to link paths. + + :param symlinks: Dictionary of symlinks and their resolved paths + """ for target, link in symlinks.items(): logger.debug(f"(re)Creating symlink: {link} -> {target}") diff --git a/services/stable_diffusion_forge.py b/services/stable_diffusion_forge.py index d8dde91..111a5a2 100644 --- a/services/stable_diffusion_forge.py +++ b/services/stable_diffusion_forge.py @@ -1,3 +1,5 @@ +from pathlib import Path + from core.stack import Stack @@ -12,14 +14,14 @@ class StableDiffusionForge(Stack): def _install(self): # Install the webui - self.git_clone(url=self.url, dest="webui") + self.git_clone(url=self.url, dest=Path(self.path / "webui")) - self.python("launch.py --skip-torch-cuda-test --exit", current_dir="webui") + self.python("launch.py --skip-torch-cuda-test --exit", current_dir=Path(self.path / "webui")) # Add NF4 support for Flux self.install_from_prebuilt("bitsandbytes") def _start(self): args = ["--listen", "--enable-insecure-extension-access", "--port", str(self.port)] - self.python(f"launch.py", args=args, current_dir="webui", + self.python(f"launch.py", args=args, current_dir=Path(self.path / "webui"), env=["TORCH_BLAS_PREFER_HIPBLASLT=0"], daemon=True)