massive refactor, added terminal interface, daemon system (still not working)

This commit is contained in:
Mathieu Broillet 2024-08-28 22:20:36 +02:00
parent 4dfd9faa39
commit a6728d3d1c
Signed by: mathieu
GPG Key ID: A08E484FE95074C1
16 changed files with 328 additions and 232 deletions

56
core/config.py Normal file
View File

@ -0,0 +1,56 @@
import json
import os
from core.vars import logger
data = {}
file = os.path.join(os.path.expanduser("~"), ".config", "ai-suite-rocm", "config.json")
def create():
if not os.path.exists(file):
os.makedirs(os.path.dirname(file), exist_ok=True)
with open(file, "w") as f:
f.write("{}")
logger.info(f"Created config file at {file}")
def read():
global data
with open(file, "r") as f:
data = json.load(f)
def write():
global data
with open(file, "w") as f:
json.dump(data, f)
def get(key: str):
global data
return data.get(key)
def set(key: str, value):
global data
data[key] = value
write()
def has(key: str):
global data
return key in data
def remove(key: str):
global data
data.pop(key)
write()
def clear():
global data
data = {}
write()

View File

@ -1,25 +1,25 @@
import logging
import os
import shutil
import subprocess
import psutil
import main
import utils
from main import PYTHON_EXEC, logger, get_config
from core import utils, config
from core.vars import logger, PYTHON_EXEC
class Stack:
def __init__(self, name: str, path: str, port: int, url: str):
def __init__(self, name: str, id: str, port: int, url: str):
self.name = name
self.path = os.path.join(os.path.expanduser("~"), ".ai-suite-rocm", path)
self.id = id
self.path = os.path.join(os.path.expanduser("~"), ".ai-suite-rocm", id)
self.url = url
self.port = port
self.process = None
def install(self):
self.create_file('.installed', 'true')
logger.info(f"Installed {self.name}")
@ -110,18 +110,20 @@ class Stack:
if daemon:
# Check if previous run process is saved
if get_config().has(f"{self.name}-pid"):
if config.has(f"{self.name}-pid"):
# Check if PID still running
if psutil.pid_exists(main.config.get(f"{self.name}-pid")):
if psutil.pid_exists(config.get(f"{self.name}-pid")):
choice = input(f"{self.name} is already running, do you want to restart it? (y/n): ")
if choice.lower() == 'y':
pid = main.config.get(f"{self.name}-pid")
pid = config.get(f"{self.name}-pid")
logger.debug(f"Killing previous daemon with PID: {pid}")
psutil.Process(pid).kill()
else:
# TODO: attach to subprocess?
logger.info("Continuing without restarting...")
return
else:
logger.warning(
@ -131,8 +133,8 @@ class Stack:
logger.debug(f"Starting {self.name} as daemon with command: {cmd}")
cmd = f"{cmd} &"
process = subprocess.Popen(cmd, shell=True)
get_config().set(f"{self.name}-pid", process.pid + 1)
process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
config.set(f"{self.name}-pid", process.pid + 1)
return
else:
logger.debug(f"Running command: {cmd}")
@ -185,7 +187,7 @@ class Stack:
def remove_dir(self, name):
logger.debug(f"Removing directory {name}")
os.rmdir(os.path.join(self.path, name))
shutil.rmtree(os.path.join(self.path, name))
def move_file_or_dir(self, src, dest):
logger.debug(f"Moving file/dir {src} to {dest}")

83
core/utils.py Normal file
View File

@ -0,0 +1,83 @@
import importlib
import json
import os
import subprocess
import urllib
from core.stack import Stack
from core.vars import ROCM_VERSION, logger
def get_prebuilts(repo_owner: str = "M4TH1EU", repo_name: str = "ai-suite-rocm-local",
release_tag: str = f"prebuilt-whl-{ROCM_VERSION}") -> list:
api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/tags/{release_tag}"
try:
with urllib.request.urlopen(api_url) as response:
if response.status != 200:
logger.error(f"Failed to fetch data: HTTP Status {response.status}")
return []
release_data = json.load(response)
assets = release_data.get('assets', [])
if not assets:
logger.error("No assets found in release data")
return []
return assets
except urllib.error.URLError as e:
logger.error(f"Error fetching release data: {e}")
def check_for_build_essentials():
logger.debug("Checking for build essentials...")
debian = os.path.exists('/etc/debian_version')
fedora = os.path.exists('/etc/fedora-release')
if debian:
# TODO: check if these work for debian users
check_gcc = run_command("dpkg -l | grep build-essential &>/dev/null", exit_on_error=False)[2] == 0
check_python = run_command("dpkg -l | grep python3.10-dev &>/dev/null", exit_on_error=False)[2] == 0
if not check_gcc or not check_python:
raise UserWarning(
"The packages build-essential and python3.10-dev are required for this script to run. Please install them. See the README for more information.")
elif fedora:
check_gcc = run_command("rpm -q gcc &>/dev/null", exit_on_error=False)[2] == 0
check_python = run_command("rpm -q python3.10-devel &>/dev/null", exit_on_error=False)[2] == 0
if not check_gcc or not check_python:
raise UserWarning(
"The package python3.10-devel and the Development Tools group are required for this script to run. Please install them. See the README for more information.")
else:
logger.warning(
"Unsupported OS detected. Please ensure you have the following packages installed or their equivalent: build-essential, python3.10-dev")
def run_command(command: str, exit_on_error: bool = True):
logger.debug(f"Running command: {command}")
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, err = process.communicate()
if process.returncode != 0:
logger.fatal(f"Failed to run command: {command}")
raise Exception(f"Failed to run command: {command}")
return out, err, process.returncode
def load_service_from_string(service: str) -> Stack:
logger.debug(f"Loading service from string: {service}")
try:
service_name = service.replace("_", " ").title().replace(" ", "")
module = importlib.import_module(f"services.{service}")
met = getattr(module, service_name)
return met()
except ModuleNotFoundError as e:
logger.error(f"Failed to load service: {e}")
return None

19
core/vars.py Normal file
View File

@ -0,0 +1,19 @@
import logging
import os
PYTHON_EXEC = 'python3.10'
PATH = os.path.dirname(os.path.abspath(__file__))
ROCM_VERSION = "6.1.2"
logger = logging.getLogger('ai-suite-rocm')
services = [
'background_removal_dis',
'comfy_ui',
'stable_diffusion_webui',
'stable_diffusion_forge',
'text_generation_webui',
'xtts_webui'
]
loaded_services = {}

135
main.py
View File

@ -1,65 +1,14 @@
import json
import logging
import os
import subprocess
import sys
import ui
PYTHON_EXEC = 'python3.10'
PATH = os.path.dirname(os.path.abspath(__file__))
ROCM_VERSION = "6.1.2"
logger = logging.getLogger('ai-suite-rocm')
config = None
class Config:
data = {}
def __init__(self):
self.file = os.path.join(os.path.expanduser("~"), ".config", "ai-suite-rocm", "config.json")
self.create()
self.read()
def create(self):
if not os.path.exists(self.file):
os.makedirs(os.path.dirname(self.file), exist_ok=True)
with open(self.file, "w") as f:
f.write("{}")
logger.info(f"Created config file at {self.file}")
def read(self):
with open(self.file, "r") as f:
self.data = json.load(f)
def write(self):
with open(self.file, "w") as f:
json.dump(self.data, f)
def get(self, key: str):
return self.data.get(key)
def set(self, key: str, value):
self.data[key] = value
self.write()
def has(self, key: str):
return key in self.data
def remove(self, key: str):
self.data.pop(key)
self.write()
def clear(self):
self.data = {}
self.write()
from core import config
from core.utils import check_for_build_essentials, load_service_from_string
from core.vars import logger, services, loaded_services
from ui.choices import update_choices
from ui.interface import run_interactive_cmd_ui
def setup_logger(level: logger.level = logging.INFO):
global logger
if not logger.hasHandlers():
logger.setLevel(level)
handler = logging.StreamHandler(sys.stdout)
@ -68,78 +17,24 @@ def setup_logger(level: logger.level = logging.INFO):
def setup_config():
global config
config = Config()
config.create()
config.read()
def get_config():
global config
return config
def load_services():
for service in services:
loaded_services[service] = load_service_from_string(service)
def run_command(command: str, exit_on_error: bool = True):
logger.debug(f"Running command: {command}")
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, err = process.communicate()
if process.returncode != 0:
logger.fatal(f"Failed to run command: {command}")
raise Exception(f"Failed to run command: {command}")
return out, err, process.returncode
def check_for_build_essentials():
logger.debug("Checking for build essentials...")
debian = os.path.exists('/etc/debian_version')
fedora = os.path.exists('/etc/fedora-release')
if debian:
# TODO: check if these work for debian users
check_gcc = run_command("dpkg -l | grep build-essential &>/dev/null", exit_on_error=False)[2] == 0
check_python = run_command("dpkg -l | grep python3.10-dev &>/dev/null", exit_on_error=False)[2] == 0
if not check_gcc or not check_python:
raise UserWarning(
"The packages build-essential and python3.10-dev are required for this script to run. Please install them. See the README for more information.")
elif fedora:
check_gcc = run_command("rpm -q gcc &>/dev/null", exit_on_error=False)[2] == 0
check_python = run_command("rpm -q python3.10-devel &>/dev/null", exit_on_error=False)[2] == 0
if not check_gcc or not check_python:
raise UserWarning(
"The package python3.10-devel and the Development Tools group are required for this script to run. Please install them. See the README for more information.")
else:
logger.warning(
"Unsupported OS detected. Please ensure you have the following packages installed or their equivalent: build-essential, python3.10-dev")
def run_interactive_cmd_ui():
while True:
choice = ui.start.ask()
match choice:
case "Start services":
services = ui.start_services.ask()
for service in services:
logger.info(f"Starting service: {service}")
pass
pass
case "Stop services":
pass
case "Install/update services":
pass
case "Uninstall services":
pass
case "Exit":
print("Exiting...")
exit(0)
if __name__ == '__main__':
setup_logger(logging.DEBUG)
logger.info("Starting AI Suite for ROCM")
setup_config()
logger.info("Starting AI Suite for ROCM")
check_for_build_essentials()
load_services()
update_choices()
run_interactive_cmd_ui()

View File

@ -1,7 +0,0 @@
from services.background_removal_dis import BGRemovalDIS
from services.comfyui import ComfyUI
from services.services import Stack
from services.stablediffusion_forge import StableDiffusionForge
from services.stablediffusion_webui import StableDiffusionWebUI
from services.textgen import TextGeneration
from services.xtts import XTTS

View File

@ -1,11 +1,11 @@
from services import Stack
from core.stack import Stack
class BGRemovalDIS(Stack):
class BackgroundRemovalDis(Stack):
def __init__(self):
super().__init__(
'BGRemovalDIS',
'bg-remove-dis-rocm',
'Background Removal (DIS)',
'background_removal_dis',
5005,
'https://huggingface.co/spaces/ECCV2022/dis-background-removal'
)
@ -31,4 +31,4 @@ class BGRemovalDIS(Stack):
def _launch(self):
args = ["--port", str(self.port)]
self.python(f"app.py {' '.join(args)}", current_dir="webui",
env=["TORCH_BLAS_PREFER_HIPBLASLT=0"])
env=["TORCH_BLAS_PREFER_HIPBLASLT=0"], daemon=True)

View File

@ -1,11 +1,11 @@
from services import Stack
from core.stack import Stack
class ComfyUI(Stack):
class ComfyUi(Stack):
def __init__(self):
super().__init__(
'ComfyUI',
'comfyui-rocm',
'comfy_ui',
5004,
'https://github.com/comfyanonymous/ComfyUI.git'
)
@ -31,4 +31,4 @@ class ComfyUI(Stack):
def _launch(self):
args = ["--port", str(self.port)]
self.python(f"main.py {' '.join(args)}", current_dir="webui",
env=["TORCH_BLAS_PREFER_HIPBLASLT=0"])
env=["TORCH_BLAS_PREFER_HIPBLASLT=0"], daemon=True)

View File

@ -1,11 +1,11 @@
from services import Stack
from core.stack import Stack
class StableDiffusionForge(Stack):
def __init__(self):
super().__init__(
'StableDiffusion Forge WebUI',
'stablediffusion-forge-rocm',
'stable_diffusion_forge',
5003,
'https://github.com/lllyasviel/stable-diffusion-webui-forge'
)
@ -24,4 +24,4 @@ class StableDiffusionForge(Stack):
def _launch(self):
args = ["--listen", "--enable-insecure-extension-access", "--port", str(self.port)]
self.python(f"launch.py {' '.join(args)}", current_dir="webui",
env=["TORCH_BLAS_PREFER_HIPBLASLT=0"])
env=["TORCH_BLAS_PREFER_HIPBLASLT=0"], daemon=True)

View File

@ -1,11 +1,11 @@
from services import Stack
from core.stack import Stack
class StableDiffusionWebUI(Stack):
class StableDiffusionWebui(Stack):
def __init__(self):
super().__init__(
'StableDiffusion WebUI',
'stablediffusion-webui-rocm',
'stable_diffusion_webui',
5002,
'https://github.com/AUTOMATIC1111/stable-diffusion-webui'
)
@ -24,4 +24,4 @@ class StableDiffusionWebUI(Stack):
def _launch(self):
args = ["--listen", "--enable-insecure-extension-access", "--port", str(self.port)]
self.python(f"launch.py {' '.join(args)}", current_dir="webui",
env=["TORCH_BLAS_PREFER_HIPBLASLT=0"])
env=["TORCH_BLAS_PREFER_HIPBLASLT=0"], daemon=True)

View File

@ -1,11 +1,11 @@
from services import Stack
from core.stack import Stack
class TextGeneration(Stack):
class TextGenerationWebui(Stack):
def __init__(self):
super().__init__(
'Text Generation',
'text-generation-rocm',
'text_generation_webui',
5000,
'https://github.com/oobabooga/text-generation-webui/'
)

View File

@ -1,11 +1,11 @@
from services import Stack
from core.stack import Stack
class XTTS(Stack):
class XttsWebui(Stack):
def __init__(self):
super().__init__(
'XTTS WebUI',
'xtts-rocm',
'xtts_webui',
5001,
'https://github.com/daswer123/xtts-webui'
)
@ -35,4 +35,4 @@ class XTTS(Stack):
def _launch(self):
args = ["--host", "0.0.0.0", "--port", str(self.port)]
self.python(f"server.py {' '.join(args)}", current_dir="webui",
env=["TORCH_BLAS_PREFER_HIPBLASLT=0"])
env=["TORCH_BLAS_PREFER_HIPBLASLT=0"], daemon=True)

43
ui.py
View File

@ -1,43 +0,0 @@
import questionary
from questionary import Choice
from services import BGRemovalDIS, ComfyUI, StableDiffusionWebUI, StableDiffusionForge, TextGeneration, XTTS
services = {
"Background Removal (DIS)": BGRemovalDIS,
"ComfyUI": ComfyUI,
"StableDiffusion (AUTOMATIC1111)": StableDiffusionWebUI,
"StableDiffusion Forge": StableDiffusionForge,
"TextGeneration (oobabooga)": TextGeneration,
"XTTS": XTTS
}
start = questionary.select(
"Choose an option:",
choices=[
Choice("Start services"),
Choice("Stop services"),
Choice("Install/update services"),
Choice("Uninstall services"),
Choice("Exit")
])
start_services = questionary.checkbox(
"Select services to start:",
choices=[Choice(service) for service in services.keys()]
)
stop_services = questionary.checkbox(
"Select services to stop:",
choices=[Choice(service) for service in services.keys()]
)
install_service = questionary.checkbox(
"Select service to install/update:",
choices=[Choice(service) for service in services.keys()]
)
uninstall_service = questionary.checkbox(
"Select service to uninstall:",
choices=[Choice(service) for service in services.keys()]
)

53
ui/choices.py Normal file
View File

@ -0,0 +1,53 @@
import questionary
from questionary import Choice
from core.vars import loaded_services
start = None
start_service = None
stop_service = None
install_service = None
uninstall_service = None
are_you_sure = None
any_key = None
def update_choices():
global start, start_service, stop_service, install_service, uninstall_service, are_you_sure, any_key
start = questionary.select(
"Choose an option:",
choices=[
Choice("Start service"),
Choice("Stop service"),
Choice("Install/update service"),
Choice("Uninstall service"),
Choice("exit")
])
_services_choices = [Choice(service.name, value=service.id) for service in loaded_services.values()]
_services_choices.append(Choice("go back", value="back"))
start_service = questionary.select(
"Select service to start:",
choices=_services_choices
)
stop_service = questionary.select(
"Select service to stop:",
choices=_services_choices
)
install_service = questionary.select(
"Select service to install/update:",
choices=_services_choices
)
uninstall_service = questionary.select(
"Select service to uninstall:",
choices=_services_choices
)
are_you_sure = questionary.confirm("Are you sure?")
any_key = questionary.text("Press any key to continue")

65
ui/interface.py Normal file
View File

@ -0,0 +1,65 @@
import os
import questionary
from core.vars import logger, loaded_services
from ui import choices
def clear_terminal():
os.system('cls' if os.name == 'nt' else 'clear')
def handle_services(action, service):
clear_terminal()
if service == "back":
return
service = loaded_services[service]
if action == "start":
logger.info(f"Starting service: {service.name}")
service.start()
elif action == "stop":
logger.info(f"Stopping service: {service.name}")
service.stop()
elif action == "update":
confirmation = choices.are_you_sure.ask()
if confirmation:
logger.info(f"Installing/updating service: {service.name}")
service.update()
elif action == "uninstall":
confirmation = choices.are_you_sure.ask()
if confirmation:
type_confirmation = questionary.text(f"Please type {service.id} to confirm uninstallation:")
if type_confirmation.ask() == service.id:
logger.info(f"Uninstalling service: {service.name}")
service.uninstall()
choices.any_key.ask()
def run_interactive_cmd_ui():
while True:
clear_terminal()
choice = choices.start.ask()
if choice == "Start service":
service = choices.start_service.ask()
handle_services("start", service)
elif choice == "Stop service":
service = choices.stop_service.ask()
handle_services("stop", service)
elif choice == "Install/update service":
service = choices.install_service.ask()
handle_services("update", service)
elif choice == "Uninstall service":
service = choices.uninstall_service.ask()
handle_services("uninstall", service)
elif choice == "Exit":
print("Exiting...")
exit(0)

View File

@ -1,27 +0,0 @@
import json
import urllib
from main import ROCM_VERSION, logger
def get_prebuilts(repo_owner: str = "M4TH1EU", repo_name: str = "ai-suite-rocm-local",
release_tag: str = f"prebuilt-whl-{ROCM_VERSION}") -> list:
api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/tags/{release_tag}"
try:
with urllib.request.urlopen(api_url) as response:
if response.status != 200:
logger.error(f"Failed to fetch data: HTTP Status {response.status}")
return []
release_data = json.load(response)
assets = release_data.get('assets', [])
if not assets:
logger.error("No assets found in release data")
return []
return assets
except urllib.error.URLError as e:
logger.error(f"Error fetching release data: {e}")