feat(workers): add multiprocessing worker pool system

- Add Worker class with CPU/GPU support
- Add WorkerPool for orchestrating multiple workers
- Support dynamic add/remove workers at runtime
- Add health monitoring with graceful shutdown
This commit is contained in:
2026-01-16 16:56:42 +01:00
parent cbf5ef9623
commit c019e96cfa
2 changed files with 611 additions and 64 deletions

View File

@@ -1,10 +1,11 @@
"""Individual worker for processing transcription jobs."""
import logging
import multiprocessing as mp
import os
import time
import traceback
from datetime import datetime
from enum import Enum
from datetime import datetime, timezone
from enum import IntEnum, Enum
from typing import Optional
from backend.core.database import Database
@@ -20,13 +21,23 @@ class WorkerType(str, Enum):
GPU = "gpu"
class WorkerStatus(str, Enum):
class WorkerStatus(IntEnum):
"""Worker status states."""
IDLE = "idle"
BUSY = "busy"
STOPPING = "stopping"
STOPPED = "stopped"
ERROR = "error"
IDLE = 0
BUSY = 1
STOPPING = 2
STOPPED = 3
ERROR = 4
def to_string(self) -> str:
"""Convert to string representation."""
return {
0: "idle",
1: "busy",
2: "stopping",
3: "stopped",
4: "error"
}.get(self.value, "unknown")
class Worker:
@@ -79,13 +90,13 @@ class Worker:
daemon=True
)
self.process.start()
self.started_at = datetime.utcnow()
self.started_at = datetime.now(timezone.utc)
logger.info(
f"Worker {self.worker_id} started (PID: {self.process.pid}, "
f"Type: {self.worker_type.value})"
)
def stop(self, timeout: float = 30.0):
def stop(self, timeout: float = 5.0):
"""
Stop the worker process gracefully.
@@ -93,7 +104,7 @@ class Worker:
timeout: Maximum time to wait for worker to stop
"""
if not self.process or not self.process.is_alive():
logger.warning(f"Worker {self.worker_id} is not running")
logger.debug(f"Worker {self.worker_id} is not running")
return
logger.info(f"Stopping worker {self.worker_id}...")
@@ -103,11 +114,12 @@ class Worker:
if self.process.is_alive():
logger.warning(f"Worker {self.worker_id} did not stop gracefully, terminating...")
self.process.terminate()
self.process.join(timeout=5.0)
self.process.join(timeout=2.0)
if self.process.is_alive():
logger.error(f"Worker {self.worker_id} did not terminate, killing...")
self.process.kill()
self.process.join(timeout=1.0)
logger.info(f"Worker {self.worker_id} stopped")
@@ -130,7 +142,7 @@ class Worker:
"worker_id": self.worker_id,
"type": self.worker_type.value,
"device_id": self.device_id,
"status": status_enum.value,
"status": status_enum.to_string(), # Convert to string
"current_job_id": current_job if current_job else None,
"jobs_completed": self.jobs_completed.value,
"jobs_failed": self.jobs_failed.value,
@@ -205,75 +217,244 @@ class Worker:
def _process_job(self, job: Job, queue_mgr: QueueManager):
"""
Process a single transcription job.
Process a job (transcription or language detection).
Args:
job: Job to process
queue_mgr: Queue manager for updating progress
"""
# TODO: This will be implemented when we add the transcriber module
# For now, simulate work
from backend.core.models import JobType
# Stage 1: Detect language
queue_mgr.update_job_progress(
job.id,
progress=10.0,
stage=JobStage.DETECTING_LANGUAGE,
eta_seconds=60
)
time.sleep(2) # Simulate work
# Route to appropriate handler based on job type
if job.job_type == JobType.LANGUAGE_DETECTION:
self._process_language_detection(job, queue_mgr)
else:
self._process_transcription(job, queue_mgr)
# Stage 2: Extract audio
queue_mgr.update_job_progress(
job.id,
progress=20.0,
stage=JobStage.EXTRACTING_AUDIO,
eta_seconds=50
)
time.sleep(2)
def _process_language_detection(self, job: Job, queue_mgr: QueueManager):
"""
Process a language detection job using fast Whisper model.
# Stage 3: Transcribe
Args:
job: Language detection job
queue_mgr: Queue manager for updating progress
"""
start_time = time.time()
try:
logger.info(f"Worker {self.worker_id} processing LANGUAGE DETECTION job {job.id}: {job.file_name}")
# Stage 1: Detecting language (20% progress)
queue_mgr.update_job_progress(
job.id,
progress=30.0,
stage=JobStage.TRANSCRIBING,
eta_seconds=40
job.id, progress=20.0, stage=JobStage.DETECTING_LANGUAGE, eta_seconds=10
)
# Simulate progressive transcription
for i in range(30, 90, 10):
time.sleep(1)
queue_mgr.update_job_progress(
job.id,
progress=float(i),
stage=JobStage.TRANSCRIBING,
eta_seconds=int((100 - i) / 2)
# Use language detector with tiny model
from backend.scanning.language_detector import LanguageDetector
language, confidence = LanguageDetector.detect_language(
file_path=job.file_path,
sample_duration=30
)
# Stage 2: Finalizing (80% progress)
queue_mgr.update_job_progress(
job.id, progress=80.0, stage=JobStage.FINALIZING, eta_seconds=2
)
if language:
# Calculate processing time
processing_time = time.time() - start_time
# Use ISO 639-1 format (ja, en, es) throughout the system
lang_code = language.value[0] if language else "unknown"
result_text = f"Language detected: {lang_code} ({language.name.title() if language else 'Unknown'})\nConfidence: {confidence}%"
# Store in ISO 639-1 format (ja, en, es) for consistency
queue_mgr.mark_job_completed(
job.id,
output_path=None,
segments_count=0,
srt_content=result_text,
detected_language=lang_code # Use ISO 639-1 (ja, en, es)
)
logger.info(
f"Worker {self.worker_id} completed detection job {job.id}: "
f"{lang_code} (confidence: {confidence}%) in {processing_time:.1f}s"
)
# Check if file matches any scan rules and queue transcription job
self._check_and_queue_transcription(job, lang_code)
else:
# Detection failed
queue_mgr.mark_job_failed(job.id, "Language detection failed - could not detect language")
logger.error(f"Worker {self.worker_id} failed detection job {job.id}: No language detected")
except Exception as e:
logger.error(f"Worker {self.worker_id} failed detection job {job.id}: {e}", exc_info=True)
queue_mgr.mark_job_failed(job.id, str(e))
def _process_transcription(self, job: Job, queue_mgr: QueueManager):
"""
Process a transcription/translation job using Whisper.
Args:
job: Transcription job
queue_mgr: Queue manager for updating progress
"""
from backend.transcription import WhisperTranscriber
from backend.transcription.audio_utils import handle_multiple_audio_tracks
from backend.core.language_code import LanguageCode
transcriber = None
start_time = time.time()
try:
logger.info(f"Worker {self.worker_id} processing TRANSCRIPTION job {job.id}: {job.file_name}")
# Stage 1: Loading model
queue_mgr.update_job_progress(
job.id, progress=5.0, stage=JobStage.LOADING_MODEL, eta_seconds=None
)
# Determine device for transcriber
if self.worker_type == WorkerType.GPU:
device = f"cuda:{self.device_id}" if self.device_id is not None else "cuda"
else:
device = "cpu"
transcriber = WhisperTranscriber(device=device)
transcriber.load_model()
# Stage 2: Preparing audio
queue_mgr.update_job_progress(
job.id, progress=10.0, stage=JobStage.EXTRACTING_AUDIO, eta_seconds=None
)
# Handle multiple audio tracks if needed
source_lang = (
LanguageCode.from_string(job.source_lang) if job.source_lang else None
)
audio_data = handle_multiple_audio_tracks(job.file_path, source_lang)
# Stage 3: Transcribing
queue_mgr.update_job_progress(
job.id, progress=15.0, stage=JobStage.TRANSCRIBING, eta_seconds=None
)
# Progress callback for real-time updates
def progress_callback(seek, total):
# Reserve 15%-75% for Whisper (60% range)
# If translate mode, reserve 75%-90% for translation (15% range)
whisper_progress = 15.0 + (seek / total) * 60.0
queue_mgr.update_job_progress(job.id, progress=whisper_progress, stage=JobStage.TRANSCRIBING)
# Stage 3A: Whisper transcription to English
# IMPORTANT: Both 'transcribe' and 'translate' modes use task='translate' here
# to convert audio to English subtitles
logger.info(f"Running Whisper with task='translate' to convert audio to English")
# job.source_lang is already in ISO 639-1 format (ja, en, es)
# Whisper accepts ISO 639-1, so we can use it directly
if audio_data:
result = transcriber.transcribe_audio_data(
audio_data=audio_data.read(),
language=job.source_lang, # Already ISO 639-1 (ja, en, es)
task="translate", # ALWAYS translate to English first
progress_callback=progress_callback,
)
else:
result = transcriber.transcribe_file(
file_path=job.file_path,
language=job.source_lang, # Already ISO 639-1 (ja, en, es)
task="translate", # ALWAYS translate to English first
progress_callback=progress_callback,
)
# Generate English SRT filename
file_base = os.path.splitext(job.file_path)[0]
english_srt_path = f"{file_base}.eng.srt"
# Save English SRT
result.to_srt(english_srt_path, word_level=False)
logger.info(f"English subtitles saved to {english_srt_path}")
# Stage 3B: Optional translation to target language
if job.transcribe_or_translate == "translate" and job.target_lang and job.target_lang.lower() != "eng":
queue_mgr.update_job_progress(
job.id, progress=75.0, stage=JobStage.FINALIZING, eta_seconds=10
)
logger.info(f"Translating English subtitles to {job.target_lang}")
from backend.transcription import translate_srt_file
# Generate target language SRT filename
target_srt_path = f"{file_base}.{job.target_lang}.srt"
# Translate English SRT to target language
success = translate_srt_file(
input_path=english_srt_path,
output_path=target_srt_path,
target_language=job.target_lang
)
if success:
logger.info(f"Translated subtitles saved to {target_srt_path}")
output_path = target_srt_path
else:
logger.warning(f"Translation failed, keeping English subtitles only")
output_path = english_srt_path
else:
# For 'transcribe' mode or if target is English, use English SRT
output_path = english_srt_path
# Stage 4: Finalize
queue_mgr.update_job_progress(
job.id,
progress=95.0,
stage=JobStage.FINALIZING,
eta_seconds=5
job.id, progress=90.0, stage=JobStage.FINALIZING, eta_seconds=5
)
time.sleep(1)
# Mark as completed
output_path = job.file_path.replace('.mkv', '.srt')
# Calculate processing time
processing_time = time.time() - start_time
# Get SRT content for storage
srt_content = result.get_srt_content()
# Mark job as completed
queue_mgr.mark_job_completed(
job.id,
output_path=output_path,
segments_count=100, # Simulated
srt_content="Simulated SRT content"
segments_count=result.segments_count,
srt_content=srt_content,
model_used=transcriber.model_name,
device_used=transcriber.device,
processing_time_seconds=processing_time,
)
logger.info(
f"Worker {self.worker_id} completed job {job.id}: "
f"{result.segments_count} segments in {processing_time:.1f}s"
)
except Exception as e:
logger.error(f"Worker {self.worker_id} failed job {job.id}: {e}", exc_info=True)
queue_mgr.mark_job_failed(job.id, str(e))
finally:
# Always unload model after job
if transcriber:
try:
transcriber.unload_model()
except Exception as e:
logger.error(f"Error unloading model: {e}")
def _set_status(self, status: WorkerStatus):
"""Set worker status (thread-safe)."""
self.status.value = status.value
def _set_current_job(self, job_id: str):
"""Set current job ID (thread-safe)."""
"""Set the current job ID (thread-safe)."""
job_id_bytes = job_id.encode('utf-8')
for i, byte in enumerate(job_id_bytes):
if i < len(self.current_job_id):
@@ -283,3 +464,30 @@ class Worker:
"""Clear current job ID (thread-safe)."""
for i in range(len(self.current_job_id)):
self.current_job_id[i] = b'\x00'
def _check_and_queue_transcription(self, job: Job, detected_lang_code: str):
"""
Check if detected language matches any scan rules and queue transcription job.
Args:
job: Completed language detection job
detected_lang_code: Detected language code (ISO 639-1, e.g., 'ja', 'en')
"""
try:
from backend.scanning.library_scanner import library_scanner
logger.info(
f"Language detection completed for {job.file_path}: {detected_lang_code}. "
f"Checking scan rules..."
)
# Use the scanner's method to check rules and queue transcription
library_scanner._check_and_queue_transcription_for_file(
job.file_path, detected_lang_code
)
except Exception as e:
logger.error(
f"Error checking scan rules for {job.file_path}: {e}",
exc_info=True
)

339
backend/core/worker_pool.py Normal file
View File

@@ -0,0 +1,339 @@
"""Worker pool orchestrator for managing transcription workers."""
import logging
import time
from typing import Dict, List, Optional
from datetime import datetime, timezone
from backend.core.worker import Worker, WorkerType, WorkerStatus
from backend.core.queue_manager import queue_manager
logger = logging.getLogger(__name__)
class WorkerPool:
"""
Orchestrator for managing a pool of transcription workers.
Similar to Tdarr's worker management system, this class handles:
- Dynamic worker creation/removal (CPU and GPU)
- Worker health monitoring
- Load balancing via the queue
- Worker statistics and reporting
- Graceful shutdown
Workers are managed as separate processes that pull jobs from the
persistent queue. The pool can be controlled via WebUI to add/remove
workers on-demand.
"""
def __init__(self):
"""Initialize worker pool."""
self.workers: Dict[str, Worker] = {}
self.is_running = False
self.started_at: Optional[datetime] = None
logger.info("WorkerPool initialized")
def start(self, cpu_workers: int = 0, gpu_workers: int = 0):
"""
Start the worker pool with specified number of workers.
Args:
cpu_workers: Number of CPU workers to start
gpu_workers: Number of GPU workers to start
"""
if self.is_running:
logger.warning("WorkerPool is already running")
return
self.is_running = True
self.started_at = datetime.now(timezone.utc)
# Start CPU workers
for i in range(cpu_workers):
self.add_worker(WorkerType.CPU)
# Start GPU workers
for i in range(gpu_workers):
self.add_worker(WorkerType.GPU, device_id=i % self._get_gpu_count())
logger.info(
f"WorkerPool started: {cpu_workers} CPU workers, {gpu_workers} GPU workers"
)
def stop(self, timeout: float = 30.0):
"""
Stop all workers gracefully.
Args:
timeout: Maximum time to wait for each worker to stop
"""
if not self.is_running:
logger.warning("WorkerPool is not running")
return
logger.info(f"Stopping WorkerPool with {len(self.workers)} workers...")
# Stop all workers
for worker_id, worker in list(self.workers.items()):
logger.info(f"Stopping worker {worker_id}")
worker.stop(timeout=timeout)
self.workers.clear()
self.is_running = False
logger.info("WorkerPool stopped")
def add_worker(
self,
worker_type: WorkerType,
device_id: Optional[int] = None
) -> str:
"""
Add a new worker to the pool.
Args:
worker_type: CPU or GPU
device_id: GPU device ID (only for GPU workers)
Returns:
Worker ID
"""
# Generate unique worker ID
worker_id = self._generate_worker_id(worker_type, device_id)
if worker_id in self.workers:
logger.warning(f"Worker {worker_id} already exists")
return worker_id
# Create and start worker
worker = Worker(worker_id, worker_type, device_id)
worker.start()
self.workers[worker_id] = worker
logger.info(f"Added worker {worker_id} ({worker_type.value})")
return worker_id
def remove_worker(self, worker_id: str, timeout: float = 30.0) -> bool:
"""
Remove a worker from the pool.
Args:
worker_id: Worker ID to remove
timeout: Maximum time to wait for worker to stop
Returns:
True if worker was removed, False otherwise
"""
worker = self.workers.get(worker_id)
if not worker:
logger.warning(f"Worker {worker_id} not found")
return False
logger.info(f"Removing worker {worker_id}")
worker.stop(timeout=timeout)
del self.workers[worker_id]
logger.info(f"Worker {worker_id} removed")
return True
def get_worker_status(self, worker_id: str) -> Optional[dict]:
"""
Get status of a specific worker.
Args:
worker_id: Worker ID
Returns:
Worker status dict or None if not found
"""
worker = self.workers.get(worker_id)
if not worker:
return None
return worker.get_status()
def get_all_workers_status(self) -> List[dict]:
"""
Get status of all workers.
Returns:
List of worker status dicts
"""
return [worker.get_status() for worker in self.workers.values()]
def get_pool_stats(self) -> dict:
"""
Get overall pool statistics.
Returns:
Dictionary with pool statistics
"""
total_workers = len(self.workers)
cpu_workers = sum(1 for w in self.workers.values() if w.worker_type == WorkerType.CPU)
gpu_workers = sum(1 for w in self.workers.values() if w.worker_type == WorkerType.GPU)
# Count workers by status
idle_workers = 0
busy_workers = 0
stopped_workers = 0
error_workers = 0
for worker in self.workers.values():
status_dict = worker.get_status()
status = status_dict["status"] # This is a string like "idle", "busy", etc.
if status == "idle":
idle_workers += 1
elif status == "busy":
busy_workers += 1
elif status == "stopped":
stopped_workers += 1
elif status == "error":
error_workers += 1
# Get total jobs processed
total_completed = sum(w.jobs_completed.value for w in self.workers.values())
total_failed = sum(w.jobs_failed.value for w in self.workers.values())
# Get queue stats
queue_stats = queue_manager.get_queue_stats()
return {
"pool": {
"is_running": self.is_running,
"started_at": self.started_at.isoformat() if self.started_at else None,
"total_workers": total_workers,
"cpu_workers": cpu_workers,
"gpu_workers": gpu_workers,
"idle_workers": idle_workers,
"busy_workers": busy_workers,
"stopped_workers": stopped_workers,
"error_workers": error_workers,
},
"jobs": {
"completed": total_completed,
"failed": total_failed,
"success_rate": (
total_completed / (total_completed + total_failed) * 100
if (total_completed + total_failed) > 0
else 0
),
},
"queue": queue_stats,
}
def health_check(self) -> dict:
"""
Perform health check on all workers.
Restarts dead workers automatically.
Returns:
Health check results
"""
dead_workers = []
restarted_workers = []
for worker_id, worker in list(self.workers.items()):
if not worker.is_alive():
logger.warning(f"Worker {worker_id} is dead, restarting...")
dead_workers.append(worker_id)
# Try to restart
try:
worker.start()
restarted_workers.append(worker_id)
logger.info(f"Worker {worker_id} restarted successfully")
except Exception as e:
logger.error(f"Failed to restart worker {worker_id}: {e}")
return {
"timestamp": datetime.now(timezone.utc).isoformat(),
"total_workers": len(self.workers),
"dead_workers": dead_workers,
"restarted_workers": restarted_workers,
"healthy": len(dead_workers) == 0,
}
def auto_scale(self, target_workers: int):
"""
Auto-scale workers based on queue size.
This is a placeholder for future auto-scaling logic.
Args:
target_workers: Target number of workers
"""
current_workers = len(self.workers)
if current_workers < target_workers:
# Add workers
workers_to_add = target_workers - current_workers
logger.info(f"Auto-scaling: adding {workers_to_add} workers")
for _ in range(workers_to_add):
# Default to CPU workers for auto-scaling
self.add_worker(WorkerType.CPU)
elif current_workers > target_workers:
# Remove idle workers
workers_to_remove = current_workers - target_workers
logger.info(f"Auto-scaling: removing {workers_to_remove} workers")
# Find idle workers to remove
idle_workers = [
worker_id for worker_id, worker in self.workers.items()
if worker.get_status()["status"] == WorkerStatus.IDLE.value
]
for worker_id in idle_workers[:workers_to_remove]:
self.remove_worker(worker_id)
def _generate_worker_id(
self,
worker_type: WorkerType,
device_id: Optional[int] = None
) -> str:
"""
Generate unique worker ID.
Args:
worker_type: CPU or GPU
device_id: GPU device ID
Returns:
Worker ID string
"""
prefix = "cpu" if worker_type == WorkerType.CPU else f"gpu{device_id}"
# Count existing workers of this type
existing_count = sum(
1 for wid in self.workers.keys()
if wid.startswith(prefix)
)
return f"{prefix}-{existing_count + 1}"
def _get_gpu_count(self) -> int:
"""
Get number of available GPUs.
Returns:
Number of GPUs (defaults to 1 if detection fails)
"""
try:
import torch
if torch.cuda.is_available():
return torch.cuda.device_count()
except ImportError:
pass
return 1 # Default to 1 GPU
# Global worker pool instance
worker_pool = WorkerPool()