Compare commits

...

2 Commits

5 changed files with 307 additions and 151 deletions

76
core/screen.py Normal file
View File

@ -0,0 +1,76 @@
# Taken from https://gitlab.com/jvadair/pyscreen
# All credit goes to the original author, jvadair
__all__ = ['Screen', 'create', 'kill', 'exists', 'ls']
import os
import signal
import subprocess
from datetime import datetime
class ScreenNotFound(Exception):
pass
class Screen:
def __init__(self, pid, kill_on_exit=False):
"""
:param pid: The process ID of the screen
Creates a new Screen object
"""
if not exists(pid):
raise ScreenNotFound(f"No screen with pid {pid}")
self.pid = pid
self.kill_on_exit = kill_on_exit
def __del__(self): # Destroy screen process when object deleted
if self.kill_on_exit:
kill(self.pid)
def send(self, command: str, end="\r") -> None:
"""
:param command: The command to be run
:param end: Appended to the end of the command - the default value is a carriage return
"""
os.system(f'screen -S {self.pid} -X stuff {command}{end}')
def create(name, shell=os.environ['SHELL'], logfile=None, title=None) -> Screen:
command = ["screen", "-DmS", name, '-s', shell]
if logfile:
command.append('-Logfile')
command.append(logfile)
if title:
command.append('-t')
command.append(title)
process = subprocess.Popen(command)
while not process.pid: pass
return Screen(process.pid)
def kill(pid):
os.kill(pid, signal.SIGTERM)
def exists(pid: int) -> int:
command = f"screen -S {str(pid)} -Q select .".split()
pop = subprocess.Popen(command, stdout=subprocess.DEVNULL)
pop.wait()
code = pop.returncode
return False if code else True
def ls() -> dict:
out = subprocess.check_output(["screen", "-ls"]).decode()
out = out.replace('\t', '')
out = out.split('\n')
out = out[1:len(out) - 2]
out = [i.replace(")", "").split("(") for i in out]
final = {}
for item in out:
process = item[0].split('.')
pid = process[0]
final[pid] = {'name': process[1]} # final: {pid:{...}, ... }
final[pid]['time'] = datetime.strptime(item[1], '%m/%d/%Y %X %p')
return final

View File

@ -1,75 +1,95 @@
import logging
import os
import shutil
import subprocess
import time
from pathlib import Path
from typing import Union, List, Optional
import psutil
from core import utils, config
from core import utils, screen
from core.vars import logger, PYTHON_EXEC
from ui import choices
def find_correct_pid(parent_pid: int) -> int:
processes: list[psutil.Process] = [psutil.Process(parent_pid)]
create_time = processes[0].create_time()
time.sleep(0.5) # Wait for child processes to spawn
for i in range(1, 10):
if psutil.pid_exists(processes[0].pid + i):
child_process = psutil.Process(processes[0].pid + i)
if child_process.create_time() - create_time < 1:
processes.append(psutil.Process(processes[0].pid + i))
else:
time.sleep(0.5 / i)
return processes[-1].pid
class Stack:
def __init__(self, name: str, id: str, port: int, url: str):
"""
Initialize the Stack instance.
:param name: The name of the stack
:param id: A unique identifier for the stack
:param port: The port the stack service uses
:param url: The URL associated with the stack
"""
self.name = name
self.id = id
self.path = os.path.join(os.path.expanduser("~"), ".ai-suite-rocm", id)
self.path = Path.home() / ".ai-suite-rocm" / id
self.url = url
self.port = port
self.pid_file = self.path / f".pid"
self.pid = self.read_pid()
self.pid = config.get(f"{self.name}-pid")
def read_pid(self) -> Optional[int]:
"""
Read the PID from the PID file, if it exists.
def install(self):
:return: The PID as an integer, or None if the file does not exist or an error occurs.
"""
if self.pid_file.exists():
return int(self.pid_file.read_text())
else:
return None
def write_pid(self, pid: int) -> None:
"""
Write the PID to the PID file.
:param pid: Process ID to write
"""
with self.pid_file.open('w') as f:
f.write(str(pid))
def remove_pid_file(self) -> None:
"""Remove the PID file if it exists."""
if self.pid_file.exists():
self.pid_file.unlink()
def install(self) -> None:
"""Install the stack, creating the virtual environment and performing initial setup."""
if self.is_installed():
self.update()
else:
self.check_for_broken_install()
self.create_dir('')
self.create_venv()
self._install()
self.create_file('.installed', 'true')
logger.info(f"Installed {self.name}")
def _install(self):
def _install(self) -> None:
"""Additional installation steps specific to the stack (override as needed)."""
pass
def is_installed(self):
def is_installed(self) -> bool:
"""
Check if the stack is installed by verifying the existence of the '.installed' file.
:return: True if the stack is installed, False otherwise.
"""
return self.file_exists('.installed')
def check_for_broken_install(self):
if not self.is_installed():
if os.path.exists(self.path):
if len(os.listdir(self.path)) > 0:
logger.warning("Found files from a previous/borked/crashed installation, cleaning up...")
self.bash(f"rm -rf {self.path}")
self.create_dir('')
else:
self.create_dir('')
def check_for_broken_install(self) -> None:
"""Check for a broken installation and clean up any leftover files."""
if not self.is_installed() and self.path.exists():
logger.warning("Found files from a previous/broken/crashed installation, cleaning up...")
self.remove_dir('')
def update(self, folder: str = 'webui'):
def update(self, folder: str = 'webui') -> None:
"""Update the stack by pulling the latest changes from the repository."""
if self.is_installed():
status = self.status()
if status:
was_running = self.status()
if was_running:
self.stop()
logger.info(f"Updating {self.name}")
@ -78,61 +98,69 @@ class Stack:
self._update()
utils.create_symlinks(symlinks)
if status:
if was_running:
self.start()
else:
logger.warning(f"Could not update {self.name} as {self.name} is not installed")
def _update(self):
def _update(self) -> None:
"""Additional update steps specific to the stack (override as needed)."""
pass
def uninstall(self):
def uninstall(self) -> None:
"""Uninstall the stack by stopping it and removing its files."""
logger.info(f"Uninstalling {self.name}")
if self.status():
self.stop()
self.bash(f"rm -rf {self.path}")
self.remove_pid_file()
def start(self):
def start(self) -> None:
"""Start the stack service."""
if self.status():
logger.warning(f"{self.name} is already running")
return
if self.is_installed():
self._start()
else:
logger.error(f"{self.name} is not installed")
def _start(self):
def _start(self) -> None:
"""Additional start steps specific to the stack (override as needed)."""
pass
def stop(self):
def stop(self) -> None:
"""Stop the stack service by terminating the associated process."""
if self.status():
logger.debug(f"Killing {self.name} with PID: {self.pid}")
psutil.Process(self.pid).kill()
logger.debug(f"Stopping {self.name} with PID: {self.pid}")
try:
proc = psutil.Process(self.pid)
proc.terminate() # Graceful shutdown
proc.wait(timeout=5)
except (psutil.NoSuchProcess, psutil.TimeoutExpired):
logger.warning(f"{self.name} did not terminate gracefully, forcing kill")
psutil.Process(self.pid).kill()
self.remove_pid_file()
else:
logger.warning(f"{self.name} is not running")
self.set_pid(None)
def set_pid(self, pid):
self.pid = pid
if pid is not None:
config.put(f"{self.name}-pid", pid)
else:
config.remove(f"{self.name}-pid")
def restart(self):
def restart(self) -> None:
"""Restart the stack service."""
self.stop()
self.start()
def status(self) -> bool:
if self.pid is None:
return False
"""
Check if the stack service is running.
return psutil.pid_exists(self.pid)
:return: True if the service is running, False otherwise.
"""
return self.pid is not None and psutil.pid_exists(self.pid)
# Python/Bash utils
def create_venv(self):
venv_path = os.path.join(self.path, 'venv')
def create_venv(self) -> None:
"""Create a Python virtual environment for the stack."""
venv_path = self.path / 'venv'
if not self.has_venv():
logger.info(f"Creating venv for {self.name}")
self.bash(f"{PYTHON_EXEC} -m venv {venv_path} --system-site-packages")
@ -141,9 +169,16 @@ class Stack:
logger.debug(f"Venv already exists for {self.name}")
def has_venv(self) -> bool:
return self.dir_exists('venv')
"""
Check if the virtual environment exists.
def pip_install(self, package: str | list, no_deps: bool = False, env=[], args=[]):
:return: True if the virtual environment exists, False otherwise.
"""
return (self.path / 'venv').exists()
def pip_install(self, package: Union[str, List[str]], no_deps: bool = False, env: List[str] = [],
args: List[str] = []) -> None:
"""Install a Python package or list of packages using pip."""
if no_deps:
args.append("--no-deps")
@ -155,115 +190,125 @@ class Stack:
logger.info(f"Installing {package}")
self.pip(f"install -U {package}", env=env, args=args)
def install_requirements(self, filename: str = 'requirements.txt', env=[]):
def install_requirements(self, filename: str = 'requirements.txt', env: List[str] = []) -> None:
"""Install requirements from a given file."""
logger.info(f"Installing requirements for {self.name} ({filename})")
self.pip(f"install -r {filename}", env=env)
def pip(self, cmd: str, env=[], args=[], current_dir: str = None):
def pip(self, cmd: str, env: List[str] = [], args: List[str] = [], current_dir: Optional[Path] = None) -> None:
"""Run pip with a given command."""
self.python(f"-m pip {cmd}", env=env, args=args, current_dir=current_dir)
def python(self, cmd: str, env=[], args=[], current_dir: str = None, daemon: bool = False):
self.bash(f"{' '.join(env)} {self.path}/venv/bin/python {cmd} {' '.join(args)}", current_dir, daemon)
def python(self, cmd: str, env: List[str] = [], args: List[str] = [], current_dir: Optional[Path] = None,
daemon: bool = False) -> None:
"""Run a Python command inside the stack's virtual environment."""
self.bash(f"{' '.join(env)} {self.path / 'venv' / 'bin' / 'python'} {cmd} {' '.join(args)}", current_dir,
daemon)
def bash(self, cmd: str, current_dir: str = None, daemon: bool = False):
cmd = f"cd {self.path if current_dir is None else os.path.join(self.path, current_dir)} && {cmd}"
def bash(self, cmd: str, current_dir: Optional[Path] = None, daemon: bool = False) -> None:
"""Run a bash command, optionally as a daemon."""
full_cmd = f"cd {current_dir or self.path} && {cmd}"
if daemon:
if self.status():
choice = choices.already_running.ask()
if choice is True:
self.stop()
self._start()
return
else:
# TODO: attach to subprocess / redirect logs?
return
else:
logger.debug(f"Running command as daemon: {cmd}")
cmd = f"{cmd} &"
process = subprocess.Popen(cmd, shell=True, preexec_fn=os.setpgrp,
stdout=config.open_file(f"{self.id}-stdout"),
stderr=config.open_file(f"{self.id}-stderr"))
self.set_pid(find_correct_pid(process.pid))
logger.debug(f"Running command as daemon: {full_cmd}")
# process = subprocess.Popen(full_cmd, shell=True, preexec_fn=os.setpgrp,
# stdout=config.open_file(f"{self.id}-stdout"),
# stderr=config.open_file(f"{self.id}-stderr"))
screen_session = screen.create(name=self.id)
screen_session.send(f"'{full_cmd}'")
self.write_pid(screen_session.pid)
return
else:
logger.debug(f"Running command: {cmd}")
logger.debug(f"Running command: {full_cmd}")
if logger.level == logging.DEBUG:
process = subprocess.Popen(cmd, shell=True)
process.wait()
if process.returncode != 0:
raise Exception(f"Failed to run command: {cmd}")
else:
process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, err = process.communicate()
process = subprocess.Popen(full_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, err = process.communicate()
if process.returncode != 0:
logger.fatal(f"Failed to run command: {cmd}")
logger.fatal(f"Error: {err.decode('utf-8')}")
logger.fatal(f"Output: {out.decode('utf-8')}")
raise Exception(f"Failed to run command: {cmd}")
if process.returncode != 0:
logger.fatal(f"Failed to run command: {full_cmd}")
logger.fatal(f"Error: {err.decode('utf-8')}")
logger.fatal(f"Output: {out.decode('utf-8')}")
raise Exception(f"Failed to run command: {full_cmd}")
# Git utils
def git_clone(self, url: str, branch: str = None, dest: str = None):
def git_clone(self, url: str, branch: Optional[str] = None, dest: Optional[Path] = None) -> None:
"""Clone a git repository."""
logger.info(f"Cloning {url}")
self.bash(f"git clone {f"-b {branch}" if branch is not None else ''} {url} {'' if dest is None else dest}")
self.bash(f"git clone {f'-b {branch}' if branch else ''} {url} {dest or ''}")
def git_pull(self, repo_folder: str, force: bool = False):
self.bash(f"git reset --hard HEAD {'&& git clean -f -d' if force else ''} && git pull", repo_folder)
def git_pull(self, repo_folder: str, force: bool = False) -> None:
"""Pull changes from a git repository."""
self.bash(f"git reset --hard HEAD {'&& git clean -f -d' if force else ''} && git pull", Path(repo_folder))
# Prebuilt utils
def install_from_prebuilt(self, name):
for prebuilt in utils.get_prebuilts():
if prebuilt['name'].split("-")[0] == name:
self.pip(f"install {prebuilt['browser_download_url']}")
return
# File utils
def create_file(self, name, content):
with open(os.path.join(self.path, name), 'w') as f:
f.write(content)
def create_file(self, name: str, content: str) -> None:
"""Create a file with the given content."""
(self.path / name).write_text(content)
def create_dir(self, name):
if name == '':
logger.info(f"Creating directory for {self.name}")
def create_dir(self, name: str) -> None:
"""Create a directory."""
dir_path = self.path / name
logger.debug(f"Creating directory {dir_path}")
dir_path.mkdir(parents=True, exist_ok=True)
logger.debug(f"Creating directory {name}")
os.makedirs(os.path.join(self.path, name), exist_ok=True)
def remove_file(self, name):
def remove_file(self, name: str) -> None:
"""Remove a file."""
logger.debug(f"Removing file {name}")
os.remove(os.path.join(self.path, name))
def remove_dir(self, name):
logger.debug(f"Removing directory {name}")
shutil.rmtree(os.path.join(self.path, name))
def remove_dir(self, name: str) -> None:
"""Remove a directory."""
logger.debug(f"Removing directory {name or self.path}")
if not name:
shutil.rmtree(self.path)
else:
shutil.rmtree(os.path.join(self.path, name))
def move_file_or_dir(self, src, dest):
def move_file_or_dir(self, src: str, dest: str) -> None:
"""Move a file or directory."""
logger.debug(f"Moving file/dir {src} to {dest}")
os.rename(os.path.join(self.path, src), os.path.join(self.path, dest))
def move_all_files_in_dir(self, src, dest):
def move_all_files_in_dir(self, src: str, dest: str) -> None:
"""Move all files in a directory to another directory"""
logger.debug(f"Moving all files in directory {src} to {dest}")
for file in os.listdir(os.path.join(self.path, src)):
os.rename(os.path.join(self.path, src, file), os.path.join(self.path, dest, file))
def file_exists(self, name):
return os.path.exists(os.path.join(self.path, name))
def file_exists(self, name: str) -> bool:
"""Check if a file exists."""
return (self.path / name).exists()
def dir_exists(self, name):
return os.path.exists(os.path.join(self.path, name))
def dir_exists(self, name: str) -> bool:
"""Check if a directory exists."""
return (self.path / name).exists()
def remove_line_in_file(self, contains: str | list, file: str):
logger.debug(f"Removing lines containing {contains} in {file}")
"""Remove lines containing a specific string from a file."""
target_file = self.path / file
logger.debug(f"Removing lines containing {contains} in {target_file}")
if isinstance(contains, list):
for c in contains:
self.bash(f"sed -i '/{c}/d' {file}")
self.bash(f"sed -i '/{c}/d' {target_file}")
else:
self.bash(f"sed -i '/{contains}/d' {file}")
self.bash(f"sed -i '/{contains}/d' {target_file}")
def replace_line_in_file(self, match: str, replace: str, file: str):
logger.debug(f"Replacing lines containing {match} with {replace} in {file}")
self.bash(f"sed -i 's/{match}/{replace}/g' {file}")
"""Replace lines containing a specific string in a file."""
target_file = self.path / file
logger.debug(f"Replacing lines containing {match} with {replace} in {target_file}")
self.bash(f"sed -i 's/{match}/{replace}/g' {target_file}")

View File

@ -4,6 +4,7 @@ import os
import shutil
import subprocess
from pathlib import Path
from typing import List, Dict, Tuple, Union
from urllib import request, error
from core.stack import Stack
@ -11,7 +12,15 @@ from core.vars import ROCM_VERSION, logger
def get_prebuilts(repo_owner: str = "M4TH1EU", repo_name: str = "ai-suite-rocm-local",
release_tag: str = f"prebuilt-whl-{ROCM_VERSION}") -> list:
release_tag: str = f"prebuilt-whl-{ROCM_VERSION}") -> List[dict]:
"""
Fetch prebuilt assets from a GitHub release using the GitHub API.
:param repo_owner: GitHub repository owner
:param repo_name: GitHub repository name
:param release_tag: Release tag for fetching assets
:return: List of assets (dictionaries) from the GitHub release
"""
api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/tags/{release_tag}"
try:
@ -21,7 +30,6 @@ def get_prebuilts(repo_owner: str = "M4TH1EU", repo_name: str = "ai-suite-rocm-l
return []
release_data = json.load(response)
assets = release_data.get('assets', [])
if not assets:
logger.error("No assets found in release data")
@ -31,34 +39,48 @@ def get_prebuilts(repo_owner: str = "M4TH1EU", repo_name: str = "ai-suite-rocm-l
except error.URLError as e:
logger.error(f"Error fetching release data: {e}")
return []
def check_for_build_essentials():
def check_for_build_essentials() -> None:
"""
Check if build essentials like `build-essential` and `python3.10-dev` are installed.
Raises a warning if they are missing.
"""
logger.debug("Checking for build essentials...")
debian = os.path.exists('/etc/debian_version')
fedora = os.path.exists('/etc/fedora-release')
debian = Path('/etc/debian_version').exists()
fedora = Path('/etc/fedora-release').exists()
if debian:
# TODO: check if these work for debian users
check_gcc = run_command("dpkg -l | grep build-essential &>/dev/null", exit_on_error=False)[2] == 0
check_python = run_command("dpkg -l | grep python3.10-dev &>/dev/null", exit_on_error=False)[2] == 0
if not check_gcc or not check_python:
raise UserWarning(
"The packages build-essential and python3.10-dev are required for this script to run. Please install them. See the README for more information.")
"The packages build-essential and python3.10-dev are required for this script to run. Please install them. See the README for more information."
)
elif fedora:
check_gcc = run_command("rpm -q gcc &>/dev/null", exit_on_error=False)[2] == 0
check_python = run_command("rpm -q python3.10-devel &>/dev/null", exit_on_error=False)[2] == 0
if not check_gcc or not check_python:
raise UserWarning(
"The package python3.10-devel and the Development Tools group are required for this script to run. Please install them. See the README for more information.")
"The package python3.10-devel and the Development Tools group are required for this script to run. Please install them. See the README for more information."
)
else:
logger.warning(
"Unsupported OS detected. Please ensure you have the following packages installed or their equivalent: build-essential, python3.10-dev")
"Unsupported OS detected. Please ensure you have the following packages installed or their equivalent: build-essential, python3.10-dev"
)
def run_command(command: str, exit_on_error: bool = True):
def run_command(command: str, exit_on_error: bool = True) -> Tuple[bytes, bytes, int]:
"""
Run a shell command and return the output, error, and return code.
:param command: The shell command to run
:param exit_on_error: Whether to raise an exception on error
:return: A tuple containing stdout, stderr, and return code
"""
logger.debug(f"Running command: {command}")
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, err = process.communicate()
@ -73,25 +95,44 @@ def run_command(command: str, exit_on_error: bool = True):
def load_service_from_string(service: str) -> Stack:
"""
Dynamically load a Stack class based on the service string.
:param service: Name of the service to load
:return: An instance of the corresponding Stack class
"""
logger.debug(f"Loading service from string: {service}")
service_name = service.replace("_", " ").title().replace(" ", "")
module = importlib.import_module(f"services.{service}")
met = getattr(module, service_name)
return met()
stack_class = getattr(module, service_name)
return stack_class()
def find_symlink_in_folder(folder: str):
def find_symlink_in_folder(folder: Union[str, Path]) -> Dict[Path, Path]:
"""
Find all symbolic links in the given folder and map them to their resolved paths.
:param folder: The folder to search for symlinks
:return: A dictionary mapping symlink paths to their resolved target paths
"""
folder = Path(folder)
symlinks = {}
for file in Path(folder).rglob("webui/**"):
for file in folder.rglob("webui/**"):
if file.is_symlink():
symlinks[file] = file.resolve()
return symlinks
def create_symlinks(symlinks: dict[Path, Path]):
def create_symlinks(symlinks: Dict[Path, Path]) -> None:
"""
Recreate symlinks from a dictionary mapping target paths to link paths.
:param symlinks: Dictionary of symlinks and their resolved paths
"""
for target, link in symlinks.items():
logger.debug(f"(re)Creating symlink: {link} -> {target}")

View File

@ -16,11 +16,3 @@ python3.10 setup.py bdist_wheel --universal
git clone --recurse-submodules https://github.com/abetlen/llama-cpp-python.git /tmp/llama-cpp-python
cd /tmp/llama-cpp-python
CMAKE_ARGS="-D GGML_HIPBLAS=on -D AMDGPU_TARGETS=${GPU_TARGETS}" FORCE_CMAKE=1 python3.10 -m build --wheel
# ROCM xformers
## Clone repo and install python requirements
pip3 install ninja
git clone --depth 1 https://github.com/facebookresearch/xformers.git /tmp/xformers
cd /tmp/xformers
python3.10 setup.py bdist_wheel --universal

View File

@ -1,3 +1,5 @@
from pathlib import Path
from core.stack import Stack
@ -12,14 +14,14 @@ class StableDiffusionForge(Stack):
def _install(self):
# Install the webui
self.git_clone(url=self.url, dest="webui")
self.git_clone(url=self.url, dest=Path(self.path / "webui"))
self.python("launch.py --skip-torch-cuda-test --exit", current_dir="webui")
self.python("launch.py --skip-torch-cuda-test --exit", current_dir=Path(self.path / "webui"))
# Add NF4 support for Flux
self.install_from_prebuilt("bitsandbytes")
def _start(self):
args = ["--listen", "--enable-insecure-extension-access", "--port", str(self.port)]
self.python(f"launch.py", args=args, current_dir="webui",
self.python(f"launch.py", args=args, current_dir=Path(self.path / "webui"),
env=["TORCH_BLAS_PREFER_HIPBLASLT=0"], daemon=True)