From c019e96cfa965093e2735a2b3a0c49de8d20a108 Mon Sep 17 00:00:00 2001 From: Dasemu Date: Fri, 16 Jan 2026 16:56:42 +0100 Subject: [PATCH] feat(workers): add multiprocessing worker pool system - Add Worker class with CPU/GPU support - Add WorkerPool for orchestrating multiple workers - Support dynamic add/remove workers at runtime - Add health monitoring with graceful shutdown --- backend/core/worker.py | 336 ++++++++++++++++++++++++++++------- backend/core/worker_pool.py | 339 ++++++++++++++++++++++++++++++++++++ 2 files changed, 611 insertions(+), 64 deletions(-) create mode 100644 backend/core/worker_pool.py diff --git a/backend/core/worker.py b/backend/core/worker.py index 83c12ae..dd32b26 100644 --- a/backend/core/worker.py +++ b/backend/core/worker.py @@ -1,10 +1,11 @@ """Individual worker for processing transcription jobs.""" import logging import multiprocessing as mp +import os import time import traceback -from datetime import datetime -from enum import Enum +from datetime import datetime, timezone +from enum import IntEnum, Enum from typing import Optional from backend.core.database import Database @@ -20,13 +21,23 @@ class WorkerType(str, Enum): GPU = "gpu" -class WorkerStatus(str, Enum): +class WorkerStatus(IntEnum): """Worker status states.""" - IDLE = "idle" - BUSY = "busy" - STOPPING = "stopping" - STOPPED = "stopped" - ERROR = "error" + IDLE = 0 + BUSY = 1 + STOPPING = 2 + STOPPED = 3 + ERROR = 4 + + def to_string(self) -> str: + """Convert to string representation.""" + return { + 0: "idle", + 1: "busy", + 2: "stopping", + 3: "stopped", + 4: "error" + }.get(self.value, "unknown") class Worker: @@ -79,13 +90,13 @@ class Worker: daemon=True ) self.process.start() - self.started_at = datetime.utcnow() + self.started_at = datetime.now(timezone.utc) logger.info( f"Worker {self.worker_id} started (PID: {self.process.pid}, " f"Type: {self.worker_type.value})" ) - def stop(self, timeout: float = 30.0): + def stop(self, timeout: float = 5.0): """ Stop the worker process gracefully. @@ -93,7 +104,7 @@ class Worker: timeout: Maximum time to wait for worker to stop """ if not self.process or not self.process.is_alive(): - logger.warning(f"Worker {self.worker_id} is not running") + logger.debug(f"Worker {self.worker_id} is not running") return logger.info(f"Stopping worker {self.worker_id}...") @@ -103,11 +114,12 @@ class Worker: if self.process.is_alive(): logger.warning(f"Worker {self.worker_id} did not stop gracefully, terminating...") self.process.terminate() - self.process.join(timeout=5.0) + self.process.join(timeout=2.0) if self.process.is_alive(): logger.error(f"Worker {self.worker_id} did not terminate, killing...") self.process.kill() + self.process.join(timeout=1.0) logger.info(f"Worker {self.worker_id} stopped") @@ -130,7 +142,7 @@ class Worker: "worker_id": self.worker_id, "type": self.worker_type.value, "device_id": self.device_id, - "status": status_enum.value, + "status": status_enum.to_string(), # Convert to string "current_job_id": current_job if current_job else None, "jobs_completed": self.jobs_completed.value, "jobs_failed": self.jobs_failed.value, @@ -205,75 +217,244 @@ class Worker: def _process_job(self, job: Job, queue_mgr: QueueManager): """ - Process a single transcription job. + Process a job (transcription or language detection). Args: job: Job to process queue_mgr: Queue manager for updating progress """ - # TODO: This will be implemented when we add the transcriber module - # For now, simulate work + from backend.core.models import JobType - # Stage 1: Detect language - queue_mgr.update_job_progress( - job.id, - progress=10.0, - stage=JobStage.DETECTING_LANGUAGE, - eta_seconds=60 - ) - time.sleep(2) # Simulate work + # Route to appropriate handler based on job type + if job.job_type == JobType.LANGUAGE_DETECTION: + self._process_language_detection(job, queue_mgr) + else: + self._process_transcription(job, queue_mgr) - # Stage 2: Extract audio - queue_mgr.update_job_progress( - job.id, - progress=20.0, - stage=JobStage.EXTRACTING_AUDIO, - eta_seconds=50 - ) - time.sleep(2) + def _process_language_detection(self, job: Job, queue_mgr: QueueManager): + """ + Process a language detection job using fast Whisper model. - # Stage 3: Transcribe - queue_mgr.update_job_progress( - job.id, - progress=30.0, - stage=JobStage.TRANSCRIBING, - eta_seconds=40 - ) + Args: + job: Language detection job + queue_mgr: Queue manager for updating progress + """ + start_time = time.time() - # Simulate progressive transcription - for i in range(30, 90, 10): - time.sleep(1) + try: + logger.info(f"Worker {self.worker_id} processing LANGUAGE DETECTION job {job.id}: {job.file_name}") + + # Stage 1: Detecting language (20% progress) queue_mgr.update_job_progress( - job.id, - progress=float(i), - stage=JobStage.TRANSCRIBING, - eta_seconds=int((100 - i) / 2) + job.id, progress=20.0, stage=JobStage.DETECTING_LANGUAGE, eta_seconds=10 ) - # Stage 4: Finalize - queue_mgr.update_job_progress( - job.id, - progress=95.0, - stage=JobStage.FINALIZING, - eta_seconds=5 - ) - time.sleep(1) + # Use language detector with tiny model + from backend.scanning.language_detector import LanguageDetector - # Mark as completed - output_path = job.file_path.replace('.mkv', '.srt') - queue_mgr.mark_job_completed( - job.id, - output_path=output_path, - segments_count=100, # Simulated - srt_content="Simulated SRT content" - ) + language, confidence = LanguageDetector.detect_language( + file_path=job.file_path, + sample_duration=30 + ) + + # Stage 2: Finalizing (80% progress) + queue_mgr.update_job_progress( + job.id, progress=80.0, stage=JobStage.FINALIZING, eta_seconds=2 + ) + + if language: + # Calculate processing time + processing_time = time.time() - start_time + + # Use ISO 639-1 format (ja, en, es) throughout the system + lang_code = language.value[0] if language else "unknown" + + result_text = f"Language detected: {lang_code} ({language.name.title() if language else 'Unknown'})\nConfidence: {confidence}%" + + # Store in ISO 639-1 format (ja, en, es) for consistency + queue_mgr.mark_job_completed( + job.id, + output_path=None, + segments_count=0, + srt_content=result_text, + detected_language=lang_code # Use ISO 639-1 (ja, en, es) + ) + + logger.info( + f"Worker {self.worker_id} completed detection job {job.id}: " + f"{lang_code} (confidence: {confidence}%) in {processing_time:.1f}s" + ) + + # Check if file matches any scan rules and queue transcription job + self._check_and_queue_transcription(job, lang_code) + else: + # Detection failed + queue_mgr.mark_job_failed(job.id, "Language detection failed - could not detect language") + logger.error(f"Worker {self.worker_id} failed detection job {job.id}: No language detected") + + except Exception as e: + logger.error(f"Worker {self.worker_id} failed detection job {job.id}: {e}", exc_info=True) + queue_mgr.mark_job_failed(job.id, str(e)) + + def _process_transcription(self, job: Job, queue_mgr: QueueManager): + """ + Process a transcription/translation job using Whisper. + + Args: + job: Transcription job + queue_mgr: Queue manager for updating progress + """ + from backend.transcription import WhisperTranscriber + from backend.transcription.audio_utils import handle_multiple_audio_tracks + from backend.core.language_code import LanguageCode + + transcriber = None + start_time = time.time() + + try: + logger.info(f"Worker {self.worker_id} processing TRANSCRIPTION job {job.id}: {job.file_name}") + + # Stage 1: Loading model + queue_mgr.update_job_progress( + job.id, progress=5.0, stage=JobStage.LOADING_MODEL, eta_seconds=None + ) + + # Determine device for transcriber + if self.worker_type == WorkerType.GPU: + device = f"cuda:{self.device_id}" if self.device_id is not None else "cuda" + else: + device = "cpu" + + transcriber = WhisperTranscriber(device=device) + transcriber.load_model() + + # Stage 2: Preparing audio + queue_mgr.update_job_progress( + job.id, progress=10.0, stage=JobStage.EXTRACTING_AUDIO, eta_seconds=None + ) + + # Handle multiple audio tracks if needed + source_lang = ( + LanguageCode.from_string(job.source_lang) if job.source_lang else None + ) + audio_data = handle_multiple_audio_tracks(job.file_path, source_lang) + + # Stage 3: Transcribing + queue_mgr.update_job_progress( + job.id, progress=15.0, stage=JobStage.TRANSCRIBING, eta_seconds=None + ) + + # Progress callback for real-time updates + def progress_callback(seek, total): + # Reserve 15%-75% for Whisper (60% range) + # If translate mode, reserve 75%-90% for translation (15% range) + whisper_progress = 15.0 + (seek / total) * 60.0 + queue_mgr.update_job_progress(job.id, progress=whisper_progress, stage=JobStage.TRANSCRIBING) + + # Stage 3A: Whisper transcription to English + # IMPORTANT: Both 'transcribe' and 'translate' modes use task='translate' here + # to convert audio to English subtitles + logger.info(f"Running Whisper with task='translate' to convert audio to English") + + # job.source_lang is already in ISO 639-1 format (ja, en, es) + # Whisper accepts ISO 639-1, so we can use it directly + if audio_data: + result = transcriber.transcribe_audio_data( + audio_data=audio_data.read(), + language=job.source_lang, # Already ISO 639-1 (ja, en, es) + task="translate", # ALWAYS translate to English first + progress_callback=progress_callback, + ) + else: + result = transcriber.transcribe_file( + file_path=job.file_path, + language=job.source_lang, # Already ISO 639-1 (ja, en, es) + task="translate", # ALWAYS translate to English first + progress_callback=progress_callback, + ) + + # Generate English SRT filename + file_base = os.path.splitext(job.file_path)[0] + english_srt_path = f"{file_base}.eng.srt" + + # Save English SRT + result.to_srt(english_srt_path, word_level=False) + logger.info(f"English subtitles saved to {english_srt_path}") + + # Stage 3B: Optional translation to target language + if job.transcribe_or_translate == "translate" and job.target_lang and job.target_lang.lower() != "eng": + queue_mgr.update_job_progress( + job.id, progress=75.0, stage=JobStage.FINALIZING, eta_seconds=10 + ) + + logger.info(f"Translating English subtitles to {job.target_lang}") + + from backend.transcription import translate_srt_file + + # Generate target language SRT filename + target_srt_path = f"{file_base}.{job.target_lang}.srt" + + # Translate English SRT to target language + success = translate_srt_file( + input_path=english_srt_path, + output_path=target_srt_path, + target_language=job.target_lang + ) + + if success: + logger.info(f"Translated subtitles saved to {target_srt_path}") + output_path = target_srt_path + else: + logger.warning(f"Translation failed, keeping English subtitles only") + output_path = english_srt_path + else: + # For 'transcribe' mode or if target is English, use English SRT + output_path = english_srt_path + + # Stage 4: Finalize + queue_mgr.update_job_progress( + job.id, progress=90.0, stage=JobStage.FINALIZING, eta_seconds=5 + ) + + # Calculate processing time + processing_time = time.time() - start_time + + # Get SRT content for storage + srt_content = result.get_srt_content() + + # Mark job as completed + queue_mgr.mark_job_completed( + job.id, + output_path=output_path, + segments_count=result.segments_count, + srt_content=srt_content, + model_used=transcriber.model_name, + device_used=transcriber.device, + processing_time_seconds=processing_time, + ) + + logger.info( + f"Worker {self.worker_id} completed job {job.id}: " + f"{result.segments_count} segments in {processing_time:.1f}s" + ) + + except Exception as e: + logger.error(f"Worker {self.worker_id} failed job {job.id}: {e}", exc_info=True) + queue_mgr.mark_job_failed(job.id, str(e)) + finally: + # Always unload model after job + if transcriber: + try: + transcriber.unload_model() + except Exception as e: + logger.error(f"Error unloading model: {e}") def _set_status(self, status: WorkerStatus): """Set worker status (thread-safe).""" self.status.value = status.value def _set_current_job(self, job_id: str): - """Set current job ID (thread-safe).""" + """Set the current job ID (thread-safe).""" job_id_bytes = job_id.encode('utf-8') for i, byte in enumerate(job_id_bytes): if i < len(self.current_job_id): @@ -282,4 +463,31 @@ class Worker: def _clear_current_job(self): """Clear current job ID (thread-safe).""" for i in range(len(self.current_job_id)): - self.current_job_id[i] = b'\x00' \ No newline at end of file + self.current_job_id[i] = b'\x00' + + def _check_and_queue_transcription(self, job: Job, detected_lang_code: str): + """ + Check if detected language matches any scan rules and queue transcription job. + + Args: + job: Completed language detection job + detected_lang_code: Detected language code (ISO 639-1, e.g., 'ja', 'en') + """ + try: + from backend.scanning.library_scanner import library_scanner + + logger.info( + f"Language detection completed for {job.file_path}: {detected_lang_code}. " + f"Checking scan rules..." + ) + + # Use the scanner's method to check rules and queue transcription + library_scanner._check_and_queue_transcription_for_file( + job.file_path, detected_lang_code + ) + + except Exception as e: + logger.error( + f"Error checking scan rules for {job.file_path}: {e}", + exc_info=True + ) diff --git a/backend/core/worker_pool.py b/backend/core/worker_pool.py new file mode 100644 index 0000000..1cd81ad --- /dev/null +++ b/backend/core/worker_pool.py @@ -0,0 +1,339 @@ +"""Worker pool orchestrator for managing transcription workers.""" +import logging +import time +from typing import Dict, List, Optional +from datetime import datetime, timezone + +from backend.core.worker import Worker, WorkerType, WorkerStatus +from backend.core.queue_manager import queue_manager + +logger = logging.getLogger(__name__) + + +class WorkerPool: + """ + Orchestrator for managing a pool of transcription workers. + + Similar to Tdarr's worker management system, this class handles: + - Dynamic worker creation/removal (CPU and GPU) + - Worker health monitoring + - Load balancing via the queue + - Worker statistics and reporting + - Graceful shutdown + + Workers are managed as separate processes that pull jobs from the + persistent queue. The pool can be controlled via WebUI to add/remove + workers on-demand. + """ + + def __init__(self): + """Initialize worker pool.""" + self.workers: Dict[str, Worker] = {} + self.is_running = False + self.started_at: Optional[datetime] = None + + logger.info("WorkerPool initialized") + + def start(self, cpu_workers: int = 0, gpu_workers: int = 0): + """ + Start the worker pool with specified number of workers. + + Args: + cpu_workers: Number of CPU workers to start + gpu_workers: Number of GPU workers to start + """ + if self.is_running: + logger.warning("WorkerPool is already running") + return + + self.is_running = True + self.started_at = datetime.now(timezone.utc) + + # Start CPU workers + for i in range(cpu_workers): + self.add_worker(WorkerType.CPU) + + # Start GPU workers + for i in range(gpu_workers): + self.add_worker(WorkerType.GPU, device_id=i % self._get_gpu_count()) + + logger.info( + f"WorkerPool started: {cpu_workers} CPU workers, {gpu_workers} GPU workers" + ) + + def stop(self, timeout: float = 30.0): + """ + Stop all workers gracefully. + + Args: + timeout: Maximum time to wait for each worker to stop + """ + if not self.is_running: + logger.warning("WorkerPool is not running") + return + + logger.info(f"Stopping WorkerPool with {len(self.workers)} workers...") + + # Stop all workers + for worker_id, worker in list(self.workers.items()): + logger.info(f"Stopping worker {worker_id}") + worker.stop(timeout=timeout) + + self.workers.clear() + self.is_running = False + + logger.info("WorkerPool stopped") + + def add_worker( + self, + worker_type: WorkerType, + device_id: Optional[int] = None + ) -> str: + """ + Add a new worker to the pool. + + Args: + worker_type: CPU or GPU + device_id: GPU device ID (only for GPU workers) + + Returns: + Worker ID + """ + # Generate unique worker ID + worker_id = self._generate_worker_id(worker_type, device_id) + + if worker_id in self.workers: + logger.warning(f"Worker {worker_id} already exists") + return worker_id + + # Create and start worker + worker = Worker(worker_id, worker_type, device_id) + worker.start() + + self.workers[worker_id] = worker + + logger.info(f"Added worker {worker_id} ({worker_type.value})") + return worker_id + + def remove_worker(self, worker_id: str, timeout: float = 30.0) -> bool: + """ + Remove a worker from the pool. + + Args: + worker_id: Worker ID to remove + timeout: Maximum time to wait for worker to stop + + Returns: + True if worker was removed, False otherwise + """ + worker = self.workers.get(worker_id) + + if not worker: + logger.warning(f"Worker {worker_id} not found") + return False + + logger.info(f"Removing worker {worker_id}") + worker.stop(timeout=timeout) + + del self.workers[worker_id] + + logger.info(f"Worker {worker_id} removed") + return True + + def get_worker_status(self, worker_id: str) -> Optional[dict]: + """ + Get status of a specific worker. + + Args: + worker_id: Worker ID + + Returns: + Worker status dict or None if not found + """ + worker = self.workers.get(worker_id) + if not worker: + return None + + return worker.get_status() + + def get_all_workers_status(self) -> List[dict]: + """ + Get status of all workers. + + Returns: + List of worker status dicts + """ + return [worker.get_status() for worker in self.workers.values()] + + def get_pool_stats(self) -> dict: + """ + Get overall pool statistics. + + Returns: + Dictionary with pool statistics + """ + total_workers = len(self.workers) + cpu_workers = sum(1 for w in self.workers.values() if w.worker_type == WorkerType.CPU) + gpu_workers = sum(1 for w in self.workers.values() if w.worker_type == WorkerType.GPU) + + # Count workers by status + idle_workers = 0 + busy_workers = 0 + stopped_workers = 0 + error_workers = 0 + + for worker in self.workers.values(): + status_dict = worker.get_status() + status = status_dict["status"] # This is a string like "idle", "busy", etc. + + if status == "idle": + idle_workers += 1 + elif status == "busy": + busy_workers += 1 + elif status == "stopped": + stopped_workers += 1 + elif status == "error": + error_workers += 1 + + # Get total jobs processed + total_completed = sum(w.jobs_completed.value for w in self.workers.values()) + total_failed = sum(w.jobs_failed.value for w in self.workers.values()) + + # Get queue stats + queue_stats = queue_manager.get_queue_stats() + + return { + "pool": { + "is_running": self.is_running, + "started_at": self.started_at.isoformat() if self.started_at else None, + "total_workers": total_workers, + "cpu_workers": cpu_workers, + "gpu_workers": gpu_workers, + "idle_workers": idle_workers, + "busy_workers": busy_workers, + "stopped_workers": stopped_workers, + "error_workers": error_workers, + }, + "jobs": { + "completed": total_completed, + "failed": total_failed, + "success_rate": ( + total_completed / (total_completed + total_failed) * 100 + if (total_completed + total_failed) > 0 + else 0 + ), + }, + "queue": queue_stats, + } + + def health_check(self) -> dict: + """ + Perform health check on all workers. + + Restarts dead workers automatically. + + Returns: + Health check results + """ + dead_workers = [] + restarted_workers = [] + + for worker_id, worker in list(self.workers.items()): + if not worker.is_alive(): + logger.warning(f"Worker {worker_id} is dead, restarting...") + dead_workers.append(worker_id) + + # Try to restart + try: + worker.start() + restarted_workers.append(worker_id) + logger.info(f"Worker {worker_id} restarted successfully") + except Exception as e: + logger.error(f"Failed to restart worker {worker_id}: {e}") + + return { + "timestamp": datetime.now(timezone.utc).isoformat(), + "total_workers": len(self.workers), + "dead_workers": dead_workers, + "restarted_workers": restarted_workers, + "healthy": len(dead_workers) == 0, + } + + def auto_scale(self, target_workers: int): + """ + Auto-scale workers based on queue size. + + This is a placeholder for future auto-scaling logic. + + Args: + target_workers: Target number of workers + """ + current_workers = len(self.workers) + + if current_workers < target_workers: + # Add workers + workers_to_add = target_workers - current_workers + logger.info(f"Auto-scaling: adding {workers_to_add} workers") + + for _ in range(workers_to_add): + # Default to CPU workers for auto-scaling + self.add_worker(WorkerType.CPU) + + elif current_workers > target_workers: + # Remove idle workers + workers_to_remove = current_workers - target_workers + logger.info(f"Auto-scaling: removing {workers_to_remove} workers") + + # Find idle workers to remove + idle_workers = [ + worker_id for worker_id, worker in self.workers.items() + if worker.get_status()["status"] == WorkerStatus.IDLE.value + ] + + for worker_id in idle_workers[:workers_to_remove]: + self.remove_worker(worker_id) + + def _generate_worker_id( + self, + worker_type: WorkerType, + device_id: Optional[int] = None + ) -> str: + """ + Generate unique worker ID. + + Args: + worker_type: CPU or GPU + device_id: GPU device ID + + Returns: + Worker ID string + """ + prefix = "cpu" if worker_type == WorkerType.CPU else f"gpu{device_id}" + + # Count existing workers of this type + existing_count = sum( + 1 for wid in self.workers.keys() + if wid.startswith(prefix) + ) + + return f"{prefix}-{existing_count + 1}" + + def _get_gpu_count(self) -> int: + """ + Get number of available GPUs. + + Returns: + Number of GPUs (defaults to 1 if detection fails) + """ + try: + import torch + if torch.cuda.is_available(): + return torch.cuda.device_count() + except ImportError: + pass + + return 1 # Default to 1 GPU + + +# Global worker pool instance +worker_pool = WorkerPool() \ No newline at end of file