diff --git a/core/config.py b/core/config.py new file mode 100644 index 0000000..c56e8e8 --- /dev/null +++ b/core/config.py @@ -0,0 +1,56 @@ +import json +import os + +from core.vars import logger + +data = {} +file = os.path.join(os.path.expanduser("~"), ".config", "ai-suite-rocm", "config.json") + + +def create(): + if not os.path.exists(file): + os.makedirs(os.path.dirname(file), exist_ok=True) + with open(file, "w") as f: + f.write("{}") + + logger.info(f"Created config file at {file}") + + +def read(): + global data + with open(file, "r") as f: + data = json.load(f) + + +def write(): + global data + with open(file, "w") as f: + json.dump(data, f) + + +def get(key: str): + global data + return data.get(key) + + +def set(key: str, value): + global data + data[key] = value + write() + + +def has(key: str): + global data + return key in data + + +def remove(key: str): + global data + data.pop(key) + write() + + +def clear(): + global data + data = {} + write() diff --git a/services/services.py b/core/stack.py similarity index 91% rename from services/services.py rename to core/stack.py index 1a864ee..3d9cb89 100644 --- a/services/services.py +++ b/core/stack.py @@ -1,25 +1,25 @@ import logging import os +import shutil import subprocess import psutil -import main -import utils -from main import PYTHON_EXEC, logger, get_config +from core import utils, config +from core.vars import logger, PYTHON_EXEC class Stack: - def __init__(self, name: str, path: str, port: int, url: str): + def __init__(self, name: str, id: str, port: int, url: str): self.name = name - self.path = os.path.join(os.path.expanduser("~"), ".ai-suite-rocm", path) + self.id = id + self.path = os.path.join(os.path.expanduser("~"), ".ai-suite-rocm", id) self.url = url self.port = port self.process = None - def install(self): self.create_file('.installed', 'true') logger.info(f"Installed {self.name}") @@ -110,18 +110,20 @@ class Stack: if daemon: # Check if previous run process is saved - if get_config().has(f"{self.name}-pid"): + if config.has(f"{self.name}-pid"): # Check if PID still running - if psutil.pid_exists(main.config.get(f"{self.name}-pid")): + if psutil.pid_exists(config.get(f"{self.name}-pid")): choice = input(f"{self.name} is already running, do you want to restart it? (y/n): ") if choice.lower() == 'y': - pid = main.config.get(f"{self.name}-pid") + pid = config.get(f"{self.name}-pid") logger.debug(f"Killing previous daemon with PID: {pid}") psutil.Process(pid).kill() else: # TODO: attach to subprocess? + logger.info("Continuing without restarting...") + return else: logger.warning( @@ -131,8 +133,8 @@ class Stack: logger.debug(f"Starting {self.name} as daemon with command: {cmd}") cmd = f"{cmd} &" - process = subprocess.Popen(cmd, shell=True) - get_config().set(f"{self.name}-pid", process.pid + 1) + process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + config.set(f"{self.name}-pid", process.pid + 1) return else: logger.debug(f"Running command: {cmd}") @@ -185,7 +187,7 @@ class Stack: def remove_dir(self, name): logger.debug(f"Removing directory {name}") - os.rmdir(os.path.join(self.path, name)) + shutil.rmtree(os.path.join(self.path, name)) def move_file_or_dir(self, src, dest): logger.debug(f"Moving file/dir {src} to {dest}") diff --git a/core/utils.py b/core/utils.py new file mode 100644 index 0000000..daf939a --- /dev/null +++ b/core/utils.py @@ -0,0 +1,83 @@ +import importlib +import json +import os +import subprocess +import urllib + +from core.stack import Stack +from core.vars import ROCM_VERSION, logger + + +def get_prebuilts(repo_owner: str = "M4TH1EU", repo_name: str = "ai-suite-rocm-local", + release_tag: str = f"prebuilt-whl-{ROCM_VERSION}") -> list: + api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/tags/{release_tag}" + + try: + with urllib.request.urlopen(api_url) as response: + if response.status != 200: + logger.error(f"Failed to fetch data: HTTP Status {response.status}") + return [] + + release_data = json.load(response) + + assets = release_data.get('assets', []) + if not assets: + logger.error("No assets found in release data") + return [] + + return assets + + except urllib.error.URLError as e: + logger.error(f"Error fetching release data: {e}") + + +def check_for_build_essentials(): + logger.debug("Checking for build essentials...") + debian = os.path.exists('/etc/debian_version') + fedora = os.path.exists('/etc/fedora-release') + + if debian: + # TODO: check if these work for debian users + check_gcc = run_command("dpkg -l | grep build-essential &>/dev/null", exit_on_error=False)[2] == 0 + check_python = run_command("dpkg -l | grep python3.10-dev &>/dev/null", exit_on_error=False)[2] == 0 + + if not check_gcc or not check_python: + raise UserWarning( + "The packages build-essential and python3.10-dev are required for this script to run. Please install them. See the README for more information.") + elif fedora: + check_gcc = run_command("rpm -q gcc &>/dev/null", exit_on_error=False)[2] == 0 + check_python = run_command("rpm -q python3.10-devel &>/dev/null", exit_on_error=False)[2] == 0 + + if not check_gcc or not check_python: + raise UserWarning( + "The package python3.10-devel and the Development Tools group are required for this script to run. Please install them. See the README for more information.") + else: + logger.warning( + "Unsupported OS detected. Please ensure you have the following packages installed or their equivalent: build-essential, python3.10-dev") + + +def run_command(command: str, exit_on_error: bool = True): + logger.debug(f"Running command: {command}") + process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = process.communicate() + + if process.returncode != 0: + logger.fatal(f"Failed to run command: {command}") + raise Exception(f"Failed to run command: {command}") + + return out, err, process.returncode + + +def load_service_from_string(service: str) -> Stack: + logger.debug(f"Loading service from string: {service}") + + try: + service_name = service.replace("_", " ").title().replace(" ", "") + + module = importlib.import_module(f"services.{service}") + met = getattr(module, service_name) + return met() + + except ModuleNotFoundError as e: + logger.error(f"Failed to load service: {e}") + return None diff --git a/core/vars.py b/core/vars.py new file mode 100644 index 0000000..01fd4d3 --- /dev/null +++ b/core/vars.py @@ -0,0 +1,19 @@ +import logging +import os + +PYTHON_EXEC = 'python3.10' +PATH = os.path.dirname(os.path.abspath(__file__)) +ROCM_VERSION = "6.1.2" + +logger = logging.getLogger('ai-suite-rocm') + +services = [ + 'background_removal_dis', + 'comfy_ui', + 'stable_diffusion_webui', + 'stable_diffusion_forge', + 'text_generation_webui', + 'xtts_webui' +] + +loaded_services = {} diff --git a/main.py b/main.py index d3ae904..94e55a7 100644 --- a/main.py +++ b/main.py @@ -1,65 +1,14 @@ -import json import logging -import os -import subprocess import sys -import ui - -PYTHON_EXEC = 'python3.10' -PATH = os.path.dirname(os.path.abspath(__file__)) -ROCM_VERSION = "6.1.2" - -logger = logging.getLogger('ai-suite-rocm') -config = None - - -class Config: - data = {} - - def __init__(self): - self.file = os.path.join(os.path.expanduser("~"), ".config", "ai-suite-rocm", "config.json") - - self.create() - self.read() - - def create(self): - if not os.path.exists(self.file): - os.makedirs(os.path.dirname(self.file), exist_ok=True) - with open(self.file, "w") as f: - f.write("{}") - - logger.info(f"Created config file at {self.file}") - - def read(self): - with open(self.file, "r") as f: - self.data = json.load(f) - - def write(self): - with open(self.file, "w") as f: - json.dump(self.data, f) - - def get(self, key: str): - return self.data.get(key) - - def set(self, key: str, value): - self.data[key] = value - self.write() - - def has(self, key: str): - return key in self.data - - def remove(self, key: str): - self.data.pop(key) - self.write() - - def clear(self): - self.data = {} - self.write() +from core import config +from core.utils import check_for_build_essentials, load_service_from_string +from core.vars import logger, services, loaded_services +from ui.choices import update_choices +from ui.interface import run_interactive_cmd_ui def setup_logger(level: logger.level = logging.INFO): - global logger if not logger.hasHandlers(): logger.setLevel(level) handler = logging.StreamHandler(sys.stdout) @@ -68,78 +17,24 @@ def setup_logger(level: logger.level = logging.INFO): def setup_config(): - global config - config = Config() + config.create() + config.read() -def get_config(): - global config - return config +def load_services(): + for service in services: + loaded_services[service] = load_service_from_string(service) -def run_command(command: str, exit_on_error: bool = True): - logger.debug(f"Running command: {command}") - process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - out, err = process.communicate() - - if process.returncode != 0: - logger.fatal(f"Failed to run command: {command}") - raise Exception(f"Failed to run command: {command}") - - return out, err, process.returncode - - -def check_for_build_essentials(): - logger.debug("Checking for build essentials...") - debian = os.path.exists('/etc/debian_version') - fedora = os.path.exists('/etc/fedora-release') - - if debian: - # TODO: check if these work for debian users - check_gcc = run_command("dpkg -l | grep build-essential &>/dev/null", exit_on_error=False)[2] == 0 - check_python = run_command("dpkg -l | grep python3.10-dev &>/dev/null", exit_on_error=False)[2] == 0 - - if not check_gcc or not check_python: - raise UserWarning( - "The packages build-essential and python3.10-dev are required for this script to run. Please install them. See the README for more information.") - elif fedora: - check_gcc = run_command("rpm -q gcc &>/dev/null", exit_on_error=False)[2] == 0 - check_python = run_command("rpm -q python3.10-devel &>/dev/null", exit_on_error=False)[2] == 0 - - if not check_gcc or not check_python: - raise UserWarning( - "The package python3.10-devel and the Development Tools group are required for this script to run. Please install them. See the README for more information.") - else: - logger.warning( - "Unsupported OS detected. Please ensure you have the following packages installed or their equivalent: build-essential, python3.10-dev") - - -def run_interactive_cmd_ui(): - while True: - choice = ui.start.ask() - - match choice: - case "Start services": - services = ui.start_services.ask() - for service in services: - logger.info(f"Starting service: {service}") - pass - pass - case "Stop services": - pass - case "Install/update services": - pass - case "Uninstall services": - pass - case "Exit": - print("Exiting...") - exit(0) - if __name__ == '__main__': setup_logger(logging.DEBUG) + logger.info("Starting AI Suite for ROCM") + setup_config() - logger.info("Starting AI Suite for ROCM") check_for_build_essentials() + load_services() + + update_choices() run_interactive_cmd_ui() diff --git a/services/__init__.py b/services/__init__.py deleted file mode 100644 index 01e3e8a..0000000 --- a/services/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from services.background_removal_dis import BGRemovalDIS -from services.comfyui import ComfyUI -from services.services import Stack -from services.stablediffusion_forge import StableDiffusionForge -from services.stablediffusion_webui import StableDiffusionWebUI -from services.textgen import TextGeneration -from services.xtts import XTTS diff --git a/services/background_removal_dis.py b/services/background_removal_dis.py index 21e6666..c45636e 100644 --- a/services/background_removal_dis.py +++ b/services/background_removal_dis.py @@ -1,11 +1,11 @@ -from services import Stack +from core.stack import Stack -class BGRemovalDIS(Stack): +class BackgroundRemovalDis(Stack): def __init__(self): super().__init__( - 'BGRemovalDIS', - 'bg-remove-dis-rocm', + 'Background Removal (DIS)', + 'background_removal_dis', 5005, 'https://huggingface.co/spaces/ECCV2022/dis-background-removal' ) @@ -31,4 +31,4 @@ class BGRemovalDIS(Stack): def _launch(self): args = ["--port", str(self.port)] self.python(f"app.py {' '.join(args)}", current_dir="webui", - env=["TORCH_BLAS_PREFER_HIPBLASLT=0"]) + env=["TORCH_BLAS_PREFER_HIPBLASLT=0"], daemon=True) diff --git a/services/comfyui.py b/services/comfy_ui.py similarity index 86% rename from services/comfyui.py rename to services/comfy_ui.py index a8b694c..18fb988 100644 --- a/services/comfyui.py +++ b/services/comfy_ui.py @@ -1,11 +1,11 @@ -from services import Stack +from core.stack import Stack -class ComfyUI(Stack): +class ComfyUi(Stack): def __init__(self): super().__init__( 'ComfyUI', - 'comfyui-rocm', + 'comfy_ui', 5004, 'https://github.com/comfyanonymous/ComfyUI.git' ) @@ -31,4 +31,4 @@ class ComfyUI(Stack): def _launch(self): args = ["--port", str(self.port)] self.python(f"main.py {' '.join(args)}", current_dir="webui", - env=["TORCH_BLAS_PREFER_HIPBLASLT=0"]) + env=["TORCH_BLAS_PREFER_HIPBLASLT=0"], daemon=True) diff --git a/services/stablediffusion_forge.py b/services/stable_diffusion_forge.py similarity index 83% rename from services/stablediffusion_forge.py rename to services/stable_diffusion_forge.py index 9176a78..b8e2d96 100644 --- a/services/stablediffusion_forge.py +++ b/services/stable_diffusion_forge.py @@ -1,11 +1,11 @@ -from services import Stack +from core.stack import Stack class StableDiffusionForge(Stack): def __init__(self): super().__init__( 'StableDiffusion Forge WebUI', - 'stablediffusion-forge-rocm', + 'stable_diffusion_forge', 5003, 'https://github.com/lllyasviel/stable-diffusion-webui-forge' ) @@ -24,4 +24,4 @@ class StableDiffusionForge(Stack): def _launch(self): args = ["--listen", "--enable-insecure-extension-access", "--port", str(self.port)] self.python(f"launch.py {' '.join(args)}", current_dir="webui", - env=["TORCH_BLAS_PREFER_HIPBLASLT=0"]) + env=["TORCH_BLAS_PREFER_HIPBLASLT=0"], daemon=True) diff --git a/services/stablediffusion_webui.py b/services/stable_diffusion_webui.py similarity index 79% rename from services/stablediffusion_webui.py rename to services/stable_diffusion_webui.py index 67e235f..da84989 100644 --- a/services/stablediffusion_webui.py +++ b/services/stable_diffusion_webui.py @@ -1,11 +1,11 @@ -from services import Stack +from core.stack import Stack -class StableDiffusionWebUI(Stack): +class StableDiffusionWebui(Stack): def __init__(self): super().__init__( 'StableDiffusion WebUI', - 'stablediffusion-webui-rocm', + 'stable_diffusion_webui', 5002, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui' ) @@ -24,4 +24,4 @@ class StableDiffusionWebUI(Stack): def _launch(self): args = ["--listen", "--enable-insecure-extension-access", "--port", str(self.port)] self.python(f"launch.py {' '.join(args)}", current_dir="webui", - env=["TORCH_BLAS_PREFER_HIPBLASLT=0"]) + env=["TORCH_BLAS_PREFER_HIPBLASLT=0"], daemon=True) diff --git a/services/textgen.py b/services/text_generation_webui.py similarity index 96% rename from services/textgen.py rename to services/text_generation_webui.py index 7bfd3b2..b0c7d12 100644 --- a/services/textgen.py +++ b/services/text_generation_webui.py @@ -1,11 +1,11 @@ -from services import Stack +from core.stack import Stack -class TextGeneration(Stack): +class TextGenerationWebui(Stack): def __init__(self): super().__init__( 'Text Generation', - 'text-generation-rocm', + 'text_generation_webui', 5000, 'https://github.com/oobabooga/text-generation-webui/' ) diff --git a/services/xtts.py b/services/xtts_webui.py similarity index 91% rename from services/xtts.py rename to services/xtts_webui.py index 0121deb..8ef2480 100644 --- a/services/xtts.py +++ b/services/xtts_webui.py @@ -1,11 +1,11 @@ -from services import Stack +from core.stack import Stack -class XTTS(Stack): +class XttsWebui(Stack): def __init__(self): super().__init__( 'XTTS WebUI', - 'xtts-rocm', + 'xtts_webui', 5001, 'https://github.com/daswer123/xtts-webui' ) @@ -35,4 +35,4 @@ class XTTS(Stack): def _launch(self): args = ["--host", "0.0.0.0", "--port", str(self.port)] self.python(f"server.py {' '.join(args)}", current_dir="webui", - env=["TORCH_BLAS_PREFER_HIPBLASLT=0"]) + env=["TORCH_BLAS_PREFER_HIPBLASLT=0"], daemon=True) diff --git a/ui.py b/ui.py deleted file mode 100644 index b500e97..0000000 --- a/ui.py +++ /dev/null @@ -1,43 +0,0 @@ -import questionary -from questionary import Choice - -from services import BGRemovalDIS, ComfyUI, StableDiffusionWebUI, StableDiffusionForge, TextGeneration, XTTS - -services = { - "Background Removal (DIS)": BGRemovalDIS, - "ComfyUI": ComfyUI, - "StableDiffusion (AUTOMATIC1111)": StableDiffusionWebUI, - "StableDiffusion Forge": StableDiffusionForge, - "TextGeneration (oobabooga)": TextGeneration, - "XTTS": XTTS -} - -start = questionary.select( - "Choose an option:", - choices=[ - Choice("Start services"), - Choice("Stop services"), - Choice("Install/update services"), - Choice("Uninstall services"), - Choice("Exit") - ]) - -start_services = questionary.checkbox( - "Select services to start:", - choices=[Choice(service) for service in services.keys()] -) - -stop_services = questionary.checkbox( - "Select services to stop:", - choices=[Choice(service) for service in services.keys()] -) - -install_service = questionary.checkbox( - "Select service to install/update:", - choices=[Choice(service) for service in services.keys()] -) - -uninstall_service = questionary.checkbox( - "Select service to uninstall:", - choices=[Choice(service) for service in services.keys()] -) diff --git a/ui/choices.py b/ui/choices.py new file mode 100644 index 0000000..97b5263 --- /dev/null +++ b/ui/choices.py @@ -0,0 +1,53 @@ +import questionary +from questionary import Choice + +from core.vars import loaded_services + +start = None +start_service = None +stop_service = None +install_service = None +uninstall_service = None +are_you_sure = None +any_key = None + + +def update_choices(): + global start, start_service, stop_service, install_service, uninstall_service, are_you_sure, any_key + + start = questionary.select( + "Choose an option:", + choices=[ + Choice("Start service"), + Choice("Stop service"), + Choice("Install/update service"), + Choice("Uninstall service"), + Choice("exit") + ]) + + _services_choices = [Choice(service.name, value=service.id) for service in loaded_services.values()] + _services_choices.append(Choice("go back", value="back")) + + start_service = questionary.select( + "Select service to start:", + choices=_services_choices + ) + + stop_service = questionary.select( + "Select service to stop:", + choices=_services_choices + ) + + install_service = questionary.select( + "Select service to install/update:", + choices=_services_choices + ) + + uninstall_service = questionary.select( + "Select service to uninstall:", + choices=_services_choices + ) + + are_you_sure = questionary.confirm("Are you sure?") + + any_key = questionary.text("Press any key to continue") diff --git a/ui/interface.py b/ui/interface.py new file mode 100644 index 0000000..c7c753d --- /dev/null +++ b/ui/interface.py @@ -0,0 +1,65 @@ +import os + +import questionary + +from core.vars import logger, loaded_services +from ui import choices + + +def clear_terminal(): + os.system('cls' if os.name == 'nt' else 'clear') + + +def handle_services(action, service): + clear_terminal() + + if service == "back": + return + + service = loaded_services[service] + if action == "start": + logger.info(f"Starting service: {service.name}") + service.start() + elif action == "stop": + logger.info(f"Stopping service: {service.name}") + service.stop() + elif action == "update": + confirmation = choices.are_you_sure.ask() + if confirmation: + logger.info(f"Installing/updating service: {service.name}") + service.update() + elif action == "uninstall": + confirmation = choices.are_you_sure.ask() + if confirmation: + type_confirmation = questionary.text(f"Please type {service.id} to confirm uninstallation:") + if type_confirmation.ask() == service.id: + logger.info(f"Uninstalling service: {service.name}") + service.uninstall() + + choices.any_key.ask() + + +def run_interactive_cmd_ui(): + while True: + clear_terminal() + choice = choices.start.ask() + + if choice == "Start service": + service = choices.start_service.ask() + handle_services("start", service) + + elif choice == "Stop service": + service = choices.stop_service.ask() + handle_services("stop", service) + + elif choice == "Install/update service": + service = choices.install_service.ask() + handle_services("update", service) + + elif choice == "Uninstall service": + service = choices.uninstall_service.ask() + handle_services("uninstall", service) + + elif choice == "Exit": + print("Exiting...") + exit(0) diff --git a/utils.py b/utils.py deleted file mode 100644 index ad8fab0..0000000 --- a/utils.py +++ /dev/null @@ -1,27 +0,0 @@ -import json -import urllib - -from main import ROCM_VERSION, logger - - -def get_prebuilts(repo_owner: str = "M4TH1EU", repo_name: str = "ai-suite-rocm-local", - release_tag: str = f"prebuilt-whl-{ROCM_VERSION}") -> list: - api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/tags/{release_tag}" - - try: - with urllib.request.urlopen(api_url) as response: - if response.status != 200: - logger.error(f"Failed to fetch data: HTTP Status {response.status}") - return [] - - release_data = json.load(response) - - assets = release_data.get('assets', []) - if not assets: - logger.error("No assets found in release data") - return [] - - return assets - - except urllib.error.URLError as e: - logger.error(f"Error fetching release data: {e}")