diff --git a/.gitignore b/.gitignore index 15a5c92..987619a 100644 --- a/.gitignore +++ b/.gitignore @@ -157,3 +157,4 @@ cython_debug/ *code-workspace settings.user.ini +.aider* diff --git a/requirements.txt b/requirements.txt index b93730e..826cb8e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ -replicate==0.33.0 +replicate==0.34.1 python-dotenv==1.0.1 token_count==0.2.1 loguru==0.7.2 -nicegui==2.1.0 +nicegui==2.3.0 httpx==0.27.2 dynaconf==3.2.6 toml==0.10.2 diff --git a/src/config.py b/src/config.py deleted file mode 100644 index ea20b8f..0000000 --- a/src/config.py +++ /dev/null @@ -1,80 +0,0 @@ -import configparser -import os -from typing import Any, Type - -from loguru import logger - -DOCKERIZED = os.environ.get("DOCKER_CONTAINER", "False").lower() == "true" -CONFIG_DIR = "/app/settings" if DOCKERIZED else "." -DEFAULT_CONFIG_FILE = os.path.join(CONFIG_DIR, "settings.ini") -USER_CONFIG_FILE = os.path.join(CONFIG_DIR, "settings.user.ini") - -logger.info( - f"Configuration files: DEFAULT={DEFAULT_CONFIG_FILE}, USER={USER_CONFIG_FILE}" -) - -config = configparser.ConfigParser() -config.read([DEFAULT_CONFIG_FILE, USER_CONFIG_FILE]) -logger.info("Configuration files loaded") - - -def get_api_key(): - api_key = os.environ.get("REPLICATE_API_KEY") or config.get( - "secrets", "REPLICATE_API_KEY", fallback=None - ) - if api_key: - logger.info("API key retrieved successfully") - else: - logger.warning("No API key found") - return api_key - - -def get_setting( - section: str, key: str, fallback: Any = None, value_type: Type[Any] = str -) -> Any: - logger.info( - f"Attempting to get setting: section={section}, key={key}, fallback={fallback}, value_type={value_type}" - ) - try: - value = config.get(section, key) - logger.debug(f"Raw value retrieved: {value}") - if value_type == int: - result = int(value) - elif value_type == float: - result = float(value) - elif value_type == bool: - result = value.lower() in ("true", "yes", "1", "on") - else: - result = value - logger.info(f"Setting retrieved successfully: {result}") - return result - except (configparser.NoSectionError, configparser.NoOptionError) as e: - logger.warning(f"Setting not found: {str(e)}. Using fallback value: {fallback}") - return fallback - except ValueError as e: - logger.error( - f"Error converting setting value: {str(e)}. Using fallback value: {fallback}" - ) - return fallback - - -def set_setting(section, key, value): - logger.info(f"Setting value: section={section}, key={key}, value={value}") - if not config.has_section(section): - logger.info(f"Creating new section: {section}") - config.add_section(section) - config.set(section, key, str(value)) - logger.info("Value set successfully") - - -def save_settings(): - logger.info(f"Saving settings to {USER_CONFIG_FILE}") - try: - with open(USER_CONFIG_FILE, "w") as configfile: - config.write(configfile) - logger.info("Settings saved successfully") - except IOError as e: - logger.error(f"Error saving settings: {str(e)}") - - -logger.info("Config module initialized") diff --git a/src/gui.py b/src/gui.py deleted file mode 100644 index 5e84df3..0000000 --- a/src/gui.py +++ /dev/null @@ -1,828 +0,0 @@ -import asyncio -import json -import os -import urllib.parse -import zipfile -from datetime import datetime -from pathlib import Path - -import httpx -from config import get_api_key, get_setting, save_settings, set_setting -from loguru import logger -from nicegui import ui - -DOCKERIZED = os.environ.get("DOCKER_CONTAINER", False) - - -class Lightbox: - def __init__(self): - logger.debug("Initializing Lightbox") - with ui.dialog().props("maximized").classes("bg-black") as self.dialog: - self.dialog.on_key = self._handle_key - self.large_image = ui.image().props("no-spinner fit=scale-down") - self.image_list = [] - logger.debug("Lightbox initialized") - - def add_image( - self, - thumb_url: str, - orig_url: str, - thumb_classes: str = "w-32 h-32 object-cover", - ) -> ui.button: - logger.debug(f"Adding image to Lightbox: {orig_url}") - self.image_list.append(orig_url) - button = ui.button(on_click=lambda: self._open(orig_url)).props( - "flat dense square" - ) - with button: - ui.image(thumb_url).classes(thumb_classes) - logger.debug("Image added to Lightbox") - return button - - def _handle_key(self, e) -> None: - logger.debug(f"Handling key press in Lightbox: {e.key}") - if not e.action.keydown: - return - if e.key.escape: - logger.debug("Closing Lightbox dialog") - self.dialog.close() - image_index = self.image_list.index(self.large_image.source) - if e.key.arrow_left and image_index > 0: - logger.debug("Displaying previous image") - self._open(self.image_list[image_index - 1]) - if e.key.arrow_right and image_index < len(self.image_list) - 1: - logger.debug("Displaying next image") - self._open(self.image_list[image_index + 1]) - - def _open(self, url: str) -> None: - logger.debug(f"Opening image in Lightbox: {url}") - self.large_image.set_source(url) - self.dialog.open() - - -class ImageGeneratorGUI: - def __init__(self, image_generator): - logger.info("Initializing ImageGeneratorGUI") - self.image_generator = image_generator - self.api_key = get_api_key() or os.environ.get("REPLICATE_API_KEY", "") - self.last_generated_images = [] - self.setup_custom_styles() - self._attributes = [ - "prompt", - "flux_model", - "aspect_ratio", - "num_outputs", - "lora_scale", - "num_inference_steps", - "guidance_scale", - "output_format", - "output_quality", - "disable_safety_checker", - "width", - "height", - "seed", - "output_folder", - "replicate_model", - ] - - self.user_added_models = {} - self.prompt = get_setting("default", "prompt", "", str) - - self.flux_model = get_setting("default", "flux_model", "dev", str) - self.aspect_ratio = get_setting("default", "aspect_ratio", "1:1", str) - self.num_outputs = get_setting("default", "num_outputs", "1", int) - self.lora_scale = get_setting("default", "lora_scale", "1", float) - self.num_inference_steps = get_setting( - "default", "num_inference_steps", "28", int - ) - self.guidance_scale = get_setting("default", "guidance_scale", "3.5", float) - self.output_format = get_setting("default", "output_format", "png") - self.output_quality = get_setting("default", "output_quality", "80", int) - self.disable_safety_checker = get_setting( - "default", "disable_safety_checker", True, bool - ) - - self.width = get_setting("default", "width", "1024", int) - self.height = get_setting("default", "height", "1024", int) - self.seed = get_setting("default", "seed", "-1", int) - - self.output_folder = ( - "/app/output" - if DOCKERIZED - else get_setting("default", "output_folder", "/Downloads", str) - ) - models_json = get_setting("default", "models", '{"user_added": []}', str) - models = json.loads(models_json) - self.user_added_models = { - model: model for model in models.get("user_added", []) - } - self.model_options = list(self.user_added_models.keys()) - self.replicate_model = get_setting("default", "replicate_model", "", str) - - logger.info("ImageGeneratorGUI initialized") - - def setup_custom_styles(self): - logger.debug("Setting up custom styles") - ui.add_head_html(""" - - - """) - - def setup_ui(self): - logger.info("Setting up UI") - ui.dark_mode() - self.check_api_key() - - with ui.grid().classes( - "w-full h-screen md:h-full grid-cols-1 md:grid-cols-2 gap-2 md:gap-5 p-4 md:p-6 dark:bg-[#1f2328] bg-[#ffffff] md:auto-rows-min" - ): - with ui.card().classes( - "col-span-full modern-card dark:bg-[#25292e] bg-[#818b981f] flex-nowrap h-min" - ): - self.setup_top_panel() - - with ui.card().classes( - "col-span-full modern-card dark:bg-[#25292e] bg-[#818b981f]" - ): - self.setup_prompt_panel() - - with ui.card().classes( - "row-span-2 overflow-auto modern-card dark:bg-[#25292e] bg-[#818b981f]" - ): - self.setup_left_panel() - - with ui.card().classes( - "row-span-2 overflow-auto modern-card dark:bg-[#25292e] bg-[#818b981f]" - ): - self.setup_right_panel() - - logger.info("UI setup completed") - - def setup_top_panel(self): - logger.debug("Setting up top panel") - with ui.row().classes("w-full items-center"): - ui.label("Lumberjack - Replicate API Interface").classes( - "text-2xl/loose font-bold" - ) - dark_mode = ui.dark_mode(True) - ui.switch().bind_value(dark_mode).classes().props( - "dense checked-icon=dark_mode unchecked-icon=light_mode color=blue-7" - ) - ui.button( - icon="settings_suggest", - on_click=self.open_settings_popup, - color="#0969da", - ).classes("absolute-right mr-6 mt-3 mb-3") - - def setup_left_panel(self): - logger.debug("Setting up left panel") - with ui.row().classes("w-full flex-row flex-nowrap"): - self.replicate_model_select = ( - ui.select( - options=self.model_options, - label="Replicate Model", - value=self.replicate_model, - on_change=lambda e: asyncio.create_task( - self.update_replicate_model(e.value) - ), - ) - .classes("width-5/6 overflow-auto custom-select") - .tooltip("Select or manage Replicate models") - .props("filled") - ) - ui.button(icon="token", color="#0969da").classes("ml-2 mt-1.2").on( - "click", self.open_user_model_popup - ).props("size=1.3rem") - - self.flux_model_select = ( - ui.select( - ["dev", "schnell"], - label="Flux Model", - value=get_setting("default", "flux_model", "dev"), - ) - .classes("w-full text-gray-200") - .tooltip( - "Which model to run inferences with. The dev model needs around 28 steps but the schnell model only needs around 4 steps." - ) - .bind_value(self, "flux_model") - .props("filled") - ) - - with ui.row().classes("w-full flex-nowrap md:flex-wrap"): - self.aspect_ratio_select = ( - ui.select( - [ - "1:1", - "16:9", - "21:9", - "3:2", - "2:3", - "4:5", - "5:4", - "3:4", - "4:3", - "9:16", - "9:21", - "custom", - ], - label="Aspect Ratio", - value=get_setting("default", "aspect_ratio", "1:1"), - ) - .classes("w-1/2 md:w-full text-gray-200") - .bind_value(self, "aspect_ratio") - .tooltip( - "Width of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16)" - ) - .props("filled") - ) - self.aspect_ratio_select.on("change", self.toggle_custom_dimensions) - - with ui.column().classes("w-full").bind_visibility_from( - self.aspect_ratio_select, "value", value="custom" - ): - self.width_input = ( - ui.number( - "Width", - value=get_setting("default", "width", 1024, int), - min=256, - max=1440, - ) - .classes("w-full") - .bind_value(self, "width") - .tooltip( - "Width of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16)" - ) - ) - self.height_input = ( - ui.number( - "Height", - value=get_setting("default", "height", 1024, int), - min=256, - max=1440, - ) - .classes("w-full") - .bind_value(self, "height") - .tooltip( - "Height of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16)" - ) - ) - - self.num_outputs_input = ( - ui.number( - "Num Outputs", - value=get_setting("default", "num_outputs", 1, int), - min=1, - max=4, - ) - .classes("w-1/2 md:w-full") - .bind_value(self, "num_outputs") - .tooltip("Number of images to output.") - .props("filled") - ) - - with ui.row().classes("w-full flex-nowrap md:flex-wrap"): - self.lora_scale_input = ( - ui.number( - "LoRA Scale", - value=float(get_setting("default", "lora_scale", 1)), - min=-1, - max=2, - step=0.1, - ) - .classes("w-1/2 md:w-full") - .tooltip( - "Determines how strongly the LoRA should be applied. Sane results between 0 and 1." - ) - .props("filled") - .bind_value(self, "lora_scale") - ) - self.num_inference_steps_input = ( - ui.number( - "Num Inference Steps", - value=get_setting("default", "num_inference_steps", 28, int), - min=1, - max=50, - ) - .classes("w-1/2 md:w-full") - .tooltip("Number of Inference Steps") - .bind_value(self, "num_inference_steps") - .props("filled") - ) - - with ui.row().classes("w-full flex-nowrap md:flex-wrap"): - self.guidance_scale_input = ( - ui.number( - "Guidance Scale", - value=float(get_setting("default", "guidance_scale", 3.5)), - min=0, - max=10, - step=0.1, - precision=2, - ) - .classes("w-1/2 md:w-full") - .tooltip("Guidance Scale for the diffusion process") - .bind_value(self, "guidance_scale") - .props("filled") - ) - self.seed_input = ( - ui.number( - "Seed", - value=get_setting("default", "seed", -1, int), - min=-2147483648, - max=2147483647, - ) - .classes("w-1/2 md:w-full") - .bind_value(self, "seed") - .props("filled") - ) - - with ui.row().classes("w-full flex-nowrap"): - self.output_format_select = ( - ui.select( - ["webp", "jpg", "png"], - label="Output Format", - value=get_setting("default", "output_format", "webp"), - ) - .classes("w-full") - .tooltip("Format of the output images") - .bind_value(self, "output_format") - .props("filled") - ) - - self.output_quality_input = ( - ui.number( - "Output Quality", - value=get_setting("default", "output_quality", 80, int), - min=0, - max=100, - ) - .classes("w-full") - .tooltip( - "Quality when saving the output images, from 0 to 100. 100 is best quality, 0 is lowest quality. Not relevant for .png outputs" - ) - .bind_value(self, "output_quality") - .props("filled") - ) - - with ui.row().classes("w-full flex-nowrap"): - self.disable_safety_checker_switch = ( - ui.switch( - "Disable Safety Checker", - value=get_setting( - "default", "disable_safety_checker", fallback="False" - ).lower() - == "true", - ) - .classes("w-1/2") - .tooltip("Disable safety checker for generated images.") - .bind_value(self, "disable_safety_checker") - .props("filled color=blue-8") - ) - self.reset_button = ui.button( - "Reset Parameters", on_click=self.reset_to_default, color="#cf222e" - ).classes("w-1/2 text-white font-bold py-2 px-4 rounded") - - def setup_right_panel(self): - logger.debug("Setting up right panel") - with ui.row().classes("w-full flex-nowrap"): - ui.label("Output").classes("text-center ml-4 mt-3 w-full").style( - "font-size: 230%; font-weight: bold; text-align: left;" - ) - ui.button( - "Download Images", on_click=self.download_zip, color="#0969da" - ).classes("modern-button text-white font-bold py-2 px-4 rounded") - ui.separator() - with ui.row().classes("w-full flex-nowrap"): - self.gallery_container = ui.column().classes( - "w-full mt-4 grid grid-cols-2 gap-4" - ) - self.lightbox = Lightbox() - - def setup_prompt_panel(self): - logger.debug("Setting up prompt panel") - with ui.row().classes("w-full flex-row flex-nowrap"): - self.prompt_input = ( - ui.textarea("Prompt", value=self.prompt) - .classes("w-full text-gray-200 shadow-lg") - .bind_value(self, "prompt") - .props("clearable filled autofocus") - ) - self.generate_button = ( - ui.button(icon="bolt", on_click=self.start_generation, color="#0969da") - .classes("ml-2 font-bold rounded modern-button h-full") - .props("size=1.5rem") - .style("animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;") - ) - self.progress = ( - ui.linear_progress(show_value=False, size="20px") - .classes("w-full") - .props("indeterminate") - ) - self.progress.visible = False - - async def open_settings_popup(self): - logger.debug("Opening settings popup") - with ui.dialog() as dialog, ui.card().classes( - "w-2/3 modern-card dark:bg-[#25292e] bg-[#818b981f]" - ): - ui.label("Settings").classes("text-2xl font-bold") - api_key_input = ui.input( - label="API Key", - placeholder="Enter Replicate API Key...", - password=True, - value=self.api_key, - ).classes("w-full mb-4") - - async def save_settings(): - logger.debug("Saving settings") - new_api_key = api_key_input.value - if new_api_key != self.api_key: - self.api_key = new_api_key - set_setting("secrets", "REPLICATE_API_KEY", new_api_key) - await self.save_settings() - os.environ["REPLICATE_API_KEY"] = new_api_key - self.image_generator.set_api_key(new_api_key) - logger.info("API key saved") - - dialog.close() - ui.notify("Settings saved successfully", type="positive") - - if not DOCKERIZED: - self.folder_input = ui.input( - label="Output Folder", value=self.output_folder - ).classes("w-full mb-4") - self.folder_input.on("change", self.update_folder_path) - ui.button( - "Save Settings", on_click=save_settings, color="#818b981f" - ).classes("mt-4") - dialog.open() - - async def save_api_key(self): - logger.debug("Saving API key") - set_setting("secrets", "REPLICATE_API_KEY", self.api_key) - save_settings() - os.environ["REPLICATE_API_KEY"] = self.api_key - self.image_generator.set_api_key(self.api_key) - - @ui.refreshable - def model_list(self): - logger.debug("Refreshing model list") - for model in self.user_added_models: - with ui.row().classes("w-full justify-between items-center"): - ui.label(model) - ui.button( - icon="delete", - on_click=lambda m=model: self.confirm_delete_model(m), - color="#818b981f", - ).props("flat round color=red") - - async def open_user_model_popup(self): - logger.debug("Opening user model popup") - - async def add_model(): - await self.add_user_model(new_model_input.value) - - with ui.dialog() as dialog, ui.card(): - ui.label("Manage Replicate Models").classes("text-xl font-bold mb-4") - new_model_input = ui.input(label="Add New Model").classes("w-full mb-4") - ui.button("Add Model", on_click=add_model, color="#818b981f") - - ui.label("Current Models:").classes("mt-4 mb-2") - self.model_list() - - ui.button("Close", on_click=dialog.close, color="#818b981f").classes("mt-4") - dialog.open() - - async def add_user_model(self, new_model): - logger.debug(f"Adding user model: {new_model}") - if new_model and new_model not in self.user_added_models: - try: - latest_v = await asyncio.to_thread( - self.image_generator.get_model_version, new_model - ) - self.user_added_models[new_model] = latest_v - self.model_options = list(self.user_added_models.values()) - self.replicate_model_select.options = self.model_options - self.replicate_model_select.value = latest_v - await self.update_replicate_model(latest_v) - models_json = json.dumps( - {"user_added": list(self.user_added_models.values())} - ) - set_setting("default", "models", models_json) - save_settings() - ui.notify(f"Model '{latest_v}' added successfully", type="positive") - self.model_list.refresh() - logger.info(f"User model added: {latest_v}") - except Exception as e: - logger.error(f"Error adding model: {str(e)}") - ui.notify(f"Error adding model: {str(e)}", type="negative") - else: - logger.warning(f"Invalid model name or model already exists: {new_model}") - ui.notify("Invalid model name or model already exists", type="negative") - - async def confirm_delete_model(self, model): - logger.debug(f"Confirming deletion of model: {model}") - with ui.dialog() as confirm_dialog, ui.card(): - ui.label(f"Are you sure you want to delete the model '{model}'?").classes( - "mb-4" - ) - with ui.row(): - ui.button( - "Yes", - on_click=lambda: self.delete_user_model(model, confirm_dialog), - color="1f883d", - ).classes("mr-2") - ui.button("No", on_click=confirm_dialog.close, color="cf222e") - confirm_dialog.open() - - async def delete_user_model(self, model, confirm_dialog): - logger.debug(f"Deleting user model: {model}") - if model in self.user_added_models: - del self.user_added_models[model] - self.model_options = list(self.user_added_models.keys()) - self.replicate_model_select.options = self.model_options - if self.replicate_model_select.value == model: - self.replicate_model_select.value = None - await self.update_replicate_model(None) - models_json = json.dumps( - {"user_added": list(self.user_added_models.keys())} - ) - set_setting("default", "models", models_json) - save_settings() - ui.notify(f"Model '{model}' deleted successfully", type="positive") - confirm_dialog.close() - self.model_list.refresh() - logger.info(f"User model deleted: {model}") - else: - logger.warning(f"Cannot delete model, not found: {model}") - ui.notify("Cannot delete this model", type="negative") - - async def update_replicate_model(self, new_model): - logger.debug(f"Updating Replicate model to: {new_model}") - if new_model: - await asyncio.to_thread(self.image_generator.set_model, new_model) - self.replicate_model = new_model - await self.save_settings() - logger.info(f"Replicate model updated to: {new_model}") - self.generate_button.enable() - else: - logger.warning("No Replicate model selected") - self.generate_button.disable() - - async def update_folder_path(self, e): - logger.debug("Updating folder path") - if hasattr(e, "value"): - new_path = e.value - elif hasattr(e, "sender") and hasattr(e.sender, "value"): - new_path = e.sender.value - elif hasattr(e, "args") and e.args: - new_path = e.args[0] - else: - new_path = None - - if new_path is None: - logger.error("Failed to extract new path from event object") - ui.notify("Error updating folder path", type="negative") - return - - if os.path.isdir(new_path): - self.output_folder = new_path - set_setting("default", "output_folder", new_path) - save_settings() - logger.info(f"Output folder set to: {self.output_folder}") - ui.notify( - f"Output folder updated to: {self.output_folder}", type="positive" - ) - else: - logger.warning(f"Invalid folder path: {new_path}") - ui.notify( - "Invalid folder path. Please enter a valid directory.", type="negative" - ) - if hasattr(self, "folder_input"): - self.folder_input.value = self.output_folder - - async def toggle_custom_dimensions(self, e): - logger.debug(f"Toggling custom dimensions: {e.value}") - if e.value == "custom": - self.width_input.enable() - self.height_input.enable() - else: - self.width_input.disable() - self.height_input.disable() - await self.save_settings() - logger.info(f"Custom dimensions toggled: {e.value}") - - def check_api_key(self): - logger.debug("Checking API key") - if not self.api_key: - logger.warning("No Replicate API Key found.") - ui.notify( - "No Replicate API Key found. Please set it in the settings before generating images.", - type="warning", - close_button="OK", - timeout=10000, # 10 seconds - position="top", - ) - - async def reset_to_default(self): - logger.debug("Resetting parameters to default values") - for attr in self._attributes: - if attr not in ["models", "replicate_model"]: - value = get_setting("default", attr) - if value is not None: - if attr in [ - "num_outputs", - "num_inference_steps", - "width", - "height", - "seed", - "output_quality", - ]: - value = int(value) - elif attr in ["lora_scale", "guidance_scale"]: - value = float(value) - elif attr == "disable_safety_checker": - value = value.lower() == "true" - - setattr(self, attr, value) - if hasattr(self, f"{attr}_input"): - getattr(self, f"{attr}_input").value = value - elif hasattr(self, f"{attr}_select"): - getattr(self, f"{attr}_select").value = value - elif hasattr(self, f"{attr}_switch"): - getattr(self, f"{attr}_switch").value = value - - await self.save_settings() - ui.notify("Parameters reset to default values", type="info") - logger.info("Parameters reset to default values") - - async def start_generation(self): - logger.debug("Starting image generation") - if not self.api_key: - ui.notify( - "Please set your Replicate API Key in the settings.", type="negative" - ) - logger.error("Cannot start generation: No API key set.") - return - if not self.replicate_model_select.value: - ui.notify( - "Please select a Replicate model before generating images.", - type="negative", - ) - logger.warning( - "Attempted to generate images without selecting a Replicate model" - ) - return - - await asyncio.to_thread( - self.image_generator.set_model, self.replicate_model_select.value - ) - - await self.save_settings() - params = { - "prompt": self.prompt_input.value, - "flux_model": self.flux_model, - "aspect_ratio": self.aspect_ratio, - "num_outputs": self.num_outputs, - "lora_scale": self.lora_scale, - "num_inference_steps": self.num_inference_steps, - "guidance_scale": self.guidance_scale, - "output_format": self.output_format, - "output_quality": self.output_quality, - "disable_safety_checker": self.disable_safety_checker, - } - - if self.aspect_ratio == "custom": - params["width"] = self.width - params["height"] = self.height - - if self.seed != -1: - params["seed"] = self.seed - - self.generate_button.disable() - self.progress.visible = True - ui.notify("Generating images...", type="info") - logger.info(f"Generating images with params: {json.dumps(params, indent=2)}") - - try: - output = await asyncio.to_thread( - self.image_generator.generate_images, params - ) - await self.download_and_display_images(output) - logger.success(f"Images generated successfully: {output}") - except Exception as e: - error_message = f"An error occurred: {str(e)}" - ui.notify(error_message, type="negative") - logger.exception(error_message) - finally: - self.generate_button.enable() - self.progress.visible = False - - def create_zip_file(self): - logger.debug("Creating zip file of generated images") - if not self.last_generated_images: - ui.notify("No images to download", type="warning") - logger.warning("No images to zip") - return None - - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - zip_filename = f"generated_images_{timestamp}.zip" - zip_path = Path(self.output_folder) / zip_filename - - with zipfile.ZipFile(zip_path, "w") as zipf: - for image_path in self.last_generated_images: - zipf.write(image_path, Path(image_path).name) - logger.info(f"Zip file created: {zip_path}") - return str(zip_path) - - def download_zip(self): - logger.debug("Downloading zip file") - zip_path = self.create_zip_file() - if zip_path: - ui.download(zip_path) - ui.notify("Downloading zip file of generated images", type="positive") - - async def update_gallery(self, image_paths): - logger.debug("Updating image gallery") - self.gallery_container.clear() - self.last_generated_images = image_paths - with self.gallery_container: - with ui.row().classes("w-full"): - with ui.grid(columns=2).classes("md:grid-cols-3 w-full gap-2"): - for image_path in image_paths: - self.lightbox.add_image( - image_path, image_path, "w-full h-full object-cover" - ) - logger.debug("Image gallery updated") - - async def download_and_display_images(self, image_urls): - logger.debug("Downloading and displaying generated images") - downloaded_images = [] - async with httpx.AsyncClient() as client: - for i, url in enumerate(image_urls): - logger.debug(f"Downloading image from {url}") - response = await client.get(url) - if response.status_code == 200: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - url_part = urllib.parse.urlparse(url).path.split("/")[-2][:8] - file_name = f"generated_image_{timestamp}_{url_part}_{i+1}.png" - file_path = Path(self.output_folder) / file_name - with open(file_path, "wb") as f: - f.write(response.content) - downloaded_images.append(str(file_path)) - logger.info(f"Image downloaded: {file_path}") - else: - logger.error(f"Failed to download image from {url}") - - await self.update_gallery(downloaded_images) - ui.notify("Images generated and downloaded successfully!", type="positive") - logger.success("Images downloaded and displayed") - - async def save_settings(self): - logger.debug("Saving settings") - for attr in self._attributes: - value = getattr(self, attr) - if attr == "models": - value = json.dumps({"user_added": list(self.user_added_models.keys())}) - set_setting("default", attr, str(value)) - - set_setting("default", "replicate_model", self.replicate_model_select.value) - - save_settings() - logger.info("Settings saved successfully") - - -async def create_gui(image_generator): - logger.debug("Creating GUI") - gui = ImageGeneratorGUI(image_generator) - gui.setup_ui() - logger.debug("GUI created") - return gui diff --git a/src/gui/__init__.py b/src/gui/__init__.py new file mode 100644 index 0000000..b051bde --- /dev/null +++ b/src/gui/__init__.py @@ -0,0 +1,6 @@ +from .lightbox import Lightbox +from .imagegenerator import ImageGeneratorGUI +from .styles import Styles +from .usermodels import UserModels +from .filehandler import FileHandler + diff --git a/src/gui/filehandler.py b/src/gui/filehandler.py new file mode 100644 index 0000000..7fef529 --- /dev/null +++ b/src/gui/filehandler.py @@ -0,0 +1,33 @@ +from nicegui import ui +from datetime import datetime +import zipfile +from loguru import logger +from pathlib import Path + + +class FileHandler: + @staticmethod + def create_zip_file(last_generated_images, output_folder): + logger.debug("Creating zip file of generated images") + if not last_generated_images: + ui.notify("No images to download", type="warning") + logger.warning("No images to zip") + return None + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + zip_filename = f"generated_images_{timestamp}.zip" + zip_path = Path(output_folder) / zip_filename + + with zipfile.ZipFile(zip_path, "w") as zipf: + for image_path in last_generated_images: + zipf.write(image_path, Path(image_path).name) + logger.info(f"Zip file created: {zip_path}") + return str(zip_path) + + @staticmethod + def download_zip(last_generated_images, output_folder): + logger.debug("Downloading zip file") + zip_path = FileHandler.create_zip_file(last_generated_images, output_folder) + if zip_path: + ui.download(zip_path) + ui.notify("Downloading zip file of generated images", type="positive") diff --git a/src/gui/imagegenerator.py b/src/gui/imagegenerator.py new file mode 100644 index 0000000..ad48136 --- /dev/null +++ b/src/gui/imagegenerator.py @@ -0,0 +1,384 @@ +import asyncio +import json +import os +import urllib.parse +from datetime import datetime +from pathlib import Path +from gui.lightbox import Lightbox +from gui.styles import Styles +from gui.panels import GUIPanels +from gui.filehandler import FileHandler +import httpx +from loguru import logger +from nicegui import ui +from util import Settings + +DOCKERIZED = os.environ.get("DOCKER_CONTAINER", False) + + +class ImageGeneratorGUI: + def __init__(self, image_generator): + logger.info("Initializing ImageGeneratorGUI") + self.image_generator = image_generator + self.api_key = Settings.get_api_key() or os.environ.get("REPLICATE_API_KEY", "") + self.last_generated_images = [] + self.custom_styles = Styles.setup_custom_styles() + self._attributes = [ + "prompt", + "flux_model", + "aspect_ratio", + "num_outputs", + "lora_scale", + "num_inference_steps", + "guidance_scale", + "output_format", + "output_quality", + "disable_safety_checker", + "width", + "height", + "seed", + "output_folder", + "replicate_model", + ] + + self.user_added_models = {} + self.prompt = Settings.get_setting("default", "prompt", "", str) + + self.flux_model = Settings.get_setting("default", "flux_model", "dev", str) + self.aspect_ratio = Settings.get_setting("default", "aspect_ratio", "1:1", str) + self.num_outputs = Settings.get_setting("default", "num_outputs", "1", int) + self.lora_scale = Settings.get_setting("default", "lora_scale", "1", float) + self.num_inference_steps = Settings.get_setting( + "default", "num_inference_steps", "28", int + ) + self.guidance_scale = Settings.get_setting( + "default", "guidance_scale", "3.5", float + ) + self.output_format = Settings.get_setting("default", "output_format", "png") + self.output_quality = Settings.get_setting( + "default", "output_quality", "80", int + ) + self.disable_safety_checker = Settings.get_setting( + "default", "disable_safety_checker", True, bool + ) + + self.width = Settings.get_setting("default", "width", "1024", int) + self.height = Settings.get_setting("default", "height", "1024", int) + self.seed = Settings.get_setting("default", "seed", "-1", int) + + self.output_folder = ( + "/app/output" + if DOCKERIZED + else Settings.get_setting("default", "output_folder", "/Downloads", str) + ) + models_json = Settings.get_setting( + "default", "models", '{"user_added": []}', str + ) + models = json.loads(models_json) + self.user_added_models = { + model: model for model in models.get("user_added", []) + } + self.model_options = list(self.user_added_models.keys()) + self.replicate_model = Settings.get_setting( + "default", "replicate_model", "", str + ) + + logger.info("ImageGeneratorGUI initialized") + + def setup_ui(self): + logger.info("Setting up UI") + ui.dark_mode(True) + self.check_api_key() + + with ui.grid().classes( + "w-full h-screen md:h-full grid-cols-1 md:grid-cols-2 gap-2 md:gap-5 p-4 md:p-6 dark:bg-[#11111b] bg-#eff1f5] md:auto-rows-min" + ): + with ui.card().classes("col-span-full modern-card flex-nowrap h-min"): + GUIPanels.setup_top_panel(self) + + with ui.card().classes("col-span-full modern-card"): + GUIPanels.setup_prompt_panel(self) + + with ui.card().classes("row-span-2 overflow-auto modern-card"): + GUIPanels.setup_left_panel(self) + + with ui.card().classes("row-span-2 overflow-auto modern-card"): + GUIPanels.setup_right_panel(self) + Styles.stylefilter(self) + logger.info("UI setup completed") + + async def open_settings_popup(self): + logger.debug("Opening settings popup") + with ui.dialog() as dialog, ui.card().classes( + "w-2/3 modern-card dark:bg-[#25292e] bg-[#818b981f]" + ): + ui.label("Settings").classes("text-2xl font-bold") + api_key_input = ui.input( + label="API Key", + placeholder="Enter Replicate API Key...", + password=True, + value=self.api_key, + ).classes("w-full mb-4") + + async def save_settings(): + logger.debug("Saving settings") + new_api_key = api_key_input.value + if new_api_key != self.api_key: + self.api_key = new_api_key + Settings.set_setting("secrets", "REPLICATE_API_KEY", new_api_key) + await self.save_settings() + os.environ["REPLICATE_API_KEY"] = new_api_key + self.image_generator.set_api_key(new_api_key) + logger.info("API key saved") + + dialog.close() + ui.notify("Settings saved successfully", type="positive") + + if not DOCKERIZED: + self.folder_input = ui.input( + label="Output Folder", value=self.output_folder + ).classes("w-full mb-4") + self.folder_input.on("change", self.update_folder_path) + ui.button("Save Settings", on_click=save_settings, color="blue-4").classes( + "mt-4" + ) + dialog.open() + + async def save_api_key(self): + logger.debug("Saving API key") + Settings.set_setting("secrets", "REPLICATE_API_KEY", self.api_key) + Settings.save_settings() + os.environ["REPLICATE_API_KEY"] = self.api_key + self.image_generator.set_api_key(self.api_key) + + async def update_folder_path(self, e): + logger.debug("Updating folder path") + if hasattr(e, "value"): + new_path = e.value + elif hasattr(e, "sender") and hasattr(e.sender, "value"): + new_path = e.sender.value + elif hasattr(e, "args") and e.args: + new_path = e.args[0] + else: + new_path = None + + if new_path is None: + logger.error("Failed to extract new path from event object") + ui.notify("Error updating folder path", type="negative") + return + + if os.path.isdir(new_path): + self.output_folder = new_path + Settings.set_setting("default", "output_folder", new_path) + Settings.save_settings() + logger.info(f"Output folder set to: {self.output_folder}") + ui.notify( + f"Output folder updated to: {self.output_folder}", type="positive" + ) + else: + logger.warning(f"Invalid folder path: {new_path}") + ui.notify( + "Invalid folder path. Please enter a valid directory.", type="negative" + ) + if hasattr(self, "folder_input"): + self.folder_input.value = self.output_folder + + async def toggle_custom_dimensions(self, e): + logger.debug(f"Toggling custom dimensions: {e.value}") + if e.value == "custom": + self.width_input.enable() + self.height_input.enable() + else: + self.width_input.disable() + self.height_input.disable() + await self.save_settings() + logger.info(f"Custom dimensions toggled: {e.value}") + + def check_api_key(self): + logger.debug("Checking API key") + if not self.api_key: + logger.warning("No Replicate API Key found.") + ui.notify( + "No Replicate API Key found. Please set it in the settings before generating images.", + type="warning", + close_button="OK", + timeout=10000, # 10 seconds + position="top", + ) + + async def reset_to_default(self): + logger.debug("Resetting parameters to default values") + for attr in self._attributes: + if attr not in ["models", "replicate_model"]: + value = Settings.get_setting("default", attr) + if value is not None: + if attr in [ + "num_outputs", + "num_inference_steps", + "width", + "height", + "seed", + "output_quality", + ]: + value = int(value) + elif attr in ["lora_scale", "guidance_scale"]: + value = float(value) + elif attr == "disable_safety_checker": + value = value.lower() == "true" + + setattr(self, attr, value) + if hasattr(self, f"{attr}_input"): + getattr(self, f"{attr}_input").value = value + elif hasattr(self, f"{attr}_select"): + getattr(self, f"{attr}_select").value = value + elif hasattr(self, f"{attr}_switch"): + getattr(self, f"{attr}_switch").value = value + + await self.save_settings() + ui.notify("Parameters reset to default values", type="info") + logger.info("Parameters reset to default values") + + async def start_generation(self): + logger.debug("Starting image generation") + if not self.api_key: + ui.notify( + "Please set your Replicate API Key in the settings.", type="negative" + ) + logger.error("Cannot start generation: No API key set.") + return + if not self.replicate_model_select.value: + ui.notify( + "Please select a Replicate model before generating images.", + type="negative", + ) + logger.warning( + "Attempted to generate images without selecting a Replicate model" + ) + return + + await asyncio.to_thread( + self.image_generator.set_model, self.replicate_model_select.value + ) + + await self.save_settings() + params = { + "prompt": self.prompt_input.value, + "flux_model": self.flux_model, + "aspect_ratio": self.aspect_ratio, + "num_outputs": self.num_outputs, + "lora_scale": self.lora_scale, + "num_inference_steps": self.num_inference_steps, + "guidance_scale": self.guidance_scale, + "output_format": self.output_format, + "output_quality": self.output_quality, + "disable_safety_checker": self.disable_safety_checker, + } + + if self.aspect_ratio == "custom": + params["width"] = self.width + params["height"] = self.height + + if self.seed != -1: + params["seed"] = self.seed + + self.generate_button.disable() + self.progress.visible = True + ui.notify("Generating images...", type="info") + logger.info(f"Generating images with params: {json.dumps(params, indent=2)}") + + try: + output = await asyncio.to_thread( + self.image_generator.generate_images, params + ) + await self.download_and_display_images(output) + logger.success(f"Images generated successfully: {output}") + except Exception as e: + error_message = f"An error occurred: {str(e)}" + ui.notify(error_message, type="negative") + logger.exception(error_message) + finally: + self.generate_button.enable() + self.progress.visible = False + + # def create_zip_file(self): + # logger.debug("Creating zip file of generated images") + # if not self.last_generated_images: + # ui.notify("No images to download", type="warning") + # logger.warning("No images to zip") + # return None + # + # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + # zip_filename = f"generated_images_{timestamp}.zip" + # zip_path = Path(self.output_folder) / zip_filename + # + # with zipfile.ZipFile(zip_path, "w") as zipf: + # for image_path in self.last_generated_images: + # zipf.write(image_path, Path(image_path).name) + # logger.info(f"Zip file created: {zip_path}") + # return str(zip_path) + # + # def download_zip(self): + # logger.debug("Downloading zip file") + # zip_path = self.create_zip_file() + # if zip_path: + # ui.download(zip_path) + # ui.notify("Downloading zip file of generated images", type="positive") + # + async def update_gallery(self, image_paths): + logger.debug("Updating image gallery") + self.gallery_container.clear() + self.last_generated_images = image_paths + with self.gallery_container: + with ui.row().classes("w-full"): + with ui.grid(columns=2).classes("md:grid-cols-3 w-full gap-2"): + for image_path in image_paths: + self.lightbox.add_image( + image_path, image_path, "w-full h-full object-cover" + ) + logger.debug("Image gallery updated") + + async def download_and_display_images(self, image_urls): + logger.debug("Downloading and displaying generated images") + downloaded_images = [] + async with httpx.AsyncClient() as client: + for i, url in enumerate(image_urls): + logger.debug(f"Downloading image from {url}") + response = await client.get(url) + if response.status_code == 200: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + url_part = urllib.parse.urlparse(url).path.split("/")[-2][:8] + file_name = f"generated_image_{timestamp}_{url_part}_{i+1}.png" + file_path = Path(self.output_folder) / file_name + with open(file_path, "wb") as f: + f.write(response.content) + downloaded_images.append(str(file_path)) + logger.info(f"Image downloaded: {file_path}") + else: + logger.error(f"Failed to download image from {url}") + + await self.update_gallery(downloaded_images) + ui.notify("Images generated and downloaded successfully!", type="positive") + logger.success("Images downloaded and displayed") + + async def save_settings(self): + logger.debug("Saving settings") + for attr in self._attributes: + value = getattr(self, attr) + if attr == "models": + value = json.dumps({"user_added": list(self.user_added_models.keys())}) + Settings.set_setting("default", attr, str(value)) + + Settings.set_setting( + "default", "replicate_model", self.replicate_model_select.value + ) + + Settings.save_settings() + logger.info("Settings saved successfully") + + +async def create_gui(image_generator): + logger.debug("Creating GUI") + gui = ImageGeneratorGUI(image_generator) + gui.setup_ui() + logger.debug("GUI created") + return gui diff --git a/src/gui/lightbox.py b/src/gui/lightbox.py new file mode 100644 index 0000000..012ec00 --- /dev/null +++ b/src/gui/lightbox.py @@ -0,0 +1,48 @@ +from nicegui import ui +from loguru import logger + + +class Lightbox: + def __init__(self): + logger.debug("Initializing Lightbox") + with ui.dialog().props("maximized").classes("bg-black") as self.dialog: + self.dialog.on_key = self._handle_key + self.large_image = ui.image().props("no-spinner fit=scale-down") + self.image_list = [] + logger.debug("Lightbox initialized") + + def add_image( + self, + thumb_url: str, + orig_url: str, + thumb_classes: str = "w-32 h-32 object-cover", + ) -> ui.button: + logger.debug(f"Adding image to Lightbox: {orig_url}") + self.image_list.append(orig_url) + button = ui.button(on_click=lambda: self._open(orig_url)).props( + "flat dense square" + ) + with button: + ui.image(thumb_url).classes(thumb_classes) + logger.debug("Image added to Lightbox") + return button + + def _handle_key(self, e) -> None: + logger.debug(f"Handling key press in Lightbox: {e.key}") + if not e.action.keydown: + return + if e.key.escape: + logger.debug("Closing Lightbox dialog") + self.dialog.close() + image_index = self.image_list.index(self.large_image.source) + if e.key.arrow_left and image_index > 0: + logger.debug("Displaying previous image") + self._open(self.image_list[image_index - 1]) + if e.key.arrow_right and image_index < len(self.image_list) - 1: + logger.debug("Displaying next image") + self._open(self.image_list[image_index + 1]) + + def _open(self, url: str) -> None: + logger.debug(f"Opening image in Lightbox: {url}") + self.large_image.set_source(url) + self.dialog.open() diff --git a/src/gui/panels.py b/src/gui/panels.py new file mode 100644 index 0000000..402bdb1 --- /dev/null +++ b/src/gui/panels.py @@ -0,0 +1,281 @@ +import asyncio +from nicegui import ui +from loguru import logger +from util.settings import Settings +from gui.lightbox import Lightbox +from gui.usermodels import UserModels +from gui.filehandler import FileHandler +from gui.styles import Styles + + +class GUIPanels: + def setup_top_panel(self): + logger.debug("Setting up top panel") + with ui.row().classes("w-full items-center"): + ui.label("Lumberjack - Replicate API Interface").classes( + "text-2xl/loose font-bold" + ) + # dark_mode = ui.dark_mode(True) + # ui.switch().bind_value(dark_mode).classes().props( + # "dense checked-icon=dark_mode unchecked-icon=light_mode color=blue-4" + # ) + ui.button( + icon="settings_suggest", + on_click=self.open_settings_popup, + color="blue-4", + ).classes("absolute-right mr-6 mt-3 mb-3") + + def setup_left_panel(self): + logger.debug("Setting up left panel") + with ui.row().classes("w-full flex-row flex-nowrap"): + self.replicate_model_select = ( + ui.select( + options=self.model_options, + label="Replicate Model", + value=self.replicate_model, + on_change=lambda e: asyncio.create_task( + self.update_replicate_model(e.value) + ), + ) + .classes("width-5/6 overflow-auto custom-select bg-[#1e1e2e]") + .tooltip("Select or manage Replicate models") + .props("filled bg-color=dark") + ) + ui.button(icon="token", color="blue-4").classes("ml-2 mt-1.2").on( + "click", UserModels.open_user_model_popup(self) + ).props("size=1.3rem") + + self.flux_model_select = ( + ui.select( + ["dev", "schnell"], + label="Flux Model", + value=Settings.get_setting("default", "flux_model", "dev"), + ) + .classes("w-full text-gray-200") + .tooltip( + "Which model to run inferences with. The dev model needs around 28 steps but the schnell model only needs around 4 steps." + ) + .bind_value(self, "flux_model") + .props("filled bg-color=dark") + ) + + with ui.row().classes("w-full flex-nowrap md:flex-wrap"): + self.aspect_ratio_select = ( + ui.select( + [ + "1:1", + "16:9", + "21:9", + "3:2", + "2:3", + "4:5", + "5:4", + "3:4", + "4:3", + "9:16", + "9:21", + "custom", + ], + label="Aspect Ratio", + value=Settings.get_setting("default", "aspect_ratio", "1:1"), + ) + .classes("w-1/2 md:w-full text-gray-200") + .bind_value(self, "aspect_ratio") + .tooltip( + "Width of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16)" + ) + .props("filled bg-color=dark") + ) + self.aspect_ratio_select.on("change", self.toggle_custom_dimensions) + + with ui.column().classes("w-full").bind_visibility_from( + self.aspect_ratio_select, "value", value="custom" + ): + self.width_input = ( + ui.number( + "Width", + value=Settings.get_setting("default", "width", 1024, int), + min=256, + max=1440, + ) + .classes("w-full") + .bind_value(self, "width") + .tooltip( + "Width of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16)" + ) + .props("filled bg-color=dark") + ) + self.height_input = ( + ui.number( + "Height", + value=Settings.get_setting("default", "height", 1024, int), + min=256, + max=1440, + ) + .classes("w-full") + .bind_value(self, "height") + .tooltip( + "Height of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16)" + ) + .props("filled bg-color=dark") + ) + + self.num_outputs_input = ( + ui.number( + "Num Outputs", + value=Settings.get_setting("default", "num_outputs", 1, int), + min=1, + max=4, + ) + .classes("w-1/2 md:w-full") + .bind_value(self, "num_outputs") + .tooltip("Number of images to output.") + .props("filled bg-color=dark") + ) + + with ui.row().classes("w-full flex-nowrap md:flex-wrap"): + self.lora_scale_input = ( + ui.number( + "LoRA Scale", + value=float(Settings.get_setting("default", "lora_scale", 1)), + min=-1, + max=2, + step=0.1, + ) + .classes("w-1/2 md:w-full") + .tooltip( + "Determines how strongly the LoRA should be applied. Sane results between 0 and 1." + ) + .props("filled bg-color=dark color=primary") + .bind_value(self, "lora_scale") + ) + self.num_inference_steps_input = ( + ui.number( + "Num Inference Steps", + value=Settings.get_setting( + "default", "num_inference_steps", 28, int + ), + min=1, + max=50, + precision=0, + ) + .classes("w-1/2 md:w-full") + .tooltip("Number of Inference Steps") + .bind_value(self, "num_inference_steps") + .props("filled bg-color=dark") + ) + + with ui.row().classes("w-full flex-nowrap md:flex-wrap"): + self.guidance_scale_input = ( + ui.number( + "Guidance Scale", + value=float(Settings.get_setting("default", "guidance_scale", 3.5)), + min=0, + max=10, + step=0.1, + precision=2, + ) + .classes("w-1/2 md:w-full") + .tooltip("Guidance Scale for the diffusion process") + .bind_value(self, "guidance_scale") + .props("filled bg-color=dark") + ) + self.seed_input = ( + ui.number( + "Seed", + value=Settings.get_setting("default", "seed", -1, int), + min=-2147483648, + max=2147483647, + ) + .classes("w-1/2 md:w-full") + .bind_value(self, "seed") + .props("filled bg-color=dark") + ) + + with ui.row().classes("w-full flex-nowrap"): + self.output_format_select = ( + ui.select( + ["webp", "jpg", "png"], + label="Output Format", + value=Settings.get_setting("default", "output_format", "webp"), + ) + .classes("w-full") + .tooltip("Format of the output images") + .bind_value(self, "output_format") + .props("filled bg-color=dark") + ) + + self.output_quality_input = ( + ui.number( + "Output Quality", + value=Settings.get_setting("default", "output_quality", 80, int), + min=0, + max=100, + ) + .classes("w-full") + .tooltip( + "Quality when saving the output images, from 0 to 100. 100 is best quality, 0 is lowest quality. Not relevant for .png outputs" + ) + .bind_value(self, "output_quality") + .props("filled bg-color=dark") + ) + + with ui.row().classes("w-full flex-nowrap"): + self.disable_safety_checker_switch = ( + ui.switch( + "Disable Safety Checker", + value=Settings.get_setting( + "default", "disable_safety_checker", fallback="False" + ).lower() + == "true", + ) + .classes("w-1/2") + .tooltip("Disable safety checker for generated images.") + .bind_value(self, "disable_safety_checker") + .props("filled bg-color=dark color=blue-4") + ) + self.reset_button = ui.button( + "Reset Parameters", on_click=self.reset_to_default, color="#e78284" + ).classes("w-1/2 text-white font-bold py-2 px-4 rounded") + + def setup_right_panel(self): + logger.debug("Setting up right panel") + with ui.row().classes("w-full flex-nowrap"): + ui.label("Output").classes("text-center ml-4 mt-3 w-full").style( + "font-size: 230%; font-weight: bold; text-align: left;" + ) + ui.button( + "Download Images", + on_click=lambda: FileHandler.download_zip( + self.last_generated_images, self.output_folder + ), + color="blue-4", + ).classes("modern-button text-white font-bold py-2 px-4 rounded") + ui.separator() + with ui.row().classes("w-full flex-nowrap"): + self.gallery_container = ui.column().classes( + "w-full mt-4 grid grid-cols-2 gap-4" + ) + self.lightbox = Lightbox() + + def setup_prompt_panel(self): + logger.debug("Setting up prompt panel") + with ui.row().classes("w-full flex-row flex-nowrap"): + self.prompt_input = ( + ui.textarea("Prompt", value=self.prompt) + .classes("w-full shadow-lg") + .bind_value(self, "prompt") + .props("clearable filled bg-color=dark autofocus color=blue-4") + ) + self.generate_button = ( + ui.button(icon="bolt", on_click=self.start_generation, color="blue-4") + .classes("ml-2 font-bold rounded modern-button h-full") + .props("size=1.5rem") + .style("animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;") + ) + self.progress = ( + ui.linear_progress(show_value=False, size="20px") + .classes("w-full") + .props("indeterminate") + ) + self.progress.visible = False diff --git a/src/gui/styles.py b/src/gui/styles.py new file mode 100644 index 0000000..6b196c0 --- /dev/null +++ b/src/gui/styles.py @@ -0,0 +1,60 @@ +from nicegui import ui, ElementFilter +from loguru import logger + + +class Styles: + def setup_custom_styles(): + logger.debug("Setting up custom styles") + ui.add_head_html(""" + + + """) + + # ui.add_css(""" + # .bg { + # color: red; + # } + # """) + ui.colors( + dark="#303446", # Crust (Closest: Grey-9) + primary="#8aadf4", # Blue (Closest: Blue-4) + positive="#a6d189", # Green (Closest: Green-4) + negative="#e78284", # Maroon (Closest: Red-3) + secondary="#f38ba8", # Red (Closest: Red-4) + accent="#f5bde6", # Pink (Closest: Pink-3) + info="#89dceb", # Sky (Closest: Cyan-4) + warning="#f0a988", # Peach (Closest: Orange-4) + ) + + def stylefilter(self): + ElementFilter(kind=ui.label).classes("mt-0") + + ElementFilter(kind=ui.card).classes("dark:bg-[#181825] bg-[#ccd0da]") diff --git a/src/gui/usermodels.py b/src/gui/usermodels.py new file mode 100644 index 0000000..45fc531 --- /dev/null +++ b/src/gui/usermodels.py @@ -0,0 +1,112 @@ +from nicegui import ui +from loguru import logger +import asyncio +import json +from util import Settings + + +class UserModels: + @ui.refreshable + def model_list(self): + logger.debug("Refreshing model list") + for model in self.user_added_models: + with ui.row().classes("w-full justify-between items-center"): + ui.label(model) + ui.button( + icon="delete", + on_click=lambda m=model: self.confirm_delete_model(m), + color="#818b981f", + ).props("flat round color=red") + + async def open_user_model_popup(self): + logger.debug("Opening user model popup") + + async def add_model(): + await self.add_user_model(new_model_input.value) + + with ui.dialog() as dialog, ui.card(): + ui.label("Manage Replicate Models").classes("text-xl font-bold mb-4") + new_model_input = ui.input(label="Add New Model").classes("w-full mb-4") + ui.button("Add Model", on_click=add_model, color="#818b981f") + + ui.label("Current Models:").classes("mt-4 mb-2") + self.model_list() + + ui.button("Close", on_click=dialog.close, color="#818b981f").classes("mt-4") + dialog.open() + + async def add_user_model(self, new_model): + logger.debug(f"Adding user model: {new_model}") + if new_model and new_model not in self.user_added_models: + try: + latest_v = await asyncio.to_thread( + self.image_generator.get_model_version, new_model + ) + self.user_added_models[new_model] = latest_v + self.model_options = list(self.user_added_models.values()) + self.replicate_model_select.options = self.model_options + self.replicate_model_select.value = latest_v + await self.update_replicate_model(latest_v) + models_json = json.dumps( + {"user_added": list(self.user_added_models.values())} + ) + Settings.set_setting("default", "models", models_json) + save_settings() + ui.notify(f"Model '{latest_v}' added successfully", type="positive") + self.model_list.refresh() + logger.info(f"User model added: {latest_v}") + except Exception as e: + logger.error(f"Error adding model: {str(e)}") + ui.notify(f"Error adding model: {str(e)}", type="negative") + else: + logger.warning(f"Invalid model name or model already exists: {new_model}") + ui.notify("Invalid model name or model already exists", type="negative") + + async def confirm_delete_model(self, model): + logger.debug(f"Confirming deletion of model: {model}") + with ui.dialog() as confirm_dialog, ui.card(): + ui.label(f"Are you sure you want to delete the model '{model}'?").classes( + "mb-4" + ) + with ui.row(): + ui.button( + "Yes", + on_click=lambda: self.delete_user_model(model, confirm_dialog), + color="1f883d", + ).classes("mr-2") + ui.button("No", on_click=confirm_dialog.close, color="cf222e") + confirm_dialog.open() + + async def delete_user_model(self, model, confirm_dialog): + logger.debug(f"Deleting user model: {model}") + if model in self.user_added_models: + del self.user_added_models[model] + self.model_options = list(self.user_added_models.keys()) + self.replicate_model_select.options = self.model_options + if self.replicate_model_select.value == model: + self.replicate_model_select.value = None + await self.update_replicate_model(None) + models_json = json.dumps( + {"user_added": list(self.user_added_models.keys())} + ) + Settings.set_setting("default", "models", models_json) + save_settings() + ui.notify(f"Model '{model}' deleted successfully", type="positive") + confirm_dialog.close() + self.model_list.refresh() + logger.info(f"User model deleted: {model}") + else: + logger.warning(f"Cannot delete model, not found: {model}") + ui.notify("Cannot delete this model", type="negative") + + async def update_replicate_model(self, new_model): + logger.debug(f"Updating Replicate model to: {new_model}") + if new_model: + await asyncio.to_thread(self.image_generator.set_model, new_model) + self.replicate_model = new_model + await self.save_settings() + logger.info(f"Replicate model updated to: {new_model}") + self.generate_button.enable() + else: + logger.warning("No Replicate model selected") + self.generate_button.disable() diff --git a/src/gui/utilities.py b/src/gui/utilities.py new file mode 100644 index 0000000..bbdb8f5 --- /dev/null +++ b/src/gui/utilities.py @@ -0,0 +1,5 @@ +from nicegui import ui +from loguru import logger + +class Utilities: + diff --git a/src/main.py b/src/main.py index 949692a..68d1bf6 100644 --- a/src/main.py +++ b/src/main.py @@ -1,10 +1,9 @@ import sys -from config import get_api_key -from gui import create_gui from loguru import logger from nicegui import ui -from replicate_api import ImageGenerator +import util +from gui import ImageGeneratorGUI logger.add( sys.stderr, format="{time} {level} {message}", filter="my_module", level="INFO" @@ -12,17 +11,16 @@ logger.add( "app.log", format="{time} {level} {module}:{line} {message}", - level="DEBUG", + level="INFO", rotation="500 MB", compression="zip", ) logger.info("Initializing ImageGenerator") -generator = ImageGenerator() +generator = util.Replicate_API() - -api_key = get_api_key() +api_key = util.Settings.get_api_key() if api_key: generator.set_api_key(api_key) else: @@ -33,8 +31,10 @@ @ui.page("/") async def main_page(): - await create_gui(generator) - logger.info("NiceGUI server is running") + logger.debug("Creating GUI") + gui = ImageGeneratorGUI(generator) + gui.setup_ui() + logger.debug("GUI created") logger.info("Starting NiceGUI server") diff --git a/src/util/__init__.py b/src/util/__init__.py new file mode 100644 index 0000000..5f59ee9 --- /dev/null +++ b/src/util/__init__.py @@ -0,0 +1,3 @@ +# from .settings import Settings +from .replicate_api import Replicate_API +from .settings import Settings diff --git a/src/replicate_api.py b/src/util/replicate_api.py similarity index 99% rename from src/replicate_api.py rename to src/util/replicate_api.py index cafd3e0..4529ef9 100644 --- a/src/replicate_api.py +++ b/src/util/replicate_api.py @@ -8,7 +8,7 @@ load_dotenv() -class ImageGenerator: +class Replicate_API: def __init__(self): self.replicate_model = None self.api_key = None diff --git a/src/util/settings.py b/src/util/settings.py new file mode 100644 index 0000000..eec88a1 --- /dev/null +++ b/src/util/settings.py @@ -0,0 +1,80 @@ +import configparser +import os +from typing import Any, Type + +from loguru import logger + +DOCKERIZED = os.environ.get("DOCKER_CONTAINER", "False").lower() == "true" +CONFIG_DIR = "/app/settings" if DOCKERIZED else "." +DEFAULT_CONFIG_FILE = os.path.join(CONFIG_DIR, "settings.ini") +USER_CONFIG_FILE = os.path.join(CONFIG_DIR, "settings.user.ini") + +logger.info( + f"Configuration files: DEFAULT={DEFAULT_CONFIG_FILE}, USER={USER_CONFIG_FILE}" +) + +config = configparser.ConfigParser() +config.read([DEFAULT_CONFIG_FILE, USER_CONFIG_FILE]) +logger.info("Configuration files loaded") + + +class Settings: + def get_api_key(): + api_key = os.environ.get("REPLICATE_API_KEY") or config.get( + "secrets", "REPLICATE_API_KEY", fallback=None + ) + if api_key: + logger.info("API key retrieved successfully") + else: + logger.warning("No API key found") + return api_key + + def get_setting( + section: str, key: str, fallback: Any = None, value_type: Type[Any] = str + ) -> Any: + logger.info( + f"Attempting to get setting: section={section}, key={key}, fallback={fallback}, value_type={value_type}" + ) + try: + value = config.get(section, key) + logger.debug(f"Raw value retrieved: {value}") + if value_type is int: + result = int(value) + elif value_type is float: + result = float(value) + elif value_type is bool: + result = value.lower() in ("true", "yes", "1", "on") + else: + result = value + logger.info(f"Setting retrieved successfully: {result}") + return result + except (configparser.NoSectionError, configparser.NoOptionError) as e: + logger.warning( + f"Setting not found: {str(e)}. Using fallback value: {fallback}" + ) + return fallback + except ValueError as e: + logger.error( + f"Error converting setting value: {str(e)}. Using fallback value: {fallback}" + ) + return fallback + + def set_setting(section, key, value): + logger.info(f"Setting value: section={section}, key={key}, value={value}") + if not config.has_section(section): + logger.info(f"Creating new section: {section}") + config.add_section(section) + config.set(section, key, str(value)) + logger.info("Value set successfully") + + def save_settings(): + logger.info(f"Saving settings to {USER_CONFIG_FILE}") + try: + with open(USER_CONFIG_FILE, "w") as configfile: + config.write(configfile) + logger.info("Settings saved successfully") + except IOError as e: + logger.error(f"Error saving settings: {str(e)}") + + +logger.info("Config module initialized")