From 4ff613cdb70e37ac618cb5b3cff35d4eea7e2dc1 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Thu, 23 Oct 2025 08:05:33 +0200 Subject: [PATCH 01/19] Add progress Signed-off-by: Ashwin Vaidya --- .../src/api/endpoints/job_endpoints.py | 20 +- application/backend/src/db/schema.py | 1 + .../backend/src/pydantic_models/job.py | 5 + .../backend/src/services/job_service.py | 63 +++++- .../backend/src/services/training_service.py | 41 +++- application/backend/src/utils/callbacks.py | 182 ++++++++++++++++++ .../statusbar/items/progressbar.component.tsx | 182 ++++++++++++++++++ .../inspect/statusbar/statusbar.component.tsx | 16 ++ application/ui/src/routes/inspect/inspect.tsx | 42 ++-- 9 files changed, 519 insertions(+), 33 deletions(-) create mode 100644 application/backend/src/utils/callbacks.py create mode 100644 application/ui/src/features/inspect/statusbar/items/progressbar.component.tsx create mode 100644 application/ui/src/features/inspect/statusbar/statusbar.component.tsx diff --git a/application/backend/src/api/endpoints/job_endpoints.py b/application/backend/src/api/endpoints/job_endpoints.py index a0fba9149b..33bc6e46e7 100644 --- a/application/backend/src/api/endpoints/job_endpoints.py +++ b/application/backend/src/api/endpoints/job_endpoints.py @@ -10,7 +10,7 @@ from api.dependencies import get_job_id, get_job_service from api.endpoints import API_PREFIX from pydantic_models import JobList -from pydantic_models.job import JobSubmitted, TrainJobPayload +from pydantic_models.job import JobCancelled, JobSubmitted, TrainJobPayload from services import JobService job_api_prefix_url = API_PREFIX + "/jobs" @@ -42,3 +42,21 @@ async def get_job_logs( ) -> StreamingResponse: """Endpoint to get the logs of a job by its ID""" return StreamingResponse(job_service.stream_logs(job_id=job_id), media_type="text/event-stream") + + +@job_router.get("/{job_id}/progress") +async def get_job_progress( + job_id: Annotated[UUID, Depends(get_job_id)], + job_service: Annotated[JobService, Depends(get_job_service)], +) -> StreamingResponse: + """Endpoint to get the progress of a job by its ID""" + return StreamingResponse(job_service.stream_progress(job_id=job_id), media_type="text/event-stream") + + +@job_router.post("/{job_id}/cancel") +async def cancel_job( + job_id: Annotated[UUID, Depends(get_job_id)], + job_service: Annotated[JobService, Depends(get_job_service)], +) -> JobCancelled: + """Endpoint to cancel a job by its ID""" + return await job_service.cancel_job(job_id=job_id) diff --git a/application/backend/src/db/schema.py b/application/backend/src/db/schema.py index 566dd8cfe2..57205f2ca9 100644 --- a/application/backend/src/db/schema.py +++ b/application/backend/src/db/schema.py @@ -58,6 +58,7 @@ class JobDB(Base): id: Mapped[str] = mapped_column(primary_key=True, default=lambda: str(uuid4())) project_id: Mapped[str] = mapped_column(ForeignKey("projects.id")) type: Mapped[str] = mapped_column(String(64), nullable=False) + stage: Mapped[str] = mapped_column(String(64), nullable=False) # training, validation, test progress: Mapped[int] = mapped_column(nullable=False) status: Mapped[str] = mapped_column(String(64), nullable=False) message: Mapped[str] = mapped_column(Text, nullable=False) diff --git a/application/backend/src/pydantic_models/job.py b/application/backend/src/pydantic_models/job.py index 8e90c237f1..27caa74ea5 100644 --- a/application/backend/src/pydantic_models/job.py +++ b/application/backend/src/pydantic_models/job.py @@ -28,6 +28,7 @@ class Job(BaseIDModel): type: JobType = JobType.TRAINING progress: int = Field(default=0, ge=0, le=100, description="Progress percentage from 0 to 100") status: JobStatus = JobStatus.PENDING + stage: str = "idle" payload: dict message: str = "Job created" start_time: datetime | None = None @@ -46,6 +47,10 @@ class JobSubmitted(BaseModel): job_id: UUID +class JobCancelled(BaseModel): + job_id: UUID + + class TrainJobPayload(BaseModel): project_id: UUID = Field(exclude=True) model_name: str diff --git a/application/backend/src/services/job_service.py b/application/backend/src/services/job_service.py index 9079740c43..a537138d21 100644 --- a/application/backend/src/services/job_service.py +++ b/application/backend/src/services/job_service.py @@ -1,16 +1,20 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import asyncio +import json import os +from collections.abc import Coroutine +from typing import Any from uuid import UUID import anyio from sqlalchemy.exc import IntegrityError +from starlette.responses import AsyncContentStream from db import get_async_db_session_ctx from exceptions import DuplicateJobException, ResourceNotFoundException from pydantic_models import Job, JobList, JobType -from pydantic_models.job import JobStatus, JobSubmitted, TrainJobPayload +from pydantic_models.job import JobCancelled, JobStatus, JobSubmitted, TrainJobPayload from repositories import JobRepository @@ -54,7 +58,11 @@ async def get_pending_train_job() -> Job | None: @staticmethod async def update_job_status( - job_id: UUID, status: JobStatus, message: str | None = None, progress: int | None = None + job_id: UUID, + status: JobStatus, + message: str | None = None, + progress: int | None = None, + stage: str | None = None, ) -> None: async with get_async_db_session_ctx() as session: repo = JobRepository(session) @@ -67,8 +75,17 @@ async def update_job_status( progress_ = 100 if status is JobStatus.COMPLETED else progress if progress_ is not None: updates["progress"] = progress_ + if stage is not None: + updates["stage"] = stage await repo.update(job, updates) + @classmethod + async def is_job_still_running(cls, job_id: UUID | str) -> bool: + job = await cls.get_job_by_id(job_id=job_id) + if job is None: + raise ResourceNotFoundException(resource_id=job_id, resource_name="job") + return job.status == JobStatus.RUNNING + @classmethod async def stream_logs(cls, job_id: UUID | str): from core.logging import get_job_logs_path @@ -77,12 +94,6 @@ async def stream_logs(cls, job_id: UUID | str): if not os.path.exists(log_file): raise ResourceNotFoundException(resource_id=job_id, resource_name="job_logs") - async def is_job_still_running(): - job = await cls.get_job_by_id(job_id=job_id) - if job is None: - raise ResourceNotFoundException(resource_id=job_id, resource_name="job") - return job.status == JobStatus.RUNNING - # Cache job status and only check every 2 seconds status_check_interval = 2.0 # seconds last_status_check = 0.0 @@ -95,7 +106,7 @@ async def is_job_still_running(): now = loop.time() # Only check job status every status_check_interval seconds if now - last_status_check > status_check_interval: - cached_still_running = await is_job_still_running() + cached_still_running = await cls.is_job_still_running(job_id=job_id) last_status_check = now still_running = cached_still_running if not line: @@ -107,3 +118,37 @@ async def is_job_still_running(): else: break yield line + + @classmethod + async def stream_progress(cls, job_id: UUID | str) -> Coroutine[Any, Any, AsyncContentStream]: + """Stream the progress of a job by its ID""" + loop = asyncio.get_running_loop() + status_check_interval = 2.0 # seconds + last_status_check = 0.0 + cached_still_running = True + still_running = True + async with get_async_db_session_ctx() as session: + repo = JobRepository(session) + job = await repo.get_by_id(job_id) + if job is None: + raise ResourceNotFoundException(resource_id=job_id, resource_name="job") + while still_running: + now = loop.time() + if now - last_status_check > status_check_interval: + cached_still_running = await cls.is_job_still_running(job_id=job_id) + last_status_check = now + still_running = cached_still_running + yield json.dumps({"progress": job.progress, "stage": job.stage}) + await asyncio.sleep(0.5) + + @classmethod + async def cancel_job(cls, job_id: UUID | str) -> JobCancelled: + """Cancel a job by its ID""" + async with get_async_db_session_ctx() as session: + repo = JobRepository(session) + job = await repo.get_by_id(job_id) + if job is None: + raise ResourceNotFoundException(resource_id=job_id, resource_name="job") + + await repo.update(job, {"status": JobStatus.CANCELED}) + return JobCancelled(job_id=job.id) diff --git a/application/backend/src/services/training_service.py b/application/backend/src/services/training_service.py index 0ede80bae3..63ba0752a9 100644 --- a/application/backend/src/services/training_service.py +++ b/application/backend/src/services/training_service.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio from contextlib import redirect_stdout +from uuid import UUID from anomalib.data import Folder from anomalib.data.utils import TestSplitMode @@ -15,6 +16,7 @@ from repositories.binary_repo import ImageBinaryRepository, ModelBinaryRepository from services import ModelService from services.job_service import JobService +from utils.callbacks import GetiInspectProgressCallback, ProgressSyncParams from utils.experiment_loggers import TrackioLogger @@ -61,30 +63,40 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model: model_name = job.payload.get("model_name") if model_name is None: raise ValueError(f"Job {job.id} payload must contain 'model_name'") - + model_service = ModelService() model = Model( project_id=project_id, name=str(model_name), ) + synchronization_parameters = ProgressSyncParams() logger.info(f"Training model `{model_name}` for job `{job.id}`") try: + synchronization_task = asyncio.create_task( + cls._sync_progress_with_db( + job_service=job_service, job_id=job.id, synchronization_parameters=synchronization_parameters + ) + ) # Use asyncio.to_thread to keep event loop responsive # TODO: Consider ProcessPoolExecutor for true parallelism with multiple jobs - trained_model = await asyncio.to_thread(cls._train_model, model) + trained_model = await asyncio.to_thread(cls._train_model, model, synchronization_parameters) if trained_model is None: raise ValueError("Training failed - model is None") await job_service.update_job_status( job_id=job.id, status=JobStatus.COMPLETED, message="Training completed successfully" ) + logger.info("Syncing progress with db stopped") + synchronization_task.cancel() return await model_service.create_model(trained_model) except Exception as e: logger.exception("Failed to train pending training job: %s", e) await job_service.update_job_status( job_id=job.id, status=JobStatus.FAILED, message=f"Failed with exception: {str(e)}" ) + logger.info("Syncing progress with db stopped") + synchronization_task.cancel() if model.export_path: logger.warning(f"Deleting partially created model with id: {model.id}") model_binary_repo = ModelBinaryRepository(project_id=project_id, model_id=model.id) @@ -93,7 +105,7 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model: raise e @staticmethod - def _train_model(model: Model) -> Model | None: + def _train_model(model: Model, synchronization_parameters: ProgressSyncParams) -> Model | None: """ Execute CPU-intensive model training using anomalib. @@ -103,6 +115,7 @@ def _train_model(model: Model) -> Model | None: Args: model: Model object with training configuration + synchronization_parameters: Parameters for synchronization between the main process and the training process Returns: Model: Trained model with updated export_path and is_ready=True @@ -131,7 +144,9 @@ def _train_model(model: Model) -> Model | None: engine = Engine( default_root_dir=model.export_path, logger=[trackio, tensorboard], + devices=[0], max_epochs=10, + callbacks=[GetiInspectProgressCallback(synchronization_parameters)], ) # Execute training and export @@ -139,7 +154,7 @@ def _train_model(model: Model) -> Model | None: # Capture pytorch stdout logs into logger with redirect_stdout(LoggerStdoutWriter()): # type: ignore[type-var] - engine.train(model=anomalib_model, datamodule=datamodule) + engine.fit(model=anomalib_model, datamodule=datamodule) export_path = engine.export( model=anomalib_model, export_type=export_format, @@ -150,6 +165,24 @@ def _train_model(model: Model) -> Model | None: model.is_ready = True return model + @classmethod + async def _sync_progress_with_db( + cls, + job_service: JobService, + job_id: UUID, + synchronization_parameters: ProgressSyncParams, + ) -> None: + while True: + progress = synchronization_parameters.get_progress() + stage = synchronization_parameters.get_stage() + if not await job_service.is_job_still_running(job_id=job_id): + logger.info("Job cancelled, stopping progress sync") + synchronization_parameters.set_cancel_training_event() + break + logger.info(f"Syncing progress with db: {progress}% - {stage}") + await job_service.update_job_status(job_id=job_id, status=JobStatus.RUNNING, progress=progress, stage=stage) + await asyncio.sleep(0.1) + @staticmethod async def abort_orphan_jobs() -> None: """ diff --git a/application/backend/src/utils/callbacks.py b/application/backend/src/utils/callbacks.py new file mode 100644 index 0000000000..adfa3ea3d4 --- /dev/null +++ b/application/backend/src/utils/callbacks.py @@ -0,0 +1,182 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Lightning callback for sending progress to the frontend via the Plugin API.""" + +from __future__ import annotations + +import logging +from ctypes import c_char_p +from multiprocessing import Event, Value +from typing import TYPE_CHECKING, Any + +from lightning.pytorch.callbacks import Callback + +if TYPE_CHECKING: + from lightning.pytorch import LightningModule, Trainer + from lightning.pytorch.trainer.states import RunningStage + +logger = logging.getLogger(__name__) + + +class ProgressSyncParams: + def __init__(self) -> None: + self.progress = Value("f", 0.0) + self.stage = Value(c_char_p, b"idle") + self.cancel_training_event = Event() + + def set_stage(self, stage: str) -> None: + with self.stage.get_lock(): + self.stage.value = stage.encode("utf-8") + logger.info("Set stage: %s", stage) + + def get_stage(self) -> str: + return self.stage.value.decode("utf-8") + + def set_progress(self, progress: float) -> None: + with self.progress.get_lock(): + self.progress.value = progress + logger.info("Set progress: %s", progress) + + def get_progress(self) -> float: + return self.progress.value + + def set_cancel_training_event(self) -> None: + self.cancel_training_event.set() + logger.info("Set cancel training event") + + +class GetiInspectProgressCallback(Callback): + """Callback for displaying training/validation/testing progress in the Geti Inspect UI. + + This callback sends progress events through a multiprocessing queue that the + main process polls and broadcasts via WebSocket to connected frontend clients. + + Args: + synchronization_parameters: Parameters for synchronization between the main process and the training process + + Example: + trainer = Trainer(callbacks=[GetiInspectProgressCallback(synchronization_parameters=ProgressSyncParams())]) + """ + + def __init__(self, synchronization_parameters: ProgressSyncParams) -> None: + """Initialize the callback with synchronization parameters. + Args: + synchronization_parameters: Parameters for synchronization between the main process and the training process + """ + self.synchronization_parameters = synchronization_parameters + + def _check_cancel_training(self, trainer: Trainer) -> None: + """Check if training should be canceled.""" + if self.synchronization_parameters.cancel_training_event.is_set(): + trainer.should_stop = True + + def _send_progress(self, progress: float, stage: RunningStage) -> None: + """Send progress update to frontend via event queue. + Puts a generic event message into the multiprocessing queue which will + be picked up by the main process and broadcast via WebSocket. + Args: + progress: Progress value between 0.0 and 1.0 + stage: The current training stage + """ + # Convert progress to percentage (0-100) + progress_percent = int(progress * 100) + + try: + logger.info("Sent progress: %s - %d%%", stage.name, progress_percent) + self.synchronization_parameters.set_progress(progress_percent) + self.synchronization_parameters.set_stage(stage.name) + except Exception as e: + logger.warning("Failed to send progress to event queue: %s", e) + + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: + self._send_progress(0, stage) + + def teardown(self, trainer: Trainer, pl_module: LightningModule, stage: RunningStage) -> None: + self._send_progress(1.0, stage) + + # Training callbacks + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when training starts.""" + self._send_progress(0, trainer.state.stage) + self._check_cancel_training(trainer) + + def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None: + """Called when a training batch starts.""" + self._check_cancel_training(trainer) + + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when a training epoch ends.""" + progress = (trainer.current_epoch + 1) / trainer.max_epochs + self._send_progress(progress, trainer.state.stage) + self._check_cancel_training(trainer) + + def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when training ends.""" + self._send_progress(1.0, trainer.state.stage) + self._check_cancel_training(trainer) + + # Validation callbacks + def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when validation starts.""" + self._check_cancel_training(trainer) + + def on_validation_batch_start( + self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> None: + """Called when a validation batch starts.""" + self._check_cancel_training(trainer) + + def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when a validation epoch ends.""" + self._check_cancel_training(trainer) + + def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when validation ends.""" + self._check_cancel_training(trainer) + + # Test callbacks + def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when testing starts.""" + self._send_progress(0, trainer.state.stage) + self._check_cancel_training(trainer) + + def on_test_batch_start( + self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> None: + """Called when a test batch starts.""" + self._check_cancel_training(trainer) + + def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when a test epoch ends.""" + progress = (trainer.current_epoch + 1) / trainer.max_epochs if trainer.max_epochs else 0.5 + self._send_progress(progress, trainer.state.stage) + self._check_cancel_training(trainer) + + def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when testing ends.""" + self._send_progress(1.0, trainer.state.stage) + self._check_cancel_training(trainer) + + # Predict callbacks + def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when prediction starts.""" + self._send_progress(0, trainer.state.stage) + self._check_cancel_training(trainer) + + def on_predict_batch_start( + self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> None: + """Called when a prediction batch starts.""" + self._check_cancel_training(trainer) + + def on_predict_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when a prediction epoch ends.""" + progress = (trainer.current_epoch + 1) / trainer.max_epochs if trainer.max_epochs else 0.5 + self._send_progress(progress, trainer.state.stage) + self._check_cancel_training(trainer) + + def on_predict_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Called when prediction ends.""" + self._send_progress(1.0, trainer.state.stage) + self._check_cancel_training(trainer) diff --git a/application/ui/src/features/inspect/statusbar/items/progressbar.component.tsx b/application/ui/src/features/inspect/statusbar/items/progressbar.component.tsx new file mode 100644 index 0000000000..29be3c8c7c --- /dev/null +++ b/application/ui/src/features/inspect/statusbar/items/progressbar.component.tsx @@ -0,0 +1,182 @@ +import React, { useEffect, useState } from 'react'; + +import { $api, API_BASE_URL } from '@geti-inspect/api'; +import { SchemaJob as Job } from '@geti-inspect/api/spec'; +import { useProjectIdentifier } from '@geti-inspect/hooks'; +import { Flex, ProgressBar, Text } from '@geti/ui'; +import { CanceledIcon, WaitingIcon } from '@geti/ui/icons'; + +function IdleItem(): React.ReactNode { + return ( + + + + Idle + + + ); +} + +function TrainingStatusItem(progress: number, stage: string, onCancel?: () => void): React.ReactNode { + // Determine color based on stage + let bgcolor = 'var(--spectrum-global-color-blue-600)'; + let fgcolor = '#fff'; + if (stage.toLowerCase().includes('valid')) { + bgcolor = 'var(--spectrum-global-color-yellow-600)'; + fgcolor = '#000'; + } else if (stage.toLowerCase().includes('test')) { + bgcolor = 'var(--spectrum-global-color-green-600)'; + fgcolor = '#fff'; + } else if (stage.toLowerCase().includes('train') || stage.toLowerCase().includes('fit')) { + bgcolor = 'var(--spectrum-global-color-blue-600)'; + fgcolor = '#fff'; + } + + return ( +
+ + + + {stage} + + + +
+ ); +} + +function getElement(status: string, stage: string, progress: number, onCancel?: () => void): React.ReactNode { + if (status === 'running') { + return TrainingStatusItem(progress, stage, onCancel); + } + return IdleItem(); +} + +export const ProgressBarItem = () => { + const { projectId } = useProjectIdentifier(); + const [progress, setProgress] = useState(0); + const [stage, setStage] = useState(''); + const [jobStatus, setJobStatus] = useState('idle'); + const [currentJobId, setCurrentJobId] = useState(null); + + // Fetch the current running job for the project + const { data: jobsData } = $api.useQuery('get', '/api/jobs', undefined, { + refetchInterval: 2000, // Refetch every 2 seconds to check for new jobs + }); + + const cancelJobMutation = $api.useMutation('post', '/api/jobs/{job_id}/cancel'); + + // Find the running or pending job for this project + useEffect(() => { + if (!jobsData?.jobs) { + return; + } + + const runningJob = jobsData.jobs.find( + (job: Job) => job.project_id === projectId && (job.status === 'running' || job.status === 'pending') + ); + + if (runningJob) { + setCurrentJobId(runningJob.id ?? null); + setJobStatus(runningJob.status); + setProgress(runningJob.progress); + // Use stage from job if available, otherwise fallback to status + setStage(runningJob.stage || runningJob.status); + } else { + setCurrentJobId(null); + setJobStatus('idle'); + setProgress(0); + setStage(''); + } + }, [jobsData, projectId]); + + // Connect to SSE for progress updates when there's a running job + useEffect(() => { + if (!currentJobId || jobStatus !== 'running') { + return; + } + + const eventSource = new EventSource(`${API_BASE_URL}/api/jobs/${currentJobId}/progress`); + + eventSource.onmessage = (event) => { + try { + const data = JSON.parse(event.data); + if (data.progress !== undefined) { + setProgress(data.progress); + } + if (data.stage !== undefined) { + setStage(data.stage); + } + } catch (error) { + console.error('Failed to parse progress data:', error); + } + }; + + eventSource.onerror = (error) => { + console.error('EventSource error:', error); + eventSource.close(); + }; + + return () => { + eventSource.close(); + }; + }, [currentJobId, jobStatus]); + + const handleCancel = async () => { + if (!currentJobId) { + return; + } + + try { + await cancelJobMutation.mutateAsync({ + params: { + path: { + job_id: currentJobId, + }, + }, + }); + console.log('Job cancelled successfully'); + } catch (error) { + console.error('Failed to cancel job:', error); + } + }; + + return getElement(jobStatus, stage, progress, handleCancel); +}; diff --git a/application/ui/src/features/inspect/statusbar/statusbar.component.tsx b/application/ui/src/features/inspect/statusbar/statusbar.component.tsx new file mode 100644 index 0000000000..5d37881ac0 --- /dev/null +++ b/application/ui/src/features/inspect/statusbar/statusbar.component.tsx @@ -0,0 +1,16 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +import { Flex, View } from '@geti/ui'; + +import { ProgressBarItem } from './items/progressbar.component'; + +export const StatusBar = () => { + return ( + + + + + + ); +}; diff --git a/application/ui/src/routes/inspect/inspect.tsx b/application/ui/src/routes/inspect/inspect.tsx index 59e2e32206..65e43761b2 100644 --- a/application/ui/src/routes/inspect/inspect.tsx +++ b/application/ui/src/routes/inspect/inspect.tsx @@ -8,30 +8,34 @@ import { InferenceProvider } from '../../features/inspect/inference-provider.com import { InferenceResult } from '../../features/inspect/inference-result.component'; import { SelectedMediaItemProvider } from '../../features/inspect/selected-media-item-provider.component'; import { Sidebar } from '../../features/inspect/sidebar.component'; +import { StatusBar } from '../../features/inspect/statusbar/statusbar.component'; import { Toolbar } from '../../features/inspect/toolbar'; export const Inspect = () => { const { projectId } = useProjectIdentifier(); return ( - - - - - - - - - +
+ + + + + + + + + + +
); }; From f3855a2259ddf6062ff74b7e6170ee81a9519838 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Fri, 31 Oct 2025 09:43:15 +0100 Subject: [PATCH 02/19] Merge fixes Signed-off-by: Ashwin Vaidya --- application/ui/package-lock.json | 82 ++++++++++--------- application/ui/src/routes/inspect/inspect.tsx | 36 ++++---- 2 files changed, 62 insertions(+), 56 deletions(-) diff --git a/application/ui/package-lock.json b/application/ui/package-lock.json index 628145dd28..2bf26144ad 100644 --- a/application/ui/package-lock.json +++ b/application/ui/package-lock.json @@ -3208,13 +3208,13 @@ } }, "node_modules/@eslint/plugin-kit": { - "version": "0.3.2", - "resolved": "https://registry.npmjs.org/@eslint/plugin-kit/-/plugin-kit-0.3.2.tgz", - "integrity": "sha512-4SaFZCNfJqvk/kenHpI8xvN42DMaoycy4PzKc5otHxRswww1kAt82OlBuwRVLofCACCTZEcla2Ydxv8scMXaTg==", + "version": "0.3.5", + "resolved": "https://registry.npmjs.org/@eslint/plugin-kit/-/plugin-kit-0.3.5.tgz", + "integrity": "sha512-Z5kJ+wU3oA7MMIqVR9tyZRtjYPr4OC004Q4Rw7pgOKUOKkJfZ3O24nz3WYfGRpMDNmcOi3TwQOmgm7B7Tpii0w==", "dev": true, "license": "Apache-2.0", "dependencies": { - "@eslint/core": "^0.15.0", + "@eslint/core": "^0.15.2", "levn": "^0.4.1" }, "engines": { @@ -3222,9 +3222,9 @@ } }, "node_modules/@eslint/plugin-kit/node_modules/@eslint/core": { - "version": "0.15.0", - "resolved": "https://registry.npmjs.org/@eslint/core/-/core-0.15.0.tgz", - "integrity": "sha512-b7ePw78tEWWkpgZCDYkbqDOP8dmM6qe+AOC6iuJqlq1R/0ahMAeH3qynpnqKFGkMltrp44ohV4ubGyvLX28tzw==", + "version": "0.15.2", + "resolved": "https://registry.npmjs.org/@eslint/core/-/core-0.15.2.tgz", + "integrity": "sha512-78Md3/Rrxh83gCxoUc0EiciuOHsIITzLy53m3d9UyiW8y9Dj2D29FeETqyKA+BRK76tnTp6RXWb3pCay8Oyomg==", "dev": true, "license": "Apache-2.0", "dependencies": { @@ -4049,13 +4049,13 @@ "license": "MIT" }, "node_modules/@playwright/test": { - "version": "1.54.1", - "resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.54.1.tgz", - "integrity": "sha512-FS8hQ12acieG2dYSksmLOF7BNxnVf2afRJdCuM1eMSxj6QTSE6G4InGF7oApGgDb65MX7AwMVlIkpru0yZA4Xw==", + "version": "1.56.1", + "resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.56.1.tgz", + "integrity": "sha512-vSMYtL/zOcFpvJCW71Q/OEGQb7KYBPAdKh35WNSkaZA75JlAO8ED8UN6GUNTm3drWomcbcqRPFqQbLae8yBTdg==", "dev": true, "license": "Apache-2.0", "dependencies": { - "playwright": "1.54.1" + "playwright": "1.56.1" }, "bin": { "playwright": "cli.js" @@ -16923,13 +16923,13 @@ } }, "node_modules/playwright": { - "version": "1.54.1", - "resolved": "https://registry.npmjs.org/playwright/-/playwright-1.54.1.tgz", - "integrity": "sha512-peWpSwIBmSLi6aW2auvrUtf2DqY16YYcCMO8rTVx486jKmDTJg7UAhyrraP98GB8BoPURZP8+nxO7TSd4cPr5g==", + "version": "1.56.1", + "resolved": "https://registry.npmjs.org/playwright/-/playwright-1.56.1.tgz", + "integrity": "sha512-aFi5B0WovBHTEvpM3DzXTUaeN6eN0qWnTkKx4NQaH4Wvcmc153PdaY2UBdSYKaGYw+UyWXSVyxDUg5DoPEttjw==", "dev": true, "license": "Apache-2.0", "dependencies": { - "playwright-core": "1.54.1" + "playwright-core": "1.56.1" }, "bin": { "playwright": "cli.js" @@ -16942,9 +16942,9 @@ } }, "node_modules/playwright-core": { - "version": "1.54.1", - "resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.54.1.tgz", - "integrity": "sha512-Nbjs2zjj0htNhzgiy5wu+3w09YetDx5pkrpI/kZotDlDUaYk0HVA5xrBVPdow4SAUIlhgKcJeJg4GRKW6xHusA==", + "version": "1.56.1", + "resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.56.1.tgz", + "integrity": "sha512-hutraynyn31F+Bifme+Ps9Vq59hKuUCz7H1kDOcBs+2oGguKkWTU50bBWrtz34OUWmIwpBTWDxaRPXrIXkgvmQ==", "dev": true, "license": "Apache-2.0", "bin": { @@ -19231,14 +19231,14 @@ "license": "MIT" }, "node_modules/tinyglobby": { - "version": "0.2.14", - "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.14.tgz", - "integrity": "sha512-tX5e7OM1HnYr2+a2C/4V0htOcSQcoSTH9KgJnVvNm5zm/cyEWKJ7j7YutsH9CxMdtOkkLFy2AHrMci9IM8IPZQ==", + "version": "0.2.15", + "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz", + "integrity": "sha512-j2Zq4NyQYG5XMST4cbs02Ak8iJUdxRM0XI5QyxXuZOzKOINmWurp3smXu3y5wDcJrptwpSjgXHzIQxR0omXljQ==", "dev": true, "license": "MIT", "dependencies": { - "fdir": "^6.4.4", - "picomatch": "^4.0.2" + "fdir": "^6.5.0", + "picomatch": "^4.0.3" }, "engines": { "node": ">=12.0.0" @@ -19248,11 +19248,14 @@ } }, "node_modules/tinyglobby/node_modules/fdir": { - "version": "6.4.6", - "resolved": "https://registry.npmjs.org/fdir/-/fdir-6.4.6.tgz", - "integrity": "sha512-hiFoqpyZcfNm1yc4u8oWCf9A2c4D3QjCrks3zmoVKVxpQRzmPNar1hUJcBG2RQHvEVGDN+Jm81ZheVLAQMK6+w==", + "version": "6.5.0", + "resolved": "https://registry.npmjs.org/fdir/-/fdir-6.5.0.tgz", + "integrity": "sha512-tIbYtZbucOs0BRGqPJkshJUYdL+SDH7dVM8gjy+ERp3WAUjLEFJE+02kanyHtwjWOnwrKYBiwAmM0p4kLJAnXg==", "dev": true, "license": "MIT", + "engines": { + "node": ">=12.0.0" + }, "peerDependencies": { "picomatch": "^3 || ^4" }, @@ -19263,9 +19266,9 @@ } }, "node_modules/tinyglobby/node_modules/picomatch": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.2.tgz", - "integrity": "sha512-M7BAV6Rlcy5u+m6oPhAPFgJTzAioX/6B0DxyvDlo9l8+T3nLKbrczg2WLUyzd45L8RqfUMyGPzekbMvX2Ldkwg==", + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", + "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", "engines": { @@ -20008,18 +20011,18 @@ } }, "node_modules/vite": { - "version": "7.0.6", - "resolved": "https://registry.npmjs.org/vite/-/vite-7.0.6.tgz", - "integrity": "sha512-MHFiOENNBd+Bd9uvc8GEsIzdkn1JxMmEeYX35tI3fv0sJBUTfW5tQsoaOwuY4KhBI09A3dUJ/DXf2yxPVPUceg==", + "version": "7.1.12", + "resolved": "https://registry.npmjs.org/vite/-/vite-7.1.12.tgz", + "integrity": "sha512-ZWyE8YXEXqJrrSLvYgrRP7p62OziLW7xI5HYGWFzOvupfAlrLvURSzv/FyGyy0eidogEM3ujU+kUG1zuHgb6Ug==", "dev": true, "license": "MIT", "dependencies": { "esbuild": "^0.25.0", - "fdir": "^6.4.6", + "fdir": "^6.5.0", "picomatch": "^4.0.3", "postcss": "^8.5.6", - "rollup": "^4.40.0", - "tinyglobby": "^0.2.14" + "rollup": "^4.43.0", + "tinyglobby": "^0.2.15" }, "bin": { "vite": "bin/vite.js" @@ -20121,11 +20124,14 @@ } }, "node_modules/vite/node_modules/fdir": { - "version": "6.4.6", - "resolved": "https://registry.npmjs.org/fdir/-/fdir-6.4.6.tgz", - "integrity": "sha512-hiFoqpyZcfNm1yc4u8oWCf9A2c4D3QjCrks3zmoVKVxpQRzmPNar1hUJcBG2RQHvEVGDN+Jm81ZheVLAQMK6+w==", + "version": "6.5.0", + "resolved": "https://registry.npmjs.org/fdir/-/fdir-6.5.0.tgz", + "integrity": "sha512-tIbYtZbucOs0BRGqPJkshJUYdL+SDH7dVM8gjy+ERp3WAUjLEFJE+02kanyHtwjWOnwrKYBiwAmM0p4kLJAnXg==", "dev": true, "license": "MIT", + "engines": { + "node": ">=12.0.0" + }, "peerDependencies": { "picomatch": "^3 || ^4" }, diff --git a/application/ui/src/routes/inspect/inspect.tsx b/application/ui/src/routes/inspect/inspect.tsx index ad9521755e..61f131d8d5 100644 --- a/application/ui/src/routes/inspect/inspect.tsx +++ b/application/ui/src/routes/inspect/inspect.tsx @@ -16,26 +16,26 @@ export const Inspect = () => { return (
- + - - - - + + + + + - - - + +
); }; From 12130d273c69f6b84f680903607af064372152f5 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Fri, 31 Oct 2025 14:00:25 +0100 Subject: [PATCH 03/19] Fix progress bar Signed-off-by: Ashwin Vaidya --- .../src/api/endpoints/job_endpoints.py | 5 +- application/backend/src/db/schema.py | 6 +- .../backend/src/pydantic_models/job.py | 24 ++++- .../backend/src/services/job_service.py | 27 +++--- .../backend/src/services/training_service.py | 40 +++++---- application/backend/src/utils/callbacks.py | 88 ++++++++++--------- .../statusbar/items/progressbar.component.tsx | 32 +++++-- 7 files changed, 130 insertions(+), 92 deletions(-) diff --git a/application/backend/src/api/endpoints/job_endpoints.py b/application/backend/src/api/endpoints/job_endpoints.py index 4aac68d0f1..bb01e187fb 100644 --- a/application/backend/src/api/endpoints/job_endpoints.py +++ b/application/backend/src/api/endpoints/job_endpoints.py @@ -6,7 +6,6 @@ from fastapi import APIRouter, Body, Depends, status from sse_starlette import EventSourceResponse -from starlette.responses import StreamingResponse from api.dependencies import get_job_id, get_job_service from api.endpoints import API_PREFIX @@ -49,9 +48,9 @@ async def get_job_logs( async def get_job_progress( job_id: Annotated[UUID, Depends(get_job_id)], job_service: Annotated[JobService, Depends(get_job_service)], -) -> StreamingResponse: +) -> EventSourceResponse: """Endpoint to get the progress of a job by its ID""" - return StreamingResponse(job_service.stream_progress(job_id=job_id), media_type="text/event-stream") + return EventSourceResponse(job_service.stream_progress(job_id=job_id)) @job_router.post("/{job_id}:cancel", status_code=status.HTTP_202_ACCEPTED) diff --git a/application/backend/src/db/schema.py b/application/backend/src/db/schema.py index b3c3b2665c..c2a895602e 100644 --- a/application/backend/src/db/schema.py +++ b/application/backend/src/db/schema.py @@ -4,10 +4,12 @@ from datetime import datetime from uuid import uuid4 -from sqlalchemy import JSON, Boolean, DateTime, Float, ForeignKey, String, Text +from sqlalchemy import JSON, Boolean, DateTime, Enum, Float, ForeignKey, String, Text from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.sql import func +from pydantic_models.job import JobStage + class Base(DeclarativeBase): pass @@ -59,7 +61,7 @@ class JobDB(Base): id: Mapped[str] = mapped_column(primary_key=True, default=lambda: str(uuid4())) project_id: Mapped[str] = mapped_column(ForeignKey("projects.id")) type: Mapped[str] = mapped_column(String(64), nullable=False) - stage: Mapped[str] = mapped_column(String(64), nullable=False) # training, validation, test + stage: Mapped[JobStage] = mapped_column(Enum(JobStage), nullable=False) progress: Mapped[int] = mapped_column(nullable=False) status: Mapped[str] = mapped_column(String(64), nullable=False) message: Mapped[str] = mapped_column(Text, nullable=False) diff --git a/application/backend/src/pydantic_models/job.py b/application/backend/src/pydantic_models/job.py index 93b21683ee..4c1bf920ca 100644 --- a/application/backend/src/pydantic_models/job.py +++ b/application/backend/src/pydantic_models/job.py @@ -5,7 +5,7 @@ from typing import Any from uuid import UUID -from pydantic import BaseModel, Field, field_serializer +from pydantic import BaseModel, Field, computed_field, field_serializer from pydantic_models.base import BaseIDModel @@ -23,12 +23,26 @@ class JobStatus(StrEnum): CANCELED = "canceled" +class JobStage(StrEnum): + """Job stages follow PyTorch Lightning stages with the addition of idle stage. + + See ``lightning.pytorch.trainer.states.RunningStage`` for more details. + """ + + IDLE = "idle" + TRAINING = "train" + SANITY_CHECKING = "sanity_check" + VALIDATING = "validate" + TESTING = "test" + PREDICTING = "predict" + + class Job(BaseIDModel): project_id: UUID type: JobType = JobType.TRAINING progress: int = Field(default=0, ge=0, le=100, description="Progress percentage from 0 to 100") status: JobStatus = JobStatus.PENDING - stage: str = "idle" + stage: JobStage = JobStage.IDLE payload: dict message: str = "Job created" start_time: datetime | None = None @@ -50,8 +64,12 @@ class JobSubmitted(BaseModel): class JobCancelled(BaseModel): job_id: UUID + @computed_field + def message(self) -> str: + return f" Job with ID `{self.job_id}' marked as cancelled." + class TrainJobPayload(BaseModel): project_id: UUID = Field(exclude=True) model_name: str - device: str | None + device: str | None = None diff --git a/application/backend/src/services/job_service.py b/application/backend/src/services/job_service.py index 2aee4fe3fe..a56e60c5d1 100644 --- a/application/backend/src/services/job_service.py +++ b/application/backend/src/services/job_service.py @@ -3,6 +3,7 @@ import asyncio import datetime import json +import logging import os from collections.abc import AsyncGenerator, Coroutine from typing import Any @@ -16,9 +17,11 @@ from db import get_async_db_session_ctx from exceptions import DuplicateJobException, ResourceNotFoundException from pydantic_models import Job, JobList, JobType -from pydantic_models.job import JobCancelled, JobStatus, JobSubmitted, TrainJobPayload +from pydantic_models.job import JobCancelled, JobStage, JobStatus, JobSubmitted, TrainJobPayload from repositories import JobRepository +logger = logging.getLogger(__name__) + class JobService: @staticmethod @@ -64,7 +67,7 @@ async def update_job_status( status: JobStatus, message: str | None = None, progress: int | None = None, - stage: str | None = None, + stage: JobStage | None = None, ) -> None: async with get_async_db_session_ctx() as session: repo = JobRepository(session) @@ -128,24 +131,14 @@ async def stream_logs(cls, job_id: UUID | str) -> AsyncGenerator[ServerSentEvent @classmethod async def stream_progress(cls, job_id: UUID | str) -> Coroutine[Any, Any, AsyncContentStream]: """Stream the progress of a job by its ID""" - loop = asyncio.get_running_loop() - status_check_interval = 2.0 # seconds - last_status_check = 0.0 - cached_still_running = True still_running = True - async with get_async_db_session_ctx() as session: - repo = JobRepository(session) - job = await repo.get_by_id(job_id) + while still_running: + job = await cls.get_job_by_id(job_id=job_id) if job is None: raise ResourceNotFoundException(resource_id=job_id, resource_name="job") - while still_running: - now = loop.time() - if now - last_status_check > status_check_interval: - cached_still_running = await cls.is_job_still_running(job_id=job_id) - last_status_check = now - still_running = cached_still_running - yield json.dumps({"progress": job.progress, "stage": job.stage}) - await asyncio.sleep(0.5) + yield ServerSentEvent(data=json.dumps({"progress": job.progress, "stage": job.stage})) + still_running = job.status in {JobStatus.RUNNING, JobStatus.PENDING} + await asyncio.sleep(0.5) @classmethod async def cancel_job(cls, job_id: UUID | str) -> JobCancelled: diff --git a/application/backend/src/services/training_service.py b/application/backend/src/services/training_service.py index 3a6d0f7390..7361976f4a 100644 --- a/application/backend/src/services/training_service.py +++ b/application/backend/src/services/training_service.py @@ -84,7 +84,10 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model: # Use asyncio.to_thread to keep event loop responsive # TODO: Consider ProcessPoolExecutor for true parallelism with multiple jobs trained_model = await asyncio.to_thread( - cls._train_model, model=model, device=device, synchronization_parameters=synchronization_parameters + cls._train_model, + model=model, + device=device, + synchronization_parameters=synchronization_parameters, ) if trained_model is None: raise ValueError("Training failed - model is None") @@ -92,7 +95,7 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model: await job_service.update_job_status( job_id=job.id, status=JobStatus.COMPLETED, message="Training completed successfully" ) - logger.info("Syncing progress with db stopped") + logger.debug("Syncing progress with db stopped") synchronization_task.cancel() return await model_service.create_model(trained_model) except Exception as e: @@ -100,7 +103,7 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model: await job_service.update_job_status( job_id=job.id, status=JobStatus.FAILED, message=f"Failed with exception: {str(e)}" ) - logger.info("Syncing progress with db stopped") + logger.debug("Syncing progress with db stopped") synchronization_task.cancel() if model.export_path: logger.warning(f"Deleting partially created model with id: {model.id}") @@ -136,8 +139,9 @@ def _train_model( f"Device '{device}' is not supported for training. " f"Supported devices: {', '.join(Devices.training_devices())}" ) + device = device or "auto" - logger.info(f"Training on device: {device or 'auto'}") + logger.info(f"Training on device: {device}") model_binary_repo = ModelBinaryRepository(project_id=model.project_id, model_id=model.id) image_binary_repo = ImageBinaryRepository(project_id=model.project_id) @@ -161,7 +165,7 @@ def _train_model( engine = Engine( default_root_dir=model.export_path, logger=[trackio, tensorboard], - devices=[0], + devices=[0], # Only single GPU training is supported for now max_epochs=10, callbacks=[GetiInspectProgressCallback(synchronization_parameters)], accelerator=device, @@ -197,16 +201,22 @@ async def _sync_progress_with_db( job_id: UUID, synchronization_parameters: ProgressSyncParams, ) -> None: - while True: - progress = synchronization_parameters.get_progress() - stage = synchronization_parameters.get_stage() - if not await job_service.is_job_still_running(job_id=job_id): - logger.info("Job cancelled, stopping progress sync") - synchronization_parameters.set_cancel_training_event() - break - logger.info(f"Syncing progress with db: {progress}% - {stage}") - await job_service.update_job_status(job_id=job_id, status=JobStatus.RUNNING, progress=progress, stage=stage) - await asyncio.sleep(0.1) + try: + while True: + progress: int = synchronization_parameters.progress + stage = synchronization_parameters.stage + if not await job_service.is_job_still_running(job_id=job_id): + logger.debug("Job cancelled, stopping progress sync") + synchronization_parameters.set_cancel_training_event() + break + logger.debug(f"Syncing progress with db: {progress}% - {stage}") + await job_service.update_job_status( + job_id=job_id, status=JobStatus.RUNNING, progress=progress, stage=stage + ) + await asyncio.sleep(0.5) + except Exception as e: + logger.exception("Failed to sync progress with db: %s", e) + raise @staticmethod async def abort_orphan_jobs() -> None: diff --git a/application/backend/src/utils/callbacks.py b/application/backend/src/utils/callbacks.py index adfa3ea3d4..0bdbd43be1 100644 --- a/application/backend/src/utils/callbacks.py +++ b/application/backend/src/utils/callbacks.py @@ -6,44 +6,52 @@ from __future__ import annotations import logging -from ctypes import c_char_p -from multiprocessing import Event, Value +import threading from typing import TYPE_CHECKING, Any from lightning.pytorch.callbacks import Callback +from pydantic_models.job import JobStage + if TYPE_CHECKING: from lightning.pytorch import LightningModule, Trainer - from lightning.pytorch.trainer.states import RunningStage logger = logging.getLogger(__name__) class ProgressSyncParams: def __init__(self) -> None: - self.progress = Value("f", 0.0) - self.stage = Value(c_char_p, b"idle") - self.cancel_training_event = Event() - - def set_stage(self, stage: str) -> None: - with self.stage.get_lock(): - self.stage.value = stage.encode("utf-8") - logger.info("Set stage: %s", stage) - - def get_stage(self) -> str: - return self.stage.value.decode("utf-8") - - def set_progress(self, progress: float) -> None: - with self.progress.get_lock(): - self.progress.value = progress - logger.info("Set progress: %s", progress) - - def get_progress(self) -> float: - return self.progress.value + self._progress = 0 + self._stage = JobStage.IDLE + self._lock = threading.Lock() + self.cancel_training_event = threading.Event() + + @property + def stage(self) -> JobStage: + with self._lock: + return self._stage + + @stage.setter + def stage(self, stage: JobStage) -> None: + with self._lock: + self._stage = stage + logger.debug("Stage updated: %s", stage) + + @property + def progress(self) -> int: + with self._lock: + return self._progress + + @progress.setter + def progress(self, progress: int) -> None: + with self._lock: + self._progress = progress + logger.debug("Progress updated: %s", progress) def set_cancel_training_event(self) -> None: - self.cancel_training_event.set() - logger.info("Set cancel training event") + with self._lock: + self.cancel_training_event.set() + logger.debug("Set cancel training event") class GetiInspectProgressCallback(Callback): @@ -71,7 +79,7 @@ def _check_cancel_training(self, trainer: Trainer) -> None: if self.synchronization_parameters.cancel_training_event.is_set(): trainer.should_stop = True - def _send_progress(self, progress: float, stage: RunningStage) -> None: + def _send_progress(self, progress: float, stage: JobStage) -> None: """Send progress update to frontend via event queue. Puts a generic event message into the multiprocessing queue which will be picked up by the main process and broadcast via WebSocket. @@ -83,22 +91,16 @@ def _send_progress(self, progress: float, stage: RunningStage) -> None: progress_percent = int(progress * 100) try: - logger.info("Sent progress: %s - %d%%", stage.name, progress_percent) - self.synchronization_parameters.set_progress(progress_percent) - self.synchronization_parameters.set_stage(stage.name) + logger.debug("Sent progress: %s - %d%%", stage, progress_percent) + self.synchronization_parameters.progress = progress_percent + self.synchronization_parameters.stage = stage except Exception as e: logger.warning("Failed to send progress to event queue: %s", e) - def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: - self._send_progress(0, stage) - - def teardown(self, trainer: Trainer, pl_module: LightningModule, stage: RunningStage) -> None: - self._send_progress(1.0, stage) - # Training callbacks def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: """Called when training starts.""" - self._send_progress(0, trainer.state.stage) + self._send_progress(0, JobStage.TRAINING) self._check_cancel_training(trainer) def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None: @@ -108,12 +110,12 @@ def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, bat def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Called when a training epoch ends.""" progress = (trainer.current_epoch + 1) / trainer.max_epochs - self._send_progress(progress, trainer.state.stage) + self._send_progress(progress, JobStage.TRAINING) self._check_cancel_training(trainer) def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Called when training ends.""" - self._send_progress(1.0, trainer.state.stage) + self._send_progress(1.0, JobStage.TRAINING) self._check_cancel_training(trainer) # Validation callbacks @@ -138,7 +140,7 @@ def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> Non # Test callbacks def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: """Called when testing starts.""" - self._send_progress(0, trainer.state.stage) + self._send_progress(0, JobStage.TESTING) self._check_cancel_training(trainer) def on_test_batch_start( @@ -150,18 +152,18 @@ def on_test_batch_start( def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Called when a test epoch ends.""" progress = (trainer.current_epoch + 1) / trainer.max_epochs if trainer.max_epochs else 0.5 - self._send_progress(progress, trainer.state.stage) + self._send_progress(progress, JobStage.TESTING) self._check_cancel_training(trainer) def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Called when testing ends.""" - self._send_progress(1.0, trainer.state.stage) + self._send_progress(1.0, JobStage.TESTING) self._check_cancel_training(trainer) # Predict callbacks def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None: """Called when prediction starts.""" - self._send_progress(0, trainer.state.stage) + self._send_progress(0, JobStage.PREDICTING) self._check_cancel_training(trainer) def on_predict_batch_start( @@ -173,10 +175,10 @@ def on_predict_batch_start( def on_predict_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Called when a prediction epoch ends.""" progress = (trainer.current_epoch + 1) / trainer.max_epochs if trainer.max_epochs else 0.5 - self._send_progress(progress, trainer.state.stage) + self._send_progress(progress, JobStage.PREDICTING) self._check_cancel_training(trainer) def on_predict_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Called when prediction ends.""" - self._send_progress(1.0, trainer.state.stage) + self._send_progress(1.0, JobStage.PREDICTING) self._check_cancel_training(trainer) diff --git a/application/ui/src/features/inspect/statusbar/items/progressbar.component.tsx b/application/ui/src/features/inspect/statusbar/items/progressbar.component.tsx index 29be3c8c7c..f7f66d0145 100644 --- a/application/ui/src/features/inspect/statusbar/items/progressbar.component.tsx +++ b/application/ui/src/features/inspect/statusbar/items/progressbar.component.tsx @@ -52,7 +52,7 @@ function TrainingStatusItem(progress: number, stage: string, onCancel?: () => vo + + {stage} + + + + + ); +}; + +const useCurrentJob = () => { + const { data: jobsData } = $api.useSuspenseQuery('get', '/api/jobs', undefined, { + refetchInterval: 5000, + }); + + const { projectId } = useProjectIdentifier(); + const runningJob = jobsData.jobs.find( + (job: Job) => job.project_id === projectId && (job.status === 'running' || job.status === 'pending') + ); + + return runningJob; +}; + +export const ProgressBarItem = () => { + const trainingJob = useCurrentJob(); + + if (trainingJob !== undefined) { + return ; + } + + return ; +}; + +export const Footer = () => { + return ( + + + + + + ); +}; diff --git a/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx b/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx index 347bbc9253..b3ef175f2e 100644 --- a/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx +++ b/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx @@ -17,56 +17,7 @@ import { } from '@geti/ui'; import { LogsIcon } from '@geti/ui/icons'; import { queryOptions, experimental_streamedQuery as streamedQuery, useQuery } from '@tanstack/react-query'; - -// Connect to an SSE endpoint and yield its messages -function fetchSSE(url: string) { - return { - async *[Symbol.asyncIterator]() { - const eventSource = new EventSource(url); - - try { - let { promise, resolve, reject } = Promise.withResolvers(); - - eventSource.onmessage = (event) => { - if (event.data === 'DONE' || event.data.includes('COMPLETED')) { - eventSource.close(); - resolve('DONE'); - return; - } - resolve(event.data); - }; - - eventSource.onerror = (error) => { - eventSource.close(); - reject(new Error('EventSource failed: ' + error)); - }; - - // Keep yielding data as it comes in - while (true) { - const message = await promise; - - // If server sends 'DONE' message or similar, break the loop - if (message === 'DONE') { - break; - } - - try { - const data = JSON.parse(message); - if (data['text']) { - yield data['text']; - } - } catch { - console.error('Could not parse message:', message); - } - - ({ promise, resolve, reject } = Promise.withResolvers()); - } - } finally { - eventSource.close(); - } - }, - }; -} +import { fetchSSE } from 'src/api/fetch-sse'; const JobLogsDialogContent = ({ jobId }: { jobId: string }) => { const query = useQuery( @@ -81,7 +32,7 @@ const JobLogsDialogContent = ({ jobId }: { jobId: string }) => { return ( - {query.data?.map((line, idx) => {line})} + {query.data?.map((line, idx) => {line.text})} ); }; diff --git a/application/ui/src/features/inspect/statusbar/items/progressbar.component.tsx b/application/ui/src/features/inspect/statusbar/items/progressbar.component.tsx deleted file mode 100644 index f7f66d0145..0000000000 --- a/application/ui/src/features/inspect/statusbar/items/progressbar.component.tsx +++ /dev/null @@ -1,196 +0,0 @@ -import React, { useEffect, useState } from 'react'; - -import { $api, API_BASE_URL } from '@geti-inspect/api'; -import { SchemaJob as Job } from '@geti-inspect/api/spec'; -import { useProjectIdentifier } from '@geti-inspect/hooks'; -import { Flex, ProgressBar, Text } from '@geti/ui'; -import { CanceledIcon, WaitingIcon } from '@geti/ui/icons'; - -function IdleItem(): React.ReactNode { - return ( - - - - Idle - - - ); -} - -function TrainingStatusItem(progress: number, stage: string, onCancel?: () => void): React.ReactNode { - // Determine color based on stage - let bgcolor = 'var(--spectrum-global-color-blue-600)'; - let fgcolor = '#fff'; - if (stage.toLowerCase().includes('valid')) { - bgcolor = 'var(--spectrum-global-color-yellow-600)'; - fgcolor = '#000'; - } else if (stage.toLowerCase().includes('test')) { - bgcolor = 'var(--spectrum-global-color-green-600)'; - fgcolor = '#fff'; - } else if (stage.toLowerCase().includes('train') || stage.toLowerCase().includes('fit')) { - bgcolor = 'var(--spectrum-global-color-blue-600)'; - fgcolor = '#fff'; - } - - return ( -
- - - - {stage} - - - -
- ); -} - -function getElement(status: string, stage: string, progress: number, onCancel?: () => void): React.ReactNode { - if (status === 'running') { - return TrainingStatusItem(progress, stage, onCancel); - } - return IdleItem(); -} - -export const ProgressBarItem = () => { - const { projectId } = useProjectIdentifier(); - const [progress, setProgress] = useState(0); - const [stage, setStage] = useState(''); - const [jobStatus, setJobStatus] = useState('idle'); - const [currentJobId, setCurrentJobId] = useState(null); - const [sseConnected, setSSEConnected] = useState(false); - - // Fetch the current running job for the project - const { data: jobsData } = $api.useQuery('get', '/api/jobs', undefined, { - refetchInterval: 5000, // Refetch every 5 seconds to check for new jobs - }); - - const cancelJobMutation = $api.useMutation('post', '/api/jobs/{job_id}:cancel'); - - // Find the running or pending job for this project - useEffect(() => { - if (!jobsData?.jobs) { - return; - } - - const runningJob = jobsData.jobs.find( - (job: Job) => job.project_id === projectId && (job.status === 'running' || job.status === 'pending') - ); - - if (runningJob) { - setCurrentJobId(runningJob.id ?? null); - setJobStatus(runningJob.status); - - // Only use polling data if SSE is not connected (fallback) - if (!sseConnected) { - setProgress(runningJob.progress); - setStage(runningJob.stage || runningJob.status); - } - } else { - // No running job found - reset everything - setCurrentJobId(null); - setJobStatus('idle'); - setProgress(0); - setStage(''); - setSSEConnected(false); - } - }, [jobsData, projectId, sseConnected]); - - // Connect to SSE for progress updates when there's a running job - useEffect(() => { - if (!currentJobId || jobStatus !== 'running') { - setSSEConnected(false); - return; - } - - const eventSource = new EventSource(`${API_BASE_URL}/api/jobs/${currentJobId}/progress`); - - eventSource.onopen = () => { - setSSEConnected(true); - console.debug('SSE connected'); - }; - - eventSource.onmessage = (event) => { - try { - const data = JSON.parse(event.data); - console.debug('SSE data:', data); - if (data.progress !== undefined) { - setProgress(data.progress); - } - if (data.stage !== undefined) { - setStage(data.stage); - } - } catch (error) { - console.error('Failed to parse progress data:', error); - } - }; - - eventSource.onerror = (error) => { - setSSEConnected(false); - eventSource.close(); - }; - - return () => { - setSSEConnected(false); - eventSource.close(); - }; - }, [currentJobId, jobStatus]); - - const handleCancel = async () => { - if (!currentJobId) { - return; - } - - try { - await cancelJobMutation.mutateAsync({ - params: { - path: { - job_id: currentJobId, - }, - }, - }); - console.info('Job cancelled successfully'); - } catch (error) { - console.error('Failed to cancel job:', error); - } - }; - - return getElement(jobStatus, stage, progress, handleCancel); -}; diff --git a/application/ui/src/features/inspect/statusbar/statusbar.component.tsx b/application/ui/src/features/inspect/statusbar/statusbar.component.tsx deleted file mode 100644 index 5d37881ac0..0000000000 --- a/application/ui/src/features/inspect/statusbar/statusbar.component.tsx +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (C) 2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -import { Flex, View } from '@geti/ui'; - -import { ProgressBarItem } from './items/progressbar.component'; - -export const StatusBar = () => { - return ( - - - - - - ); -}; diff --git a/application/ui/src/routes/inspect/inspect.tsx b/application/ui/src/routes/inspect/inspect.tsx index 61f131d8d5..738038a81b 100644 --- a/application/ui/src/routes/inspect/inspect.tsx +++ b/application/ui/src/routes/inspect/inspect.tsx @@ -4,38 +4,36 @@ import { useProjectIdentifier } from '@geti-inspect/hooks'; import { Grid } from '@geti/ui'; +import { Footer } from '../../features/inspect/footer/footer.component'; import { InferenceProvider } from '../../features/inspect/inference-provider.component'; import { InferenceResult } from '../../features/inspect/inference-result.component'; import { SelectedMediaItemProvider } from '../../features/inspect/selected-media-item-provider.component'; import { Sidebar } from '../../features/inspect/sidebar.component'; -import { StatusBar } from '../../features/inspect/statusbar/statusbar.component'; import { Toolbar } from '../../features/inspect/toolbar'; export const Inspect = () => { const { projectId } = useProjectIdentifier(); return ( -
- - - - - - - - - - -
+ + + + + + +