improve stack and utils, use screen for daemon processes

This commit is contained in:
Mathieu Broillet 2024-10-05 14:20:45 +02:00
parent 15569eedca
commit 80e92d3cd3
Signed by: mathieu
GPG Key ID: A08E484FE95074C1
4 changed files with 307 additions and 143 deletions

76
core/screen.py Normal file
View File

@ -0,0 +1,76 @@
# Taken from https://gitlab.com/jvadair/pyscreen
# All credit goes to the original author, jvadair
__all__ = ['Screen', 'create', 'kill', 'exists', 'ls']
import os
import signal
import subprocess
from datetime import datetime
class ScreenNotFound(Exception):
pass
class Screen:
def __init__(self, pid, kill_on_exit=False):
"""
:param pid: The process ID of the screen
Creates a new Screen object
"""
if not exists(pid):
raise ScreenNotFound(f"No screen with pid {pid}")
self.pid = pid
self.kill_on_exit = kill_on_exit
def __del__(self): # Destroy screen process when object deleted
if self.kill_on_exit:
kill(self.pid)
def send(self, command: str, end="\r") -> None:
"""
:param command: The command to be run
:param end: Appended to the end of the command - the default value is a carriage return
"""
os.system(f'screen -S {self.pid} -X stuff {command}{end}')
def create(name, shell=os.environ['SHELL'], logfile=None, title=None) -> Screen:
command = ["screen", "-DmS", name, '-s', shell]
if logfile:
command.append('-Logfile')
command.append(logfile)
if title:
command.append('-t')
command.append(title)
process = subprocess.Popen(command)
while not process.pid: pass
return Screen(process.pid)
def kill(pid):
os.kill(pid, signal.SIGTERM)
def exists(pid: int) -> int:
command = f"screen -S {str(pid)} -Q select .".split()
pop = subprocess.Popen(command, stdout=subprocess.DEVNULL)
pop.wait()
code = pop.returncode
return False if code else True
def ls() -> dict:
out = subprocess.check_output(["screen", "-ls"]).decode()
out = out.replace('\t', '')
out = out.split('\n')
out = out[1:len(out) - 2]
out = [i.replace(")", "").split("(") for i in out]
final = {}
for item in out:
process = item[0].split('.')
pid = process[0]
final[pid] = {'name': process[1]} # final: {pid:{...}, ... }
final[pid]['time'] = datetime.strptime(item[1], '%m/%d/%Y %X %p')
return final

View File

@ -1,75 +1,95 @@
import logging
import os import os
import shutil import shutil
import subprocess import subprocess
import time from pathlib import Path
from typing import Union, List, Optional
import psutil import psutil
from core import utils, config from core import utils, screen
from core.vars import logger, PYTHON_EXEC from core.vars import logger, PYTHON_EXEC
from ui import choices from ui import choices
def find_correct_pid(parent_pid: int) -> int:
processes: list[psutil.Process] = [psutil.Process(parent_pid)]
create_time = processes[0].create_time()
time.sleep(0.5) # Wait for child processes to spawn
for i in range(1, 10):
if psutil.pid_exists(processes[0].pid + i):
child_process = psutil.Process(processes[0].pid + i)
if child_process.create_time() - create_time < 1:
processes.append(psutil.Process(processes[0].pid + i))
else:
time.sleep(0.5 / i)
return processes[-1].pid
class Stack: class Stack:
def __init__(self, name: str, id: str, port: int, url: str): def __init__(self, name: str, id: str, port: int, url: str):
"""
Initialize the Stack instance.
:param name: The name of the stack
:param id: A unique identifier for the stack
:param port: The port the stack service uses
:param url: The URL associated with the stack
"""
self.name = name self.name = name
self.id = id self.id = id
self.path = os.path.join(os.path.expanduser("~"), ".ai-suite-rocm", id) self.path = Path.home() / ".ai-suite-rocm" / id
self.url = url self.url = url
self.port = port self.port = port
self.pid_file = self.path / f".pid"
self.pid = self.read_pid()
self.pid = config.get(f"{self.name}-pid") def read_pid(self) -> Optional[int]:
"""
Read the PID from the PID file, if it exists.
def install(self): :return: The PID as an integer, or None if the file does not exist or an error occurs.
"""
if self.pid_file.exists():
return int(self.pid_file.read_text())
else:
return None
def write_pid(self, pid: int) -> None:
"""
Write the PID to the PID file.
:param pid: Process ID to write
"""
with self.pid_file.open('w') as f:
f.write(str(pid))
def remove_pid_file(self) -> None:
"""Remove the PID file if it exists."""
if self.pid_file.exists():
self.pid_file.unlink()
def install(self) -> None:
"""Install the stack, creating the virtual environment and performing initial setup."""
if self.is_installed(): if self.is_installed():
self.update() self.update()
else: else:
self.check_for_broken_install() self.check_for_broken_install()
self.create_dir('')
self.create_venv() self.create_venv()
self._install() self._install()
self.create_file('.installed', 'true') self.create_file('.installed', 'true')
logger.info(f"Installed {self.name}") logger.info(f"Installed {self.name}")
def _install(self): def _install(self) -> None:
"""Additional installation steps specific to the stack (override as needed)."""
pass pass
def is_installed(self): def is_installed(self) -> bool:
"""
Check if the stack is installed by verifying the existence of the '.installed' file.
:return: True if the stack is installed, False otherwise.
"""
return self.file_exists('.installed') return self.file_exists('.installed')
def check_for_broken_install(self): def check_for_broken_install(self) -> None:
if not self.is_installed(): """Check for a broken installation and clean up any leftover files."""
if os.path.exists(self.path): if not self.is_installed() and self.path.exists():
if len(os.listdir(self.path)) > 0: logger.warning("Found files from a previous/broken/crashed installation, cleaning up...")
logger.warning("Found files from a previous/borked/crashed installation, cleaning up...") self.remove_dir('')
self.bash(f"rm -rf {self.path}")
self.create_dir('')
else:
self.create_dir('')
def update(self, folder: str = 'webui'): def update(self, folder: str = 'webui') -> None:
"""Update the stack by pulling the latest changes from the repository."""
if self.is_installed(): if self.is_installed():
status = self.status() was_running = self.status()
if status: if was_running:
self.stop() self.stop()
logger.info(f"Updating {self.name}") logger.info(f"Updating {self.name}")
@ -78,61 +98,69 @@ class Stack:
self._update() self._update()
utils.create_symlinks(symlinks) utils.create_symlinks(symlinks)
if status: if was_running:
self.start() self.start()
else: else:
logger.warning(f"Could not update {self.name} as {self.name} is not installed") logger.warning(f"Could not update {self.name} as {self.name} is not installed")
def _update(self): def _update(self) -> None:
"""Additional update steps specific to the stack (override as needed)."""
pass pass
def uninstall(self): def uninstall(self) -> None:
"""Uninstall the stack by stopping it and removing its files."""
logger.info(f"Uninstalling {self.name}") logger.info(f"Uninstalling {self.name}")
if self.status(): if self.status():
self.stop() self.stop()
self.bash(f"rm -rf {self.path}") self.bash(f"rm -rf {self.path}")
self.remove_pid_file()
def start(self): def start(self) -> None:
"""Start the stack service."""
if self.status(): if self.status():
logger.warning(f"{self.name} is already running") logger.warning(f"{self.name} is already running")
return
if self.is_installed(): if self.is_installed():
self._start() self._start()
else: else:
logger.error(f"{self.name} is not installed") logger.error(f"{self.name} is not installed")
def _start(self): def _start(self) -> None:
"""Additional start steps specific to the stack (override as needed)."""
pass pass
def stop(self): def stop(self) -> None:
"""Stop the stack service by terminating the associated process."""
if self.status(): if self.status():
logger.debug(f"Killing {self.name} with PID: {self.pid}") logger.debug(f"Stopping {self.name} with PID: {self.pid}")
try:
proc = psutil.Process(self.pid)
proc.terminate() # Graceful shutdown
proc.wait(timeout=5)
except (psutil.NoSuchProcess, psutil.TimeoutExpired):
logger.warning(f"{self.name} did not terminate gracefully, forcing kill")
psutil.Process(self.pid).kill() psutil.Process(self.pid).kill()
self.remove_pid_file()
else: else:
logger.warning(f"{self.name} is not running") logger.warning(f"{self.name} is not running")
self.set_pid(None) def restart(self) -> None:
"""Restart the stack service."""
def set_pid(self, pid):
self.pid = pid
if pid is not None:
config.put(f"{self.name}-pid", pid)
else:
config.remove(f"{self.name}-pid")
def restart(self):
self.stop() self.stop()
self.start() self.start()
def status(self) -> bool: def status(self) -> bool:
if self.pid is None: """
return False Check if the stack service is running.
return psutil.pid_exists(self.pid) :return: True if the service is running, False otherwise.
"""
return self.pid is not None and psutil.pid_exists(self.pid)
# Python/Bash utils def create_venv(self) -> None:
def create_venv(self): """Create a Python virtual environment for the stack."""
venv_path = os.path.join(self.path, 'venv') venv_path = self.path / 'venv'
if not self.has_venv(): if not self.has_venv():
logger.info(f"Creating venv for {self.name}") logger.info(f"Creating venv for {self.name}")
self.bash(f"{PYTHON_EXEC} -m venv {venv_path} --system-site-packages") self.bash(f"{PYTHON_EXEC} -m venv {venv_path} --system-site-packages")
@ -141,9 +169,16 @@ class Stack:
logger.debug(f"Venv already exists for {self.name}") logger.debug(f"Venv already exists for {self.name}")
def has_venv(self) -> bool: def has_venv(self) -> bool:
return self.dir_exists('venv') """
Check if the virtual environment exists.
def pip_install(self, package: str | list, no_deps: bool = False, env=[], args=[]): :return: True if the virtual environment exists, False otherwise.
"""
return (self.path / 'venv').exists()
def pip_install(self, package: Union[str, List[str]], no_deps: bool = False, env: List[str] = [],
args: List[str] = []) -> None:
"""Install a Python package or list of packages using pip."""
if no_deps: if no_deps:
args.append("--no-deps") args.append("--no-deps")
@ -155,115 +190,125 @@ class Stack:
logger.info(f"Installing {package}") logger.info(f"Installing {package}")
self.pip(f"install -U {package}", env=env, args=args) self.pip(f"install -U {package}", env=env, args=args)
def install_requirements(self, filename: str = 'requirements.txt', env=[]): def install_requirements(self, filename: str = 'requirements.txt', env: List[str] = []) -> None:
"""Install requirements from a given file."""
logger.info(f"Installing requirements for {self.name} ({filename})") logger.info(f"Installing requirements for {self.name} ({filename})")
self.pip(f"install -r {filename}", env=env) self.pip(f"install -r {filename}", env=env)
def pip(self, cmd: str, env=[], args=[], current_dir: str = None): def pip(self, cmd: str, env: List[str] = [], args: List[str] = [], current_dir: Optional[Path] = None) -> None:
"""Run pip with a given command."""
self.python(f"-m pip {cmd}", env=env, args=args, current_dir=current_dir) self.python(f"-m pip {cmd}", env=env, args=args, current_dir=current_dir)
def python(self, cmd: str, env=[], args=[], current_dir: str = None, daemon: bool = False): def python(self, cmd: str, env: List[str] = [], args: List[str] = [], current_dir: Optional[Path] = None,
self.bash(f"{' '.join(env)} {self.path}/venv/bin/python {cmd} {' '.join(args)}", current_dir, daemon) daemon: bool = False) -> None:
"""Run a Python command inside the stack's virtual environment."""
self.bash(f"{' '.join(env)} {self.path / 'venv' / 'bin' / 'python'} {cmd} {' '.join(args)}", current_dir,
daemon)
def bash(self, cmd: str, current_dir: str = None, daemon: bool = False): def bash(self, cmd: str, current_dir: Optional[Path] = None, daemon: bool = False) -> None:
cmd = f"cd {self.path if current_dir is None else os.path.join(self.path, current_dir)} && {cmd}" """Run a bash command, optionally as a daemon."""
full_cmd = f"cd {current_dir or self.path} && {cmd}"
if daemon: if daemon:
if self.status(): if self.status():
choice = choices.already_running.ask() choice = choices.already_running.ask()
if choice is True: if choice is True:
self.stop() self.stop()
self._start() self._start()
return return
else: else:
# TODO: attach to subprocess / redirect logs?
return return
else: else:
logger.debug(f"Running command as daemon: {cmd}") logger.debug(f"Running command as daemon: {full_cmd}")
cmd = f"{cmd} &"
process = subprocess.Popen(cmd, shell=True, preexec_fn=os.setpgrp,
stdout=config.open_file(f"{self.id}-stdout"),
stderr=config.open_file(f"{self.id}-stderr"))
self.set_pid(find_correct_pid(process.pid))
return
else:
logger.debug(f"Running command: {cmd}")
if logger.level == logging.DEBUG: # process = subprocess.Popen(full_cmd, shell=True, preexec_fn=os.setpgrp,
process = subprocess.Popen(cmd, shell=True) # stdout=config.open_file(f"{self.id}-stdout"),
process.wait() # stderr=config.open_file(f"{self.id}-stderr"))
if process.returncode != 0: screen_session = screen.create(name=self.id)
raise Exception(f"Failed to run command: {cmd}") screen_session.send(f"'{full_cmd}'")
self.write_pid(screen_session.pid)
return
else: else:
process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) logger.debug(f"Running command: {full_cmd}")
process = subprocess.Popen(full_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, err = process.communicate() out, err = process.communicate()
if process.returncode != 0: if process.returncode != 0:
logger.fatal(f"Failed to run command: {cmd}") logger.fatal(f"Failed to run command: {full_cmd}")
logger.fatal(f"Error: {err.decode('utf-8')}") logger.fatal(f"Error: {err.decode('utf-8')}")
logger.fatal(f"Output: {out.decode('utf-8')}") logger.fatal(f"Output: {out.decode('utf-8')}")
raise Exception(f"Failed to run command: {cmd}") raise Exception(f"Failed to run command: {full_cmd}")
# Git utils def git_clone(self, url: str, branch: Optional[str] = None, dest: Optional[Path] = None) -> None:
def git_clone(self, url: str, branch: str = None, dest: str = None): """Clone a git repository."""
logger.info(f"Cloning {url}") logger.info(f"Cloning {url}")
self.bash(f"git clone {f"-b {branch}" if branch is not None else ''} {url} {'' if dest is None else dest}") self.bash(f"git clone {f'-b {branch}' if branch else ''} {url} {dest or ''}")
def git_pull(self, repo_folder: str, force: bool = False): def git_pull(self, repo_folder: str, force: bool = False) -> None:
self.bash(f"git reset --hard HEAD {'&& git clean -f -d' if force else ''} && git pull", repo_folder) """Pull changes from a git repository."""
self.bash(f"git reset --hard HEAD {'&& git clean -f -d' if force else ''} && git pull", Path(repo_folder))
# Prebuilt utils
def install_from_prebuilt(self, name): def install_from_prebuilt(self, name):
for prebuilt in utils.get_prebuilts(): for prebuilt in utils.get_prebuilts():
if prebuilt['name'].split("-")[0] == name: if prebuilt['name'].split("-")[0] == name:
self.pip(f"install {prebuilt['browser_download_url']}") self.pip(f"install {prebuilt['browser_download_url']}")
return return
# File utils def create_file(self, name: str, content: str) -> None:
def create_file(self, name, content): """Create a file with the given content."""
with open(os.path.join(self.path, name), 'w') as f: (self.path / name).write_text(content)
f.write(content)
def create_dir(self, name): def create_dir(self, name: str) -> None:
if name == '': """Create a directory."""
logger.info(f"Creating directory for {self.name}") dir_path = self.path / name
logger.debug(f"Creating directory {dir_path}")
dir_path.mkdir(parents=True, exist_ok=True)
logger.debug(f"Creating directory {name}") def remove_file(self, name: str) -> None:
os.makedirs(os.path.join(self.path, name), exist_ok=True) """Remove a file."""
def remove_file(self, name):
logger.debug(f"Removing file {name}") logger.debug(f"Removing file {name}")
os.remove(os.path.join(self.path, name)) os.remove(os.path.join(self.path, name))
def remove_dir(self, name): def remove_dir(self, name: str) -> None:
logger.debug(f"Removing directory {name}") """Remove a directory."""
logger.debug(f"Removing directory {name or self.path}")
if not name:
shutil.rmtree(self.path)
else:
shutil.rmtree(os.path.join(self.path, name)) shutil.rmtree(os.path.join(self.path, name))
def move_file_or_dir(self, src, dest): def move_file_or_dir(self, src: str, dest: str) -> None:
"""Move a file or directory."""
logger.debug(f"Moving file/dir {src} to {dest}") logger.debug(f"Moving file/dir {src} to {dest}")
os.rename(os.path.join(self.path, src), os.path.join(self.path, dest)) os.rename(os.path.join(self.path, src), os.path.join(self.path, dest))
def move_all_files_in_dir(self, src, dest): def move_all_files_in_dir(self, src: str, dest: str) -> None:
"""Move all files in a directory to another directory"""
logger.debug(f"Moving all files in directory {src} to {dest}") logger.debug(f"Moving all files in directory {src} to {dest}")
for file in os.listdir(os.path.join(self.path, src)): for file in os.listdir(os.path.join(self.path, src)):
os.rename(os.path.join(self.path, src, file), os.path.join(self.path, dest, file)) os.rename(os.path.join(self.path, src, file), os.path.join(self.path, dest, file))
def file_exists(self, name): def file_exists(self, name: str) -> bool:
return os.path.exists(os.path.join(self.path, name)) """Check if a file exists."""
return (self.path / name).exists()
def dir_exists(self, name): def dir_exists(self, name: str) -> bool:
return os.path.exists(os.path.join(self.path, name)) """Check if a directory exists."""
return (self.path / name).exists()
def remove_line_in_file(self, contains: str | list, file: str): def remove_line_in_file(self, contains: str | list, file: str):
logger.debug(f"Removing lines containing {contains} in {file}") """Remove lines containing a specific string from a file."""
target_file = self.path / file
logger.debug(f"Removing lines containing {contains} in {target_file}")
if isinstance(contains, list): if isinstance(contains, list):
for c in contains: for c in contains:
self.bash(f"sed -i '/{c}/d' {file}") self.bash(f"sed -i '/{c}/d' {target_file}")
else: else:
self.bash(f"sed -i '/{contains}/d' {file}") self.bash(f"sed -i '/{contains}/d' {target_file}")
def replace_line_in_file(self, match: str, replace: str, file: str): def replace_line_in_file(self, match: str, replace: str, file: str):
logger.debug(f"Replacing lines containing {match} with {replace} in {file}") """Replace lines containing a specific string in a file."""
self.bash(f"sed -i 's/{match}/{replace}/g' {file}") target_file = self.path / file
logger.debug(f"Replacing lines containing {match} with {replace} in {target_file}")
self.bash(f"sed -i 's/{match}/{replace}/g' {target_file}")

View File

@ -4,6 +4,7 @@ import os
import shutil import shutil
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from typing import List, Dict, Tuple, Union
from urllib import request, error from urllib import request, error
from core.stack import Stack from core.stack import Stack
@ -11,7 +12,15 @@ from core.vars import ROCM_VERSION, logger
def get_prebuilts(repo_owner: str = "M4TH1EU", repo_name: str = "ai-suite-rocm-local", def get_prebuilts(repo_owner: str = "M4TH1EU", repo_name: str = "ai-suite-rocm-local",
release_tag: str = f"prebuilt-whl-{ROCM_VERSION}") -> list: release_tag: str = f"prebuilt-whl-{ROCM_VERSION}") -> List[dict]:
"""
Fetch prebuilt assets from a GitHub release using the GitHub API.
:param repo_owner: GitHub repository owner
:param repo_name: GitHub repository name
:param release_tag: Release tag for fetching assets
:return: List of assets (dictionaries) from the GitHub release
"""
api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/tags/{release_tag}" api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/tags/{release_tag}"
try: try:
@ -21,7 +30,6 @@ def get_prebuilts(repo_owner: str = "M4TH1EU", repo_name: str = "ai-suite-rocm-l
return [] return []
release_data = json.load(response) release_data = json.load(response)
assets = release_data.get('assets', []) assets = release_data.get('assets', [])
if not assets: if not assets:
logger.error("No assets found in release data") logger.error("No assets found in release data")
@ -31,34 +39,48 @@ def get_prebuilts(repo_owner: str = "M4TH1EU", repo_name: str = "ai-suite-rocm-l
except error.URLError as e: except error.URLError as e:
logger.error(f"Error fetching release data: {e}") logger.error(f"Error fetching release data: {e}")
return []
def check_for_build_essentials(): def check_for_build_essentials() -> None:
"""
Check if build essentials like `build-essential` and `python3.10-dev` are installed.
Raises a warning if they are missing.
"""
logger.debug("Checking for build essentials...") logger.debug("Checking for build essentials...")
debian = os.path.exists('/etc/debian_version') debian = Path('/etc/debian_version').exists()
fedora = os.path.exists('/etc/fedora-release') fedora = Path('/etc/fedora-release').exists()
if debian: if debian:
# TODO: check if these work for debian users
check_gcc = run_command("dpkg -l | grep build-essential &>/dev/null", exit_on_error=False)[2] == 0 check_gcc = run_command("dpkg -l | grep build-essential &>/dev/null", exit_on_error=False)[2] == 0
check_python = run_command("dpkg -l | grep python3.10-dev &>/dev/null", exit_on_error=False)[2] == 0 check_python = run_command("dpkg -l | grep python3.10-dev &>/dev/null", exit_on_error=False)[2] == 0
if not check_gcc or not check_python: if not check_gcc or not check_python:
raise UserWarning( raise UserWarning(
"The packages build-essential and python3.10-dev are required for this script to run. Please install them. See the README for more information.") "The packages build-essential and python3.10-dev are required for this script to run. Please install them. See the README for more information."
)
elif fedora: elif fedora:
check_gcc = run_command("rpm -q gcc &>/dev/null", exit_on_error=False)[2] == 0 check_gcc = run_command("rpm -q gcc &>/dev/null", exit_on_error=False)[2] == 0
check_python = run_command("rpm -q python3.10-devel &>/dev/null", exit_on_error=False)[2] == 0 check_python = run_command("rpm -q python3.10-devel &>/dev/null", exit_on_error=False)[2] == 0
if not check_gcc or not check_python: if not check_gcc or not check_python:
raise UserWarning( raise UserWarning(
"The package python3.10-devel and the Development Tools group are required for this script to run. Please install them. See the README for more information.") "The package python3.10-devel and the Development Tools group are required for this script to run. Please install them. See the README for more information."
)
else: else:
logger.warning( logger.warning(
"Unsupported OS detected. Please ensure you have the following packages installed or their equivalent: build-essential, python3.10-dev") "Unsupported OS detected. Please ensure you have the following packages installed or their equivalent: build-essential, python3.10-dev"
)
def run_command(command: str, exit_on_error: bool = True): def run_command(command: str, exit_on_error: bool = True) -> Tuple[bytes, bytes, int]:
"""
Run a shell command and return the output, error, and return code.
:param command: The shell command to run
:param exit_on_error: Whether to raise an exception on error
:return: A tuple containing stdout, stderr, and return code
"""
logger.debug(f"Running command: {command}") logger.debug(f"Running command: {command}")
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, err = process.communicate() out, err = process.communicate()
@ -73,25 +95,44 @@ def run_command(command: str, exit_on_error: bool = True):
def load_service_from_string(service: str) -> Stack: def load_service_from_string(service: str) -> Stack:
"""
Dynamically load a Stack class based on the service string.
:param service: Name of the service to load
:return: An instance of the corresponding Stack class
"""
logger.debug(f"Loading service from string: {service}") logger.debug(f"Loading service from string: {service}")
service_name = service.replace("_", " ").title().replace(" ", "") service_name = service.replace("_", " ").title().replace(" ", "")
module = importlib.import_module(f"services.{service}") module = importlib.import_module(f"services.{service}")
met = getattr(module, service_name) stack_class = getattr(module, service_name)
return met()
return stack_class()
def find_symlink_in_folder(folder: str): def find_symlink_in_folder(folder: Union[str, Path]) -> Dict[Path, Path]:
"""
Find all symbolic links in the given folder and map them to their resolved paths.
:param folder: The folder to search for symlinks
:return: A dictionary mapping symlink paths to their resolved target paths
"""
folder = Path(folder)
symlinks = {} symlinks = {}
for file in Path(folder).rglob("webui/**"):
for file in folder.rglob("webui/**"):
if file.is_symlink(): if file.is_symlink():
symlinks[file] = file.resolve() symlinks[file] = file.resolve()
return symlinks return symlinks
def create_symlinks(symlinks: dict[Path, Path]): def create_symlinks(symlinks: Dict[Path, Path]) -> None:
"""
Recreate symlinks from a dictionary mapping target paths to link paths.
:param symlinks: Dictionary of symlinks and their resolved paths
"""
for target, link in symlinks.items(): for target, link in symlinks.items():
logger.debug(f"(re)Creating symlink: {link} -> {target}") logger.debug(f"(re)Creating symlink: {link} -> {target}")

View File

@ -1,3 +1,5 @@
from pathlib import Path
from core.stack import Stack from core.stack import Stack
@ -12,14 +14,14 @@ class StableDiffusionForge(Stack):
def _install(self): def _install(self):
# Install the webui # Install the webui
self.git_clone(url=self.url, dest="webui") self.git_clone(url=self.url, dest=Path(self.path / "webui"))
self.python("launch.py --skip-torch-cuda-test --exit", current_dir="webui") self.python("launch.py --skip-torch-cuda-test --exit", current_dir=Path(self.path / "webui"))
# Add NF4 support for Flux # Add NF4 support for Flux
self.install_from_prebuilt("bitsandbytes") self.install_from_prebuilt("bitsandbytes")
def _start(self): def _start(self):
args = ["--listen", "--enable-insecure-extension-access", "--port", str(self.port)] args = ["--listen", "--enable-insecure-extension-access", "--port", str(self.port)]
self.python(f"launch.py", args=args, current_dir="webui", self.python(f"launch.py", args=args, current_dir=Path(self.path / "webui"),
env=["TORCH_BLAS_PREFER_HIPBLASLT=0"], daemon=True) env=["TORCH_BLAS_PREFER_HIPBLASLT=0"], daemon=True)