diff --git a/.gitignore b/.gitignore
index 15a5c92..987619a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -157,3 +157,4 @@ cython_debug/
*code-workspace
settings.user.ini
+.aider*
diff --git a/requirements.txt b/requirements.txt
index b93730e..826cb8e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,8 @@
-replicate==0.33.0
+replicate==0.34.1
python-dotenv==1.0.1
token_count==0.2.1
loguru==0.7.2
-nicegui==2.1.0
+nicegui==2.3.0
httpx==0.27.2
dynaconf==3.2.6
toml==0.10.2
diff --git a/src/config.py b/src/config.py
deleted file mode 100644
index ea20b8f..0000000
--- a/src/config.py
+++ /dev/null
@@ -1,80 +0,0 @@
-import configparser
-import os
-from typing import Any, Type
-
-from loguru import logger
-
-DOCKERIZED = os.environ.get("DOCKER_CONTAINER", "False").lower() == "true"
-CONFIG_DIR = "/app/settings" if DOCKERIZED else "."
-DEFAULT_CONFIG_FILE = os.path.join(CONFIG_DIR, "settings.ini")
-USER_CONFIG_FILE = os.path.join(CONFIG_DIR, "settings.user.ini")
-
-logger.info(
- f"Configuration files: DEFAULT={DEFAULT_CONFIG_FILE}, USER={USER_CONFIG_FILE}"
-)
-
-config = configparser.ConfigParser()
-config.read([DEFAULT_CONFIG_FILE, USER_CONFIG_FILE])
-logger.info("Configuration files loaded")
-
-
-def get_api_key():
- api_key = os.environ.get("REPLICATE_API_KEY") or config.get(
- "secrets", "REPLICATE_API_KEY", fallback=None
- )
- if api_key:
- logger.info("API key retrieved successfully")
- else:
- logger.warning("No API key found")
- return api_key
-
-
-def get_setting(
- section: str, key: str, fallback: Any = None, value_type: Type[Any] = str
-) -> Any:
- logger.info(
- f"Attempting to get setting: section={section}, key={key}, fallback={fallback}, value_type={value_type}"
- )
- try:
- value = config.get(section, key)
- logger.debug(f"Raw value retrieved: {value}")
- if value_type == int:
- result = int(value)
- elif value_type == float:
- result = float(value)
- elif value_type == bool:
- result = value.lower() in ("true", "yes", "1", "on")
- else:
- result = value
- logger.info(f"Setting retrieved successfully: {result}")
- return result
- except (configparser.NoSectionError, configparser.NoOptionError) as e:
- logger.warning(f"Setting not found: {str(e)}. Using fallback value: {fallback}")
- return fallback
- except ValueError as e:
- logger.error(
- f"Error converting setting value: {str(e)}. Using fallback value: {fallback}"
- )
- return fallback
-
-
-def set_setting(section, key, value):
- logger.info(f"Setting value: section={section}, key={key}, value={value}")
- if not config.has_section(section):
- logger.info(f"Creating new section: {section}")
- config.add_section(section)
- config.set(section, key, str(value))
- logger.info("Value set successfully")
-
-
-def save_settings():
- logger.info(f"Saving settings to {USER_CONFIG_FILE}")
- try:
- with open(USER_CONFIG_FILE, "w") as configfile:
- config.write(configfile)
- logger.info("Settings saved successfully")
- except IOError as e:
- logger.error(f"Error saving settings: {str(e)}")
-
-
-logger.info("Config module initialized")
diff --git a/src/gui.py b/src/gui.py
deleted file mode 100644
index 5e84df3..0000000
--- a/src/gui.py
+++ /dev/null
@@ -1,828 +0,0 @@
-import asyncio
-import json
-import os
-import urllib.parse
-import zipfile
-from datetime import datetime
-from pathlib import Path
-
-import httpx
-from config import get_api_key, get_setting, save_settings, set_setting
-from loguru import logger
-from nicegui import ui
-
-DOCKERIZED = os.environ.get("DOCKER_CONTAINER", False)
-
-
-class Lightbox:
- def __init__(self):
- logger.debug("Initializing Lightbox")
- with ui.dialog().props("maximized").classes("bg-black") as self.dialog:
- self.dialog.on_key = self._handle_key
- self.large_image = ui.image().props("no-spinner fit=scale-down")
- self.image_list = []
- logger.debug("Lightbox initialized")
-
- def add_image(
- self,
- thumb_url: str,
- orig_url: str,
- thumb_classes: str = "w-32 h-32 object-cover",
- ) -> ui.button:
- logger.debug(f"Adding image to Lightbox: {orig_url}")
- self.image_list.append(orig_url)
- button = ui.button(on_click=lambda: self._open(orig_url)).props(
- "flat dense square"
- )
- with button:
- ui.image(thumb_url).classes(thumb_classes)
- logger.debug("Image added to Lightbox")
- return button
-
- def _handle_key(self, e) -> None:
- logger.debug(f"Handling key press in Lightbox: {e.key}")
- if not e.action.keydown:
- return
- if e.key.escape:
- logger.debug("Closing Lightbox dialog")
- self.dialog.close()
- image_index = self.image_list.index(self.large_image.source)
- if e.key.arrow_left and image_index > 0:
- logger.debug("Displaying previous image")
- self._open(self.image_list[image_index - 1])
- if e.key.arrow_right and image_index < len(self.image_list) - 1:
- logger.debug("Displaying next image")
- self._open(self.image_list[image_index + 1])
-
- def _open(self, url: str) -> None:
- logger.debug(f"Opening image in Lightbox: {url}")
- self.large_image.set_source(url)
- self.dialog.open()
-
-
-class ImageGeneratorGUI:
- def __init__(self, image_generator):
- logger.info("Initializing ImageGeneratorGUI")
- self.image_generator = image_generator
- self.api_key = get_api_key() or os.environ.get("REPLICATE_API_KEY", "")
- self.last_generated_images = []
- self.setup_custom_styles()
- self._attributes = [
- "prompt",
- "flux_model",
- "aspect_ratio",
- "num_outputs",
- "lora_scale",
- "num_inference_steps",
- "guidance_scale",
- "output_format",
- "output_quality",
- "disable_safety_checker",
- "width",
- "height",
- "seed",
- "output_folder",
- "replicate_model",
- ]
-
- self.user_added_models = {}
- self.prompt = get_setting("default", "prompt", "", str)
-
- self.flux_model = get_setting("default", "flux_model", "dev", str)
- self.aspect_ratio = get_setting("default", "aspect_ratio", "1:1", str)
- self.num_outputs = get_setting("default", "num_outputs", "1", int)
- self.lora_scale = get_setting("default", "lora_scale", "1", float)
- self.num_inference_steps = get_setting(
- "default", "num_inference_steps", "28", int
- )
- self.guidance_scale = get_setting("default", "guidance_scale", "3.5", float)
- self.output_format = get_setting("default", "output_format", "png")
- self.output_quality = get_setting("default", "output_quality", "80", int)
- self.disable_safety_checker = get_setting(
- "default", "disable_safety_checker", True, bool
- )
-
- self.width = get_setting("default", "width", "1024", int)
- self.height = get_setting("default", "height", "1024", int)
- self.seed = get_setting("default", "seed", "-1", int)
-
- self.output_folder = (
- "/app/output"
- if DOCKERIZED
- else get_setting("default", "output_folder", "/Downloads", str)
- )
- models_json = get_setting("default", "models", '{"user_added": []}', str)
- models = json.loads(models_json)
- self.user_added_models = {
- model: model for model in models.get("user_added", [])
- }
- self.model_options = list(self.user_added_models.keys())
- self.replicate_model = get_setting("default", "replicate_model", "", str)
-
- logger.info("ImageGeneratorGUI initialized")
-
- def setup_custom_styles(self):
- logger.debug("Setting up custom styles")
- ui.add_head_html("""
-
-
- """)
-
- def setup_ui(self):
- logger.info("Setting up UI")
- ui.dark_mode()
- self.check_api_key()
-
- with ui.grid().classes(
- "w-full h-screen md:h-full grid-cols-1 md:grid-cols-2 gap-2 md:gap-5 p-4 md:p-6 dark:bg-[#1f2328] bg-[#ffffff] md:auto-rows-min"
- ):
- with ui.card().classes(
- "col-span-full modern-card dark:bg-[#25292e] bg-[#818b981f] flex-nowrap h-min"
- ):
- self.setup_top_panel()
-
- with ui.card().classes(
- "col-span-full modern-card dark:bg-[#25292e] bg-[#818b981f]"
- ):
- self.setup_prompt_panel()
-
- with ui.card().classes(
- "row-span-2 overflow-auto modern-card dark:bg-[#25292e] bg-[#818b981f]"
- ):
- self.setup_left_panel()
-
- with ui.card().classes(
- "row-span-2 overflow-auto modern-card dark:bg-[#25292e] bg-[#818b981f]"
- ):
- self.setup_right_panel()
-
- logger.info("UI setup completed")
-
- def setup_top_panel(self):
- logger.debug("Setting up top panel")
- with ui.row().classes("w-full items-center"):
- ui.label("Lumberjack - Replicate API Interface").classes(
- "text-2xl/loose font-bold"
- )
- dark_mode = ui.dark_mode(True)
- ui.switch().bind_value(dark_mode).classes().props(
- "dense checked-icon=dark_mode unchecked-icon=light_mode color=blue-7"
- )
- ui.button(
- icon="settings_suggest",
- on_click=self.open_settings_popup,
- color="#0969da",
- ).classes("absolute-right mr-6 mt-3 mb-3")
-
- def setup_left_panel(self):
- logger.debug("Setting up left panel")
- with ui.row().classes("w-full flex-row flex-nowrap"):
- self.replicate_model_select = (
- ui.select(
- options=self.model_options,
- label="Replicate Model",
- value=self.replicate_model,
- on_change=lambda e: asyncio.create_task(
- self.update_replicate_model(e.value)
- ),
- )
- .classes("width-5/6 overflow-auto custom-select")
- .tooltip("Select or manage Replicate models")
- .props("filled")
- )
- ui.button(icon="token", color="#0969da").classes("ml-2 mt-1.2").on(
- "click", self.open_user_model_popup
- ).props("size=1.3rem")
-
- self.flux_model_select = (
- ui.select(
- ["dev", "schnell"],
- label="Flux Model",
- value=get_setting("default", "flux_model", "dev"),
- )
- .classes("w-full text-gray-200")
- .tooltip(
- "Which model to run inferences with. The dev model needs around 28 steps but the schnell model only needs around 4 steps."
- )
- .bind_value(self, "flux_model")
- .props("filled")
- )
-
- with ui.row().classes("w-full flex-nowrap md:flex-wrap"):
- self.aspect_ratio_select = (
- ui.select(
- [
- "1:1",
- "16:9",
- "21:9",
- "3:2",
- "2:3",
- "4:5",
- "5:4",
- "3:4",
- "4:3",
- "9:16",
- "9:21",
- "custom",
- ],
- label="Aspect Ratio",
- value=get_setting("default", "aspect_ratio", "1:1"),
- )
- .classes("w-1/2 md:w-full text-gray-200")
- .bind_value(self, "aspect_ratio")
- .tooltip(
- "Width of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16)"
- )
- .props("filled")
- )
- self.aspect_ratio_select.on("change", self.toggle_custom_dimensions)
-
- with ui.column().classes("w-full").bind_visibility_from(
- self.aspect_ratio_select, "value", value="custom"
- ):
- self.width_input = (
- ui.number(
- "Width",
- value=get_setting("default", "width", 1024, int),
- min=256,
- max=1440,
- )
- .classes("w-full")
- .bind_value(self, "width")
- .tooltip(
- "Width of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16)"
- )
- )
- self.height_input = (
- ui.number(
- "Height",
- value=get_setting("default", "height", 1024, int),
- min=256,
- max=1440,
- )
- .classes("w-full")
- .bind_value(self, "height")
- .tooltip(
- "Height of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16)"
- )
- )
-
- self.num_outputs_input = (
- ui.number(
- "Num Outputs",
- value=get_setting("default", "num_outputs", 1, int),
- min=1,
- max=4,
- )
- .classes("w-1/2 md:w-full")
- .bind_value(self, "num_outputs")
- .tooltip("Number of images to output.")
- .props("filled")
- )
-
- with ui.row().classes("w-full flex-nowrap md:flex-wrap"):
- self.lora_scale_input = (
- ui.number(
- "LoRA Scale",
- value=float(get_setting("default", "lora_scale", 1)),
- min=-1,
- max=2,
- step=0.1,
- )
- .classes("w-1/2 md:w-full")
- .tooltip(
- "Determines how strongly the LoRA should be applied. Sane results between 0 and 1."
- )
- .props("filled")
- .bind_value(self, "lora_scale")
- )
- self.num_inference_steps_input = (
- ui.number(
- "Num Inference Steps",
- value=get_setting("default", "num_inference_steps", 28, int),
- min=1,
- max=50,
- )
- .classes("w-1/2 md:w-full")
- .tooltip("Number of Inference Steps")
- .bind_value(self, "num_inference_steps")
- .props("filled")
- )
-
- with ui.row().classes("w-full flex-nowrap md:flex-wrap"):
- self.guidance_scale_input = (
- ui.number(
- "Guidance Scale",
- value=float(get_setting("default", "guidance_scale", 3.5)),
- min=0,
- max=10,
- step=0.1,
- precision=2,
- )
- .classes("w-1/2 md:w-full")
- .tooltip("Guidance Scale for the diffusion process")
- .bind_value(self, "guidance_scale")
- .props("filled")
- )
- self.seed_input = (
- ui.number(
- "Seed",
- value=get_setting("default", "seed", -1, int),
- min=-2147483648,
- max=2147483647,
- )
- .classes("w-1/2 md:w-full")
- .bind_value(self, "seed")
- .props("filled")
- )
-
- with ui.row().classes("w-full flex-nowrap"):
- self.output_format_select = (
- ui.select(
- ["webp", "jpg", "png"],
- label="Output Format",
- value=get_setting("default", "output_format", "webp"),
- )
- .classes("w-full")
- .tooltip("Format of the output images")
- .bind_value(self, "output_format")
- .props("filled")
- )
-
- self.output_quality_input = (
- ui.number(
- "Output Quality",
- value=get_setting("default", "output_quality", 80, int),
- min=0,
- max=100,
- )
- .classes("w-full")
- .tooltip(
- "Quality when saving the output images, from 0 to 100. 100 is best quality, 0 is lowest quality. Not relevant for .png outputs"
- )
- .bind_value(self, "output_quality")
- .props("filled")
- )
-
- with ui.row().classes("w-full flex-nowrap"):
- self.disable_safety_checker_switch = (
- ui.switch(
- "Disable Safety Checker",
- value=get_setting(
- "default", "disable_safety_checker", fallback="False"
- ).lower()
- == "true",
- )
- .classes("w-1/2")
- .tooltip("Disable safety checker for generated images.")
- .bind_value(self, "disable_safety_checker")
- .props("filled color=blue-8")
- )
- self.reset_button = ui.button(
- "Reset Parameters", on_click=self.reset_to_default, color="#cf222e"
- ).classes("w-1/2 text-white font-bold py-2 px-4 rounded")
-
- def setup_right_panel(self):
- logger.debug("Setting up right panel")
- with ui.row().classes("w-full flex-nowrap"):
- ui.label("Output").classes("text-center ml-4 mt-3 w-full").style(
- "font-size: 230%; font-weight: bold; text-align: left;"
- )
- ui.button(
- "Download Images", on_click=self.download_zip, color="#0969da"
- ).classes("modern-button text-white font-bold py-2 px-4 rounded")
- ui.separator()
- with ui.row().classes("w-full flex-nowrap"):
- self.gallery_container = ui.column().classes(
- "w-full mt-4 grid grid-cols-2 gap-4"
- )
- self.lightbox = Lightbox()
-
- def setup_prompt_panel(self):
- logger.debug("Setting up prompt panel")
- with ui.row().classes("w-full flex-row flex-nowrap"):
- self.prompt_input = (
- ui.textarea("Prompt", value=self.prompt)
- .classes("w-full text-gray-200 shadow-lg")
- .bind_value(self, "prompt")
- .props("clearable filled autofocus")
- )
- self.generate_button = (
- ui.button(icon="bolt", on_click=self.start_generation, color="#0969da")
- .classes("ml-2 font-bold rounded modern-button h-full")
- .props("size=1.5rem")
- .style("animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;")
- )
- self.progress = (
- ui.linear_progress(show_value=False, size="20px")
- .classes("w-full")
- .props("indeterminate")
- )
- self.progress.visible = False
-
- async def open_settings_popup(self):
- logger.debug("Opening settings popup")
- with ui.dialog() as dialog, ui.card().classes(
- "w-2/3 modern-card dark:bg-[#25292e] bg-[#818b981f]"
- ):
- ui.label("Settings").classes("text-2xl font-bold")
- api_key_input = ui.input(
- label="API Key",
- placeholder="Enter Replicate API Key...",
- password=True,
- value=self.api_key,
- ).classes("w-full mb-4")
-
- async def save_settings():
- logger.debug("Saving settings")
- new_api_key = api_key_input.value
- if new_api_key != self.api_key:
- self.api_key = new_api_key
- set_setting("secrets", "REPLICATE_API_KEY", new_api_key)
- await self.save_settings()
- os.environ["REPLICATE_API_KEY"] = new_api_key
- self.image_generator.set_api_key(new_api_key)
- logger.info("API key saved")
-
- dialog.close()
- ui.notify("Settings saved successfully", type="positive")
-
- if not DOCKERIZED:
- self.folder_input = ui.input(
- label="Output Folder", value=self.output_folder
- ).classes("w-full mb-4")
- self.folder_input.on("change", self.update_folder_path)
- ui.button(
- "Save Settings", on_click=save_settings, color="#818b981f"
- ).classes("mt-4")
- dialog.open()
-
- async def save_api_key(self):
- logger.debug("Saving API key")
- set_setting("secrets", "REPLICATE_API_KEY", self.api_key)
- save_settings()
- os.environ["REPLICATE_API_KEY"] = self.api_key
- self.image_generator.set_api_key(self.api_key)
-
- @ui.refreshable
- def model_list(self):
- logger.debug("Refreshing model list")
- for model in self.user_added_models:
- with ui.row().classes("w-full justify-between items-center"):
- ui.label(model)
- ui.button(
- icon="delete",
- on_click=lambda m=model: self.confirm_delete_model(m),
- color="#818b981f",
- ).props("flat round color=red")
-
- async def open_user_model_popup(self):
- logger.debug("Opening user model popup")
-
- async def add_model():
- await self.add_user_model(new_model_input.value)
-
- with ui.dialog() as dialog, ui.card():
- ui.label("Manage Replicate Models").classes("text-xl font-bold mb-4")
- new_model_input = ui.input(label="Add New Model").classes("w-full mb-4")
- ui.button("Add Model", on_click=add_model, color="#818b981f")
-
- ui.label("Current Models:").classes("mt-4 mb-2")
- self.model_list()
-
- ui.button("Close", on_click=dialog.close, color="#818b981f").classes("mt-4")
- dialog.open()
-
- async def add_user_model(self, new_model):
- logger.debug(f"Adding user model: {new_model}")
- if new_model and new_model not in self.user_added_models:
- try:
- latest_v = await asyncio.to_thread(
- self.image_generator.get_model_version, new_model
- )
- self.user_added_models[new_model] = latest_v
- self.model_options = list(self.user_added_models.values())
- self.replicate_model_select.options = self.model_options
- self.replicate_model_select.value = latest_v
- await self.update_replicate_model(latest_v)
- models_json = json.dumps(
- {"user_added": list(self.user_added_models.values())}
- )
- set_setting("default", "models", models_json)
- save_settings()
- ui.notify(f"Model '{latest_v}' added successfully", type="positive")
- self.model_list.refresh()
- logger.info(f"User model added: {latest_v}")
- except Exception as e:
- logger.error(f"Error adding model: {str(e)}")
- ui.notify(f"Error adding model: {str(e)}", type="negative")
- else:
- logger.warning(f"Invalid model name or model already exists: {new_model}")
- ui.notify("Invalid model name or model already exists", type="negative")
-
- async def confirm_delete_model(self, model):
- logger.debug(f"Confirming deletion of model: {model}")
- with ui.dialog() as confirm_dialog, ui.card():
- ui.label(f"Are you sure you want to delete the model '{model}'?").classes(
- "mb-4"
- )
- with ui.row():
- ui.button(
- "Yes",
- on_click=lambda: self.delete_user_model(model, confirm_dialog),
- color="1f883d",
- ).classes("mr-2")
- ui.button("No", on_click=confirm_dialog.close, color="cf222e")
- confirm_dialog.open()
-
- async def delete_user_model(self, model, confirm_dialog):
- logger.debug(f"Deleting user model: {model}")
- if model in self.user_added_models:
- del self.user_added_models[model]
- self.model_options = list(self.user_added_models.keys())
- self.replicate_model_select.options = self.model_options
- if self.replicate_model_select.value == model:
- self.replicate_model_select.value = None
- await self.update_replicate_model(None)
- models_json = json.dumps(
- {"user_added": list(self.user_added_models.keys())}
- )
- set_setting("default", "models", models_json)
- save_settings()
- ui.notify(f"Model '{model}' deleted successfully", type="positive")
- confirm_dialog.close()
- self.model_list.refresh()
- logger.info(f"User model deleted: {model}")
- else:
- logger.warning(f"Cannot delete model, not found: {model}")
- ui.notify("Cannot delete this model", type="negative")
-
- async def update_replicate_model(self, new_model):
- logger.debug(f"Updating Replicate model to: {new_model}")
- if new_model:
- await asyncio.to_thread(self.image_generator.set_model, new_model)
- self.replicate_model = new_model
- await self.save_settings()
- logger.info(f"Replicate model updated to: {new_model}")
- self.generate_button.enable()
- else:
- logger.warning("No Replicate model selected")
- self.generate_button.disable()
-
- async def update_folder_path(self, e):
- logger.debug("Updating folder path")
- if hasattr(e, "value"):
- new_path = e.value
- elif hasattr(e, "sender") and hasattr(e.sender, "value"):
- new_path = e.sender.value
- elif hasattr(e, "args") and e.args:
- new_path = e.args[0]
- else:
- new_path = None
-
- if new_path is None:
- logger.error("Failed to extract new path from event object")
- ui.notify("Error updating folder path", type="negative")
- return
-
- if os.path.isdir(new_path):
- self.output_folder = new_path
- set_setting("default", "output_folder", new_path)
- save_settings()
- logger.info(f"Output folder set to: {self.output_folder}")
- ui.notify(
- f"Output folder updated to: {self.output_folder}", type="positive"
- )
- else:
- logger.warning(f"Invalid folder path: {new_path}")
- ui.notify(
- "Invalid folder path. Please enter a valid directory.", type="negative"
- )
- if hasattr(self, "folder_input"):
- self.folder_input.value = self.output_folder
-
- async def toggle_custom_dimensions(self, e):
- logger.debug(f"Toggling custom dimensions: {e.value}")
- if e.value == "custom":
- self.width_input.enable()
- self.height_input.enable()
- else:
- self.width_input.disable()
- self.height_input.disable()
- await self.save_settings()
- logger.info(f"Custom dimensions toggled: {e.value}")
-
- def check_api_key(self):
- logger.debug("Checking API key")
- if not self.api_key:
- logger.warning("No Replicate API Key found.")
- ui.notify(
- "No Replicate API Key found. Please set it in the settings before generating images.",
- type="warning",
- close_button="OK",
- timeout=10000, # 10 seconds
- position="top",
- )
-
- async def reset_to_default(self):
- logger.debug("Resetting parameters to default values")
- for attr in self._attributes:
- if attr not in ["models", "replicate_model"]:
- value = get_setting("default", attr)
- if value is not None:
- if attr in [
- "num_outputs",
- "num_inference_steps",
- "width",
- "height",
- "seed",
- "output_quality",
- ]:
- value = int(value)
- elif attr in ["lora_scale", "guidance_scale"]:
- value = float(value)
- elif attr == "disable_safety_checker":
- value = value.lower() == "true"
-
- setattr(self, attr, value)
- if hasattr(self, f"{attr}_input"):
- getattr(self, f"{attr}_input").value = value
- elif hasattr(self, f"{attr}_select"):
- getattr(self, f"{attr}_select").value = value
- elif hasattr(self, f"{attr}_switch"):
- getattr(self, f"{attr}_switch").value = value
-
- await self.save_settings()
- ui.notify("Parameters reset to default values", type="info")
- logger.info("Parameters reset to default values")
-
- async def start_generation(self):
- logger.debug("Starting image generation")
- if not self.api_key:
- ui.notify(
- "Please set your Replicate API Key in the settings.", type="negative"
- )
- logger.error("Cannot start generation: No API key set.")
- return
- if not self.replicate_model_select.value:
- ui.notify(
- "Please select a Replicate model before generating images.",
- type="negative",
- )
- logger.warning(
- "Attempted to generate images without selecting a Replicate model"
- )
- return
-
- await asyncio.to_thread(
- self.image_generator.set_model, self.replicate_model_select.value
- )
-
- await self.save_settings()
- params = {
- "prompt": self.prompt_input.value,
- "flux_model": self.flux_model,
- "aspect_ratio": self.aspect_ratio,
- "num_outputs": self.num_outputs,
- "lora_scale": self.lora_scale,
- "num_inference_steps": self.num_inference_steps,
- "guidance_scale": self.guidance_scale,
- "output_format": self.output_format,
- "output_quality": self.output_quality,
- "disable_safety_checker": self.disable_safety_checker,
- }
-
- if self.aspect_ratio == "custom":
- params["width"] = self.width
- params["height"] = self.height
-
- if self.seed != -1:
- params["seed"] = self.seed
-
- self.generate_button.disable()
- self.progress.visible = True
- ui.notify("Generating images...", type="info")
- logger.info(f"Generating images with params: {json.dumps(params, indent=2)}")
-
- try:
- output = await asyncio.to_thread(
- self.image_generator.generate_images, params
- )
- await self.download_and_display_images(output)
- logger.success(f"Images generated successfully: {output}")
- except Exception as e:
- error_message = f"An error occurred: {str(e)}"
- ui.notify(error_message, type="negative")
- logger.exception(error_message)
- finally:
- self.generate_button.enable()
- self.progress.visible = False
-
- def create_zip_file(self):
- logger.debug("Creating zip file of generated images")
- if not self.last_generated_images:
- ui.notify("No images to download", type="warning")
- logger.warning("No images to zip")
- return None
-
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- zip_filename = f"generated_images_{timestamp}.zip"
- zip_path = Path(self.output_folder) / zip_filename
-
- with zipfile.ZipFile(zip_path, "w") as zipf:
- for image_path in self.last_generated_images:
- zipf.write(image_path, Path(image_path).name)
- logger.info(f"Zip file created: {zip_path}")
- return str(zip_path)
-
- def download_zip(self):
- logger.debug("Downloading zip file")
- zip_path = self.create_zip_file()
- if zip_path:
- ui.download(zip_path)
- ui.notify("Downloading zip file of generated images", type="positive")
-
- async def update_gallery(self, image_paths):
- logger.debug("Updating image gallery")
- self.gallery_container.clear()
- self.last_generated_images = image_paths
- with self.gallery_container:
- with ui.row().classes("w-full"):
- with ui.grid(columns=2).classes("md:grid-cols-3 w-full gap-2"):
- for image_path in image_paths:
- self.lightbox.add_image(
- image_path, image_path, "w-full h-full object-cover"
- )
- logger.debug("Image gallery updated")
-
- async def download_and_display_images(self, image_urls):
- logger.debug("Downloading and displaying generated images")
- downloaded_images = []
- async with httpx.AsyncClient() as client:
- for i, url in enumerate(image_urls):
- logger.debug(f"Downloading image from {url}")
- response = await client.get(url)
- if response.status_code == 200:
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- url_part = urllib.parse.urlparse(url).path.split("/")[-2][:8]
- file_name = f"generated_image_{timestamp}_{url_part}_{i+1}.png"
- file_path = Path(self.output_folder) / file_name
- with open(file_path, "wb") as f:
- f.write(response.content)
- downloaded_images.append(str(file_path))
- logger.info(f"Image downloaded: {file_path}")
- else:
- logger.error(f"Failed to download image from {url}")
-
- await self.update_gallery(downloaded_images)
- ui.notify("Images generated and downloaded successfully!", type="positive")
- logger.success("Images downloaded and displayed")
-
- async def save_settings(self):
- logger.debug("Saving settings")
- for attr in self._attributes:
- value = getattr(self, attr)
- if attr == "models":
- value = json.dumps({"user_added": list(self.user_added_models.keys())})
- set_setting("default", attr, str(value))
-
- set_setting("default", "replicate_model", self.replicate_model_select.value)
-
- save_settings()
- logger.info("Settings saved successfully")
-
-
-async def create_gui(image_generator):
- logger.debug("Creating GUI")
- gui = ImageGeneratorGUI(image_generator)
- gui.setup_ui()
- logger.debug("GUI created")
- return gui
diff --git a/src/gui/__init__.py b/src/gui/__init__.py
new file mode 100644
index 0000000..b051bde
--- /dev/null
+++ b/src/gui/__init__.py
@@ -0,0 +1,6 @@
+from .lightbox import Lightbox
+from .imagegenerator import ImageGeneratorGUI
+from .styles import Styles
+from .usermodels import UserModels
+from .filehandler import FileHandler
+
diff --git a/src/gui/filehandler.py b/src/gui/filehandler.py
new file mode 100644
index 0000000..7fef529
--- /dev/null
+++ b/src/gui/filehandler.py
@@ -0,0 +1,33 @@
+from nicegui import ui
+from datetime import datetime
+import zipfile
+from loguru import logger
+from pathlib import Path
+
+
+class FileHandler:
+ @staticmethod
+ def create_zip_file(last_generated_images, output_folder):
+ logger.debug("Creating zip file of generated images")
+ if not last_generated_images:
+ ui.notify("No images to download", type="warning")
+ logger.warning("No images to zip")
+ return None
+
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ zip_filename = f"generated_images_{timestamp}.zip"
+ zip_path = Path(output_folder) / zip_filename
+
+ with zipfile.ZipFile(zip_path, "w") as zipf:
+ for image_path in last_generated_images:
+ zipf.write(image_path, Path(image_path).name)
+ logger.info(f"Zip file created: {zip_path}")
+ return str(zip_path)
+
+ @staticmethod
+ def download_zip(last_generated_images, output_folder):
+ logger.debug("Downloading zip file")
+ zip_path = FileHandler.create_zip_file(last_generated_images, output_folder)
+ if zip_path:
+ ui.download(zip_path)
+ ui.notify("Downloading zip file of generated images", type="positive")
diff --git a/src/gui/imagegenerator.py b/src/gui/imagegenerator.py
new file mode 100644
index 0000000..ad48136
--- /dev/null
+++ b/src/gui/imagegenerator.py
@@ -0,0 +1,384 @@
+import asyncio
+import json
+import os
+import urllib.parse
+from datetime import datetime
+from pathlib import Path
+from gui.lightbox import Lightbox
+from gui.styles import Styles
+from gui.panels import GUIPanels
+from gui.filehandler import FileHandler
+import httpx
+from loguru import logger
+from nicegui import ui
+from util import Settings
+
+DOCKERIZED = os.environ.get("DOCKER_CONTAINER", False)
+
+
+class ImageGeneratorGUI:
+ def __init__(self, image_generator):
+ logger.info("Initializing ImageGeneratorGUI")
+ self.image_generator = image_generator
+ self.api_key = Settings.get_api_key() or os.environ.get("REPLICATE_API_KEY", "")
+ self.last_generated_images = []
+ self.custom_styles = Styles.setup_custom_styles()
+ self._attributes = [
+ "prompt",
+ "flux_model",
+ "aspect_ratio",
+ "num_outputs",
+ "lora_scale",
+ "num_inference_steps",
+ "guidance_scale",
+ "output_format",
+ "output_quality",
+ "disable_safety_checker",
+ "width",
+ "height",
+ "seed",
+ "output_folder",
+ "replicate_model",
+ ]
+
+ self.user_added_models = {}
+ self.prompt = Settings.get_setting("default", "prompt", "", str)
+
+ self.flux_model = Settings.get_setting("default", "flux_model", "dev", str)
+ self.aspect_ratio = Settings.get_setting("default", "aspect_ratio", "1:1", str)
+ self.num_outputs = Settings.get_setting("default", "num_outputs", "1", int)
+ self.lora_scale = Settings.get_setting("default", "lora_scale", "1", float)
+ self.num_inference_steps = Settings.get_setting(
+ "default", "num_inference_steps", "28", int
+ )
+ self.guidance_scale = Settings.get_setting(
+ "default", "guidance_scale", "3.5", float
+ )
+ self.output_format = Settings.get_setting("default", "output_format", "png")
+ self.output_quality = Settings.get_setting(
+ "default", "output_quality", "80", int
+ )
+ self.disable_safety_checker = Settings.get_setting(
+ "default", "disable_safety_checker", True, bool
+ )
+
+ self.width = Settings.get_setting("default", "width", "1024", int)
+ self.height = Settings.get_setting("default", "height", "1024", int)
+ self.seed = Settings.get_setting("default", "seed", "-1", int)
+
+ self.output_folder = (
+ "/app/output"
+ if DOCKERIZED
+ else Settings.get_setting("default", "output_folder", "/Downloads", str)
+ )
+ models_json = Settings.get_setting(
+ "default", "models", '{"user_added": []}', str
+ )
+ models = json.loads(models_json)
+ self.user_added_models = {
+ model: model for model in models.get("user_added", [])
+ }
+ self.model_options = list(self.user_added_models.keys())
+ self.replicate_model = Settings.get_setting(
+ "default", "replicate_model", "", str
+ )
+
+ logger.info("ImageGeneratorGUI initialized")
+
+ def setup_ui(self):
+ logger.info("Setting up UI")
+ ui.dark_mode(True)
+ self.check_api_key()
+
+ with ui.grid().classes(
+ "w-full h-screen md:h-full grid-cols-1 md:grid-cols-2 gap-2 md:gap-5 p-4 md:p-6 dark:bg-[#11111b] bg-#eff1f5] md:auto-rows-min"
+ ):
+ with ui.card().classes("col-span-full modern-card flex-nowrap h-min"):
+ GUIPanels.setup_top_panel(self)
+
+ with ui.card().classes("col-span-full modern-card"):
+ GUIPanels.setup_prompt_panel(self)
+
+ with ui.card().classes("row-span-2 overflow-auto modern-card"):
+ GUIPanels.setup_left_panel(self)
+
+ with ui.card().classes("row-span-2 overflow-auto modern-card"):
+ GUIPanels.setup_right_panel(self)
+ Styles.stylefilter(self)
+ logger.info("UI setup completed")
+
+ async def open_settings_popup(self):
+ logger.debug("Opening settings popup")
+ with ui.dialog() as dialog, ui.card().classes(
+ "w-2/3 modern-card dark:bg-[#25292e] bg-[#818b981f]"
+ ):
+ ui.label("Settings").classes("text-2xl font-bold")
+ api_key_input = ui.input(
+ label="API Key",
+ placeholder="Enter Replicate API Key...",
+ password=True,
+ value=self.api_key,
+ ).classes("w-full mb-4")
+
+ async def save_settings():
+ logger.debug("Saving settings")
+ new_api_key = api_key_input.value
+ if new_api_key != self.api_key:
+ self.api_key = new_api_key
+ Settings.set_setting("secrets", "REPLICATE_API_KEY", new_api_key)
+ await self.save_settings()
+ os.environ["REPLICATE_API_KEY"] = new_api_key
+ self.image_generator.set_api_key(new_api_key)
+ logger.info("API key saved")
+
+ dialog.close()
+ ui.notify("Settings saved successfully", type="positive")
+
+ if not DOCKERIZED:
+ self.folder_input = ui.input(
+ label="Output Folder", value=self.output_folder
+ ).classes("w-full mb-4")
+ self.folder_input.on("change", self.update_folder_path)
+ ui.button("Save Settings", on_click=save_settings, color="blue-4").classes(
+ "mt-4"
+ )
+ dialog.open()
+
+ async def save_api_key(self):
+ logger.debug("Saving API key")
+ Settings.set_setting("secrets", "REPLICATE_API_KEY", self.api_key)
+ Settings.save_settings()
+ os.environ["REPLICATE_API_KEY"] = self.api_key
+ self.image_generator.set_api_key(self.api_key)
+
+ async def update_folder_path(self, e):
+ logger.debug("Updating folder path")
+ if hasattr(e, "value"):
+ new_path = e.value
+ elif hasattr(e, "sender") and hasattr(e.sender, "value"):
+ new_path = e.sender.value
+ elif hasattr(e, "args") and e.args:
+ new_path = e.args[0]
+ else:
+ new_path = None
+
+ if new_path is None:
+ logger.error("Failed to extract new path from event object")
+ ui.notify("Error updating folder path", type="negative")
+ return
+
+ if os.path.isdir(new_path):
+ self.output_folder = new_path
+ Settings.set_setting("default", "output_folder", new_path)
+ Settings.save_settings()
+ logger.info(f"Output folder set to: {self.output_folder}")
+ ui.notify(
+ f"Output folder updated to: {self.output_folder}", type="positive"
+ )
+ else:
+ logger.warning(f"Invalid folder path: {new_path}")
+ ui.notify(
+ "Invalid folder path. Please enter a valid directory.", type="negative"
+ )
+ if hasattr(self, "folder_input"):
+ self.folder_input.value = self.output_folder
+
+ async def toggle_custom_dimensions(self, e):
+ logger.debug(f"Toggling custom dimensions: {e.value}")
+ if e.value == "custom":
+ self.width_input.enable()
+ self.height_input.enable()
+ else:
+ self.width_input.disable()
+ self.height_input.disable()
+ await self.save_settings()
+ logger.info(f"Custom dimensions toggled: {e.value}")
+
+ def check_api_key(self):
+ logger.debug("Checking API key")
+ if not self.api_key:
+ logger.warning("No Replicate API Key found.")
+ ui.notify(
+ "No Replicate API Key found. Please set it in the settings before generating images.",
+ type="warning",
+ close_button="OK",
+ timeout=10000, # 10 seconds
+ position="top",
+ )
+
+ async def reset_to_default(self):
+ logger.debug("Resetting parameters to default values")
+ for attr in self._attributes:
+ if attr not in ["models", "replicate_model"]:
+ value = Settings.get_setting("default", attr)
+ if value is not None:
+ if attr in [
+ "num_outputs",
+ "num_inference_steps",
+ "width",
+ "height",
+ "seed",
+ "output_quality",
+ ]:
+ value = int(value)
+ elif attr in ["lora_scale", "guidance_scale"]:
+ value = float(value)
+ elif attr == "disable_safety_checker":
+ value = value.lower() == "true"
+
+ setattr(self, attr, value)
+ if hasattr(self, f"{attr}_input"):
+ getattr(self, f"{attr}_input").value = value
+ elif hasattr(self, f"{attr}_select"):
+ getattr(self, f"{attr}_select").value = value
+ elif hasattr(self, f"{attr}_switch"):
+ getattr(self, f"{attr}_switch").value = value
+
+ await self.save_settings()
+ ui.notify("Parameters reset to default values", type="info")
+ logger.info("Parameters reset to default values")
+
+ async def start_generation(self):
+ logger.debug("Starting image generation")
+ if not self.api_key:
+ ui.notify(
+ "Please set your Replicate API Key in the settings.", type="negative"
+ )
+ logger.error("Cannot start generation: No API key set.")
+ return
+ if not self.replicate_model_select.value:
+ ui.notify(
+ "Please select a Replicate model before generating images.",
+ type="negative",
+ )
+ logger.warning(
+ "Attempted to generate images without selecting a Replicate model"
+ )
+ return
+
+ await asyncio.to_thread(
+ self.image_generator.set_model, self.replicate_model_select.value
+ )
+
+ await self.save_settings()
+ params = {
+ "prompt": self.prompt_input.value,
+ "flux_model": self.flux_model,
+ "aspect_ratio": self.aspect_ratio,
+ "num_outputs": self.num_outputs,
+ "lora_scale": self.lora_scale,
+ "num_inference_steps": self.num_inference_steps,
+ "guidance_scale": self.guidance_scale,
+ "output_format": self.output_format,
+ "output_quality": self.output_quality,
+ "disable_safety_checker": self.disable_safety_checker,
+ }
+
+ if self.aspect_ratio == "custom":
+ params["width"] = self.width
+ params["height"] = self.height
+
+ if self.seed != -1:
+ params["seed"] = self.seed
+
+ self.generate_button.disable()
+ self.progress.visible = True
+ ui.notify("Generating images...", type="info")
+ logger.info(f"Generating images with params: {json.dumps(params, indent=2)}")
+
+ try:
+ output = await asyncio.to_thread(
+ self.image_generator.generate_images, params
+ )
+ await self.download_and_display_images(output)
+ logger.success(f"Images generated successfully: {output}")
+ except Exception as e:
+ error_message = f"An error occurred: {str(e)}"
+ ui.notify(error_message, type="negative")
+ logger.exception(error_message)
+ finally:
+ self.generate_button.enable()
+ self.progress.visible = False
+
+ # def create_zip_file(self):
+ # logger.debug("Creating zip file of generated images")
+ # if not self.last_generated_images:
+ # ui.notify("No images to download", type="warning")
+ # logger.warning("No images to zip")
+ # return None
+ #
+ # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ # zip_filename = f"generated_images_{timestamp}.zip"
+ # zip_path = Path(self.output_folder) / zip_filename
+ #
+ # with zipfile.ZipFile(zip_path, "w") as zipf:
+ # for image_path in self.last_generated_images:
+ # zipf.write(image_path, Path(image_path).name)
+ # logger.info(f"Zip file created: {zip_path}")
+ # return str(zip_path)
+ #
+ # def download_zip(self):
+ # logger.debug("Downloading zip file")
+ # zip_path = self.create_zip_file()
+ # if zip_path:
+ # ui.download(zip_path)
+ # ui.notify("Downloading zip file of generated images", type="positive")
+ #
+ async def update_gallery(self, image_paths):
+ logger.debug("Updating image gallery")
+ self.gallery_container.clear()
+ self.last_generated_images = image_paths
+ with self.gallery_container:
+ with ui.row().classes("w-full"):
+ with ui.grid(columns=2).classes("md:grid-cols-3 w-full gap-2"):
+ for image_path in image_paths:
+ self.lightbox.add_image(
+ image_path, image_path, "w-full h-full object-cover"
+ )
+ logger.debug("Image gallery updated")
+
+ async def download_and_display_images(self, image_urls):
+ logger.debug("Downloading and displaying generated images")
+ downloaded_images = []
+ async with httpx.AsyncClient() as client:
+ for i, url in enumerate(image_urls):
+ logger.debug(f"Downloading image from {url}")
+ response = await client.get(url)
+ if response.status_code == 200:
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ url_part = urllib.parse.urlparse(url).path.split("/")[-2][:8]
+ file_name = f"generated_image_{timestamp}_{url_part}_{i+1}.png"
+ file_path = Path(self.output_folder) / file_name
+ with open(file_path, "wb") as f:
+ f.write(response.content)
+ downloaded_images.append(str(file_path))
+ logger.info(f"Image downloaded: {file_path}")
+ else:
+ logger.error(f"Failed to download image from {url}")
+
+ await self.update_gallery(downloaded_images)
+ ui.notify("Images generated and downloaded successfully!", type="positive")
+ logger.success("Images downloaded and displayed")
+
+ async def save_settings(self):
+ logger.debug("Saving settings")
+ for attr in self._attributes:
+ value = getattr(self, attr)
+ if attr == "models":
+ value = json.dumps({"user_added": list(self.user_added_models.keys())})
+ Settings.set_setting("default", attr, str(value))
+
+ Settings.set_setting(
+ "default", "replicate_model", self.replicate_model_select.value
+ )
+
+ Settings.save_settings()
+ logger.info("Settings saved successfully")
+
+
+async def create_gui(image_generator):
+ logger.debug("Creating GUI")
+ gui = ImageGeneratorGUI(image_generator)
+ gui.setup_ui()
+ logger.debug("GUI created")
+ return gui
diff --git a/src/gui/lightbox.py b/src/gui/lightbox.py
new file mode 100644
index 0000000..012ec00
--- /dev/null
+++ b/src/gui/lightbox.py
@@ -0,0 +1,48 @@
+from nicegui import ui
+from loguru import logger
+
+
+class Lightbox:
+ def __init__(self):
+ logger.debug("Initializing Lightbox")
+ with ui.dialog().props("maximized").classes("bg-black") as self.dialog:
+ self.dialog.on_key = self._handle_key
+ self.large_image = ui.image().props("no-spinner fit=scale-down")
+ self.image_list = []
+ logger.debug("Lightbox initialized")
+
+ def add_image(
+ self,
+ thumb_url: str,
+ orig_url: str,
+ thumb_classes: str = "w-32 h-32 object-cover",
+ ) -> ui.button:
+ logger.debug(f"Adding image to Lightbox: {orig_url}")
+ self.image_list.append(orig_url)
+ button = ui.button(on_click=lambda: self._open(orig_url)).props(
+ "flat dense square"
+ )
+ with button:
+ ui.image(thumb_url).classes(thumb_classes)
+ logger.debug("Image added to Lightbox")
+ return button
+
+ def _handle_key(self, e) -> None:
+ logger.debug(f"Handling key press in Lightbox: {e.key}")
+ if not e.action.keydown:
+ return
+ if e.key.escape:
+ logger.debug("Closing Lightbox dialog")
+ self.dialog.close()
+ image_index = self.image_list.index(self.large_image.source)
+ if e.key.arrow_left and image_index > 0:
+ logger.debug("Displaying previous image")
+ self._open(self.image_list[image_index - 1])
+ if e.key.arrow_right and image_index < len(self.image_list) - 1:
+ logger.debug("Displaying next image")
+ self._open(self.image_list[image_index + 1])
+
+ def _open(self, url: str) -> None:
+ logger.debug(f"Opening image in Lightbox: {url}")
+ self.large_image.set_source(url)
+ self.dialog.open()
diff --git a/src/gui/panels.py b/src/gui/panels.py
new file mode 100644
index 0000000..402bdb1
--- /dev/null
+++ b/src/gui/panels.py
@@ -0,0 +1,281 @@
+import asyncio
+from nicegui import ui
+from loguru import logger
+from util.settings import Settings
+from gui.lightbox import Lightbox
+from gui.usermodels import UserModels
+from gui.filehandler import FileHandler
+from gui.styles import Styles
+
+
+class GUIPanels:
+ def setup_top_panel(self):
+ logger.debug("Setting up top panel")
+ with ui.row().classes("w-full items-center"):
+ ui.label("Lumberjack - Replicate API Interface").classes(
+ "text-2xl/loose font-bold"
+ )
+ # dark_mode = ui.dark_mode(True)
+ # ui.switch().bind_value(dark_mode).classes().props(
+ # "dense checked-icon=dark_mode unchecked-icon=light_mode color=blue-4"
+ # )
+ ui.button(
+ icon="settings_suggest",
+ on_click=self.open_settings_popup,
+ color="blue-4",
+ ).classes("absolute-right mr-6 mt-3 mb-3")
+
+ def setup_left_panel(self):
+ logger.debug("Setting up left panel")
+ with ui.row().classes("w-full flex-row flex-nowrap"):
+ self.replicate_model_select = (
+ ui.select(
+ options=self.model_options,
+ label="Replicate Model",
+ value=self.replicate_model,
+ on_change=lambda e: asyncio.create_task(
+ self.update_replicate_model(e.value)
+ ),
+ )
+ .classes("width-5/6 overflow-auto custom-select bg-[#1e1e2e]")
+ .tooltip("Select or manage Replicate models")
+ .props("filled bg-color=dark")
+ )
+ ui.button(icon="token", color="blue-4").classes("ml-2 mt-1.2").on(
+ "click", UserModels.open_user_model_popup(self)
+ ).props("size=1.3rem")
+
+ self.flux_model_select = (
+ ui.select(
+ ["dev", "schnell"],
+ label="Flux Model",
+ value=Settings.get_setting("default", "flux_model", "dev"),
+ )
+ .classes("w-full text-gray-200")
+ .tooltip(
+ "Which model to run inferences with. The dev model needs around 28 steps but the schnell model only needs around 4 steps."
+ )
+ .bind_value(self, "flux_model")
+ .props("filled bg-color=dark")
+ )
+
+ with ui.row().classes("w-full flex-nowrap md:flex-wrap"):
+ self.aspect_ratio_select = (
+ ui.select(
+ [
+ "1:1",
+ "16:9",
+ "21:9",
+ "3:2",
+ "2:3",
+ "4:5",
+ "5:4",
+ "3:4",
+ "4:3",
+ "9:16",
+ "9:21",
+ "custom",
+ ],
+ label="Aspect Ratio",
+ value=Settings.get_setting("default", "aspect_ratio", "1:1"),
+ )
+ .classes("w-1/2 md:w-full text-gray-200")
+ .bind_value(self, "aspect_ratio")
+ .tooltip(
+ "Width of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16)"
+ )
+ .props("filled bg-color=dark")
+ )
+ self.aspect_ratio_select.on("change", self.toggle_custom_dimensions)
+
+ with ui.column().classes("w-full").bind_visibility_from(
+ self.aspect_ratio_select, "value", value="custom"
+ ):
+ self.width_input = (
+ ui.number(
+ "Width",
+ value=Settings.get_setting("default", "width", 1024, int),
+ min=256,
+ max=1440,
+ )
+ .classes("w-full")
+ .bind_value(self, "width")
+ .tooltip(
+ "Width of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16)"
+ )
+ .props("filled bg-color=dark")
+ )
+ self.height_input = (
+ ui.number(
+ "Height",
+ value=Settings.get_setting("default", "height", 1024, int),
+ min=256,
+ max=1440,
+ )
+ .classes("w-full")
+ .bind_value(self, "height")
+ .tooltip(
+ "Height of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16)"
+ )
+ .props("filled bg-color=dark")
+ )
+
+ self.num_outputs_input = (
+ ui.number(
+ "Num Outputs",
+ value=Settings.get_setting("default", "num_outputs", 1, int),
+ min=1,
+ max=4,
+ )
+ .classes("w-1/2 md:w-full")
+ .bind_value(self, "num_outputs")
+ .tooltip("Number of images to output.")
+ .props("filled bg-color=dark")
+ )
+
+ with ui.row().classes("w-full flex-nowrap md:flex-wrap"):
+ self.lora_scale_input = (
+ ui.number(
+ "LoRA Scale",
+ value=float(Settings.get_setting("default", "lora_scale", 1)),
+ min=-1,
+ max=2,
+ step=0.1,
+ )
+ .classes("w-1/2 md:w-full")
+ .tooltip(
+ "Determines how strongly the LoRA should be applied. Sane results between 0 and 1."
+ )
+ .props("filled bg-color=dark color=primary")
+ .bind_value(self, "lora_scale")
+ )
+ self.num_inference_steps_input = (
+ ui.number(
+ "Num Inference Steps",
+ value=Settings.get_setting(
+ "default", "num_inference_steps", 28, int
+ ),
+ min=1,
+ max=50,
+ precision=0,
+ )
+ .classes("w-1/2 md:w-full")
+ .tooltip("Number of Inference Steps")
+ .bind_value(self, "num_inference_steps")
+ .props("filled bg-color=dark")
+ )
+
+ with ui.row().classes("w-full flex-nowrap md:flex-wrap"):
+ self.guidance_scale_input = (
+ ui.number(
+ "Guidance Scale",
+ value=float(Settings.get_setting("default", "guidance_scale", 3.5)),
+ min=0,
+ max=10,
+ step=0.1,
+ precision=2,
+ )
+ .classes("w-1/2 md:w-full")
+ .tooltip("Guidance Scale for the diffusion process")
+ .bind_value(self, "guidance_scale")
+ .props("filled bg-color=dark")
+ )
+ self.seed_input = (
+ ui.number(
+ "Seed",
+ value=Settings.get_setting("default", "seed", -1, int),
+ min=-2147483648,
+ max=2147483647,
+ )
+ .classes("w-1/2 md:w-full")
+ .bind_value(self, "seed")
+ .props("filled bg-color=dark")
+ )
+
+ with ui.row().classes("w-full flex-nowrap"):
+ self.output_format_select = (
+ ui.select(
+ ["webp", "jpg", "png"],
+ label="Output Format",
+ value=Settings.get_setting("default", "output_format", "webp"),
+ )
+ .classes("w-full")
+ .tooltip("Format of the output images")
+ .bind_value(self, "output_format")
+ .props("filled bg-color=dark")
+ )
+
+ self.output_quality_input = (
+ ui.number(
+ "Output Quality",
+ value=Settings.get_setting("default", "output_quality", 80, int),
+ min=0,
+ max=100,
+ )
+ .classes("w-full")
+ .tooltip(
+ "Quality when saving the output images, from 0 to 100. 100 is best quality, 0 is lowest quality. Not relevant for .png outputs"
+ )
+ .bind_value(self, "output_quality")
+ .props("filled bg-color=dark")
+ )
+
+ with ui.row().classes("w-full flex-nowrap"):
+ self.disable_safety_checker_switch = (
+ ui.switch(
+ "Disable Safety Checker",
+ value=Settings.get_setting(
+ "default", "disable_safety_checker", fallback="False"
+ ).lower()
+ == "true",
+ )
+ .classes("w-1/2")
+ .tooltip("Disable safety checker for generated images.")
+ .bind_value(self, "disable_safety_checker")
+ .props("filled bg-color=dark color=blue-4")
+ )
+ self.reset_button = ui.button(
+ "Reset Parameters", on_click=self.reset_to_default, color="#e78284"
+ ).classes("w-1/2 text-white font-bold py-2 px-4 rounded")
+
+ def setup_right_panel(self):
+ logger.debug("Setting up right panel")
+ with ui.row().classes("w-full flex-nowrap"):
+ ui.label("Output").classes("text-center ml-4 mt-3 w-full").style(
+ "font-size: 230%; font-weight: bold; text-align: left;"
+ )
+ ui.button(
+ "Download Images",
+ on_click=lambda: FileHandler.download_zip(
+ self.last_generated_images, self.output_folder
+ ),
+ color="blue-4",
+ ).classes("modern-button text-white font-bold py-2 px-4 rounded")
+ ui.separator()
+ with ui.row().classes("w-full flex-nowrap"):
+ self.gallery_container = ui.column().classes(
+ "w-full mt-4 grid grid-cols-2 gap-4"
+ )
+ self.lightbox = Lightbox()
+
+ def setup_prompt_panel(self):
+ logger.debug("Setting up prompt panel")
+ with ui.row().classes("w-full flex-row flex-nowrap"):
+ self.prompt_input = (
+ ui.textarea("Prompt", value=self.prompt)
+ .classes("w-full shadow-lg")
+ .bind_value(self, "prompt")
+ .props("clearable filled bg-color=dark autofocus color=blue-4")
+ )
+ self.generate_button = (
+ ui.button(icon="bolt", on_click=self.start_generation, color="blue-4")
+ .classes("ml-2 font-bold rounded modern-button h-full")
+ .props("size=1.5rem")
+ .style("animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;")
+ )
+ self.progress = (
+ ui.linear_progress(show_value=False, size="20px")
+ .classes("w-full")
+ .props("indeterminate")
+ )
+ self.progress.visible = False
diff --git a/src/gui/styles.py b/src/gui/styles.py
new file mode 100644
index 0000000..6b196c0
--- /dev/null
+++ b/src/gui/styles.py
@@ -0,0 +1,60 @@
+from nicegui import ui, ElementFilter
+from loguru import logger
+
+
+class Styles:
+ def setup_custom_styles():
+ logger.debug("Setting up custom styles")
+ ui.add_head_html("""
+
+
+ """)
+
+ # ui.add_css("""
+ # .bg {
+ # color: red;
+ # }
+ # """)
+ ui.colors(
+ dark="#303446", # Crust (Closest: Grey-9)
+ primary="#8aadf4", # Blue (Closest: Blue-4)
+ positive="#a6d189", # Green (Closest: Green-4)
+ negative="#e78284", # Maroon (Closest: Red-3)
+ secondary="#f38ba8", # Red (Closest: Red-4)
+ accent="#f5bde6", # Pink (Closest: Pink-3)
+ info="#89dceb", # Sky (Closest: Cyan-4)
+ warning="#f0a988", # Peach (Closest: Orange-4)
+ )
+
+ def stylefilter(self):
+ ElementFilter(kind=ui.label).classes("mt-0")
+
+ ElementFilter(kind=ui.card).classes("dark:bg-[#181825] bg-[#ccd0da]")
diff --git a/src/gui/usermodels.py b/src/gui/usermodels.py
new file mode 100644
index 0000000..45fc531
--- /dev/null
+++ b/src/gui/usermodels.py
@@ -0,0 +1,112 @@
+from nicegui import ui
+from loguru import logger
+import asyncio
+import json
+from util import Settings
+
+
+class UserModels:
+ @ui.refreshable
+ def model_list(self):
+ logger.debug("Refreshing model list")
+ for model in self.user_added_models:
+ with ui.row().classes("w-full justify-between items-center"):
+ ui.label(model)
+ ui.button(
+ icon="delete",
+ on_click=lambda m=model: self.confirm_delete_model(m),
+ color="#818b981f",
+ ).props("flat round color=red")
+
+ async def open_user_model_popup(self):
+ logger.debug("Opening user model popup")
+
+ async def add_model():
+ await self.add_user_model(new_model_input.value)
+
+ with ui.dialog() as dialog, ui.card():
+ ui.label("Manage Replicate Models").classes("text-xl font-bold mb-4")
+ new_model_input = ui.input(label="Add New Model").classes("w-full mb-4")
+ ui.button("Add Model", on_click=add_model, color="#818b981f")
+
+ ui.label("Current Models:").classes("mt-4 mb-2")
+ self.model_list()
+
+ ui.button("Close", on_click=dialog.close, color="#818b981f").classes("mt-4")
+ dialog.open()
+
+ async def add_user_model(self, new_model):
+ logger.debug(f"Adding user model: {new_model}")
+ if new_model and new_model not in self.user_added_models:
+ try:
+ latest_v = await asyncio.to_thread(
+ self.image_generator.get_model_version, new_model
+ )
+ self.user_added_models[new_model] = latest_v
+ self.model_options = list(self.user_added_models.values())
+ self.replicate_model_select.options = self.model_options
+ self.replicate_model_select.value = latest_v
+ await self.update_replicate_model(latest_v)
+ models_json = json.dumps(
+ {"user_added": list(self.user_added_models.values())}
+ )
+ Settings.set_setting("default", "models", models_json)
+ save_settings()
+ ui.notify(f"Model '{latest_v}' added successfully", type="positive")
+ self.model_list.refresh()
+ logger.info(f"User model added: {latest_v}")
+ except Exception as e:
+ logger.error(f"Error adding model: {str(e)}")
+ ui.notify(f"Error adding model: {str(e)}", type="negative")
+ else:
+ logger.warning(f"Invalid model name or model already exists: {new_model}")
+ ui.notify("Invalid model name or model already exists", type="negative")
+
+ async def confirm_delete_model(self, model):
+ logger.debug(f"Confirming deletion of model: {model}")
+ with ui.dialog() as confirm_dialog, ui.card():
+ ui.label(f"Are you sure you want to delete the model '{model}'?").classes(
+ "mb-4"
+ )
+ with ui.row():
+ ui.button(
+ "Yes",
+ on_click=lambda: self.delete_user_model(model, confirm_dialog),
+ color="1f883d",
+ ).classes("mr-2")
+ ui.button("No", on_click=confirm_dialog.close, color="cf222e")
+ confirm_dialog.open()
+
+ async def delete_user_model(self, model, confirm_dialog):
+ logger.debug(f"Deleting user model: {model}")
+ if model in self.user_added_models:
+ del self.user_added_models[model]
+ self.model_options = list(self.user_added_models.keys())
+ self.replicate_model_select.options = self.model_options
+ if self.replicate_model_select.value == model:
+ self.replicate_model_select.value = None
+ await self.update_replicate_model(None)
+ models_json = json.dumps(
+ {"user_added": list(self.user_added_models.keys())}
+ )
+ Settings.set_setting("default", "models", models_json)
+ save_settings()
+ ui.notify(f"Model '{model}' deleted successfully", type="positive")
+ confirm_dialog.close()
+ self.model_list.refresh()
+ logger.info(f"User model deleted: {model}")
+ else:
+ logger.warning(f"Cannot delete model, not found: {model}")
+ ui.notify("Cannot delete this model", type="negative")
+
+ async def update_replicate_model(self, new_model):
+ logger.debug(f"Updating Replicate model to: {new_model}")
+ if new_model:
+ await asyncio.to_thread(self.image_generator.set_model, new_model)
+ self.replicate_model = new_model
+ await self.save_settings()
+ logger.info(f"Replicate model updated to: {new_model}")
+ self.generate_button.enable()
+ else:
+ logger.warning("No Replicate model selected")
+ self.generate_button.disable()
diff --git a/src/gui/utilities.py b/src/gui/utilities.py
new file mode 100644
index 0000000..bbdb8f5
--- /dev/null
+++ b/src/gui/utilities.py
@@ -0,0 +1,5 @@
+from nicegui import ui
+from loguru import logger
+
+class Utilities:
+
diff --git a/src/main.py b/src/main.py
index 949692a..68d1bf6 100644
--- a/src/main.py
+++ b/src/main.py
@@ -1,10 +1,9 @@
import sys
-from config import get_api_key
-from gui import create_gui
from loguru import logger
from nicegui import ui
-from replicate_api import ImageGenerator
+import util
+from gui import ImageGeneratorGUI
logger.add(
sys.stderr, format="{time} {level} {message}", filter="my_module", level="INFO"
@@ -12,17 +11,16 @@
logger.add(
"app.log",
format="{time} {level} {module}:{line} {message}",
- level="DEBUG",
+ level="INFO",
rotation="500 MB",
compression="zip",
)
logger.info("Initializing ImageGenerator")
-generator = ImageGenerator()
+generator = util.Replicate_API()
-
-api_key = get_api_key()
+api_key = util.Settings.get_api_key()
if api_key:
generator.set_api_key(api_key)
else:
@@ -33,8 +31,10 @@
@ui.page("/")
async def main_page():
- await create_gui(generator)
- logger.info("NiceGUI server is running")
+ logger.debug("Creating GUI")
+ gui = ImageGeneratorGUI(generator)
+ gui.setup_ui()
+ logger.debug("GUI created")
logger.info("Starting NiceGUI server")
diff --git a/src/util/__init__.py b/src/util/__init__.py
new file mode 100644
index 0000000..5f59ee9
--- /dev/null
+++ b/src/util/__init__.py
@@ -0,0 +1,3 @@
+# from .settings import Settings
+from .replicate_api import Replicate_API
+from .settings import Settings
diff --git a/src/replicate_api.py b/src/util/replicate_api.py
similarity index 99%
rename from src/replicate_api.py
rename to src/util/replicate_api.py
index cafd3e0..4529ef9 100644
--- a/src/replicate_api.py
+++ b/src/util/replicate_api.py
@@ -8,7 +8,7 @@
load_dotenv()
-class ImageGenerator:
+class Replicate_API:
def __init__(self):
self.replicate_model = None
self.api_key = None
diff --git a/src/util/settings.py b/src/util/settings.py
new file mode 100644
index 0000000..eec88a1
--- /dev/null
+++ b/src/util/settings.py
@@ -0,0 +1,80 @@
+import configparser
+import os
+from typing import Any, Type
+
+from loguru import logger
+
+DOCKERIZED = os.environ.get("DOCKER_CONTAINER", "False").lower() == "true"
+CONFIG_DIR = "/app/settings" if DOCKERIZED else "."
+DEFAULT_CONFIG_FILE = os.path.join(CONFIG_DIR, "settings.ini")
+USER_CONFIG_FILE = os.path.join(CONFIG_DIR, "settings.user.ini")
+
+logger.info(
+ f"Configuration files: DEFAULT={DEFAULT_CONFIG_FILE}, USER={USER_CONFIG_FILE}"
+)
+
+config = configparser.ConfigParser()
+config.read([DEFAULT_CONFIG_FILE, USER_CONFIG_FILE])
+logger.info("Configuration files loaded")
+
+
+class Settings:
+ def get_api_key():
+ api_key = os.environ.get("REPLICATE_API_KEY") or config.get(
+ "secrets", "REPLICATE_API_KEY", fallback=None
+ )
+ if api_key:
+ logger.info("API key retrieved successfully")
+ else:
+ logger.warning("No API key found")
+ return api_key
+
+ def get_setting(
+ section: str, key: str, fallback: Any = None, value_type: Type[Any] = str
+ ) -> Any:
+ logger.info(
+ f"Attempting to get setting: section={section}, key={key}, fallback={fallback}, value_type={value_type}"
+ )
+ try:
+ value = config.get(section, key)
+ logger.debug(f"Raw value retrieved: {value}")
+ if value_type is int:
+ result = int(value)
+ elif value_type is float:
+ result = float(value)
+ elif value_type is bool:
+ result = value.lower() in ("true", "yes", "1", "on")
+ else:
+ result = value
+ logger.info(f"Setting retrieved successfully: {result}")
+ return result
+ except (configparser.NoSectionError, configparser.NoOptionError) as e:
+ logger.warning(
+ f"Setting not found: {str(e)}. Using fallback value: {fallback}"
+ )
+ return fallback
+ except ValueError as e:
+ logger.error(
+ f"Error converting setting value: {str(e)}. Using fallback value: {fallback}"
+ )
+ return fallback
+
+ def set_setting(section, key, value):
+ logger.info(f"Setting value: section={section}, key={key}, value={value}")
+ if not config.has_section(section):
+ logger.info(f"Creating new section: {section}")
+ config.add_section(section)
+ config.set(section, key, str(value))
+ logger.info("Value set successfully")
+
+ def save_settings():
+ logger.info(f"Saving settings to {USER_CONFIG_FILE}")
+ try:
+ with open(USER_CONFIG_FILE, "w") as configfile:
+ config.write(configfile)
+ logger.info("Settings saved successfully")
+ except IOError as e:
+ logger.error(f"Error saving settings: {str(e)}")
+
+
+logger.info("Config module initialized")