feat(workers): add multiprocessing worker pool system
- Add Worker class with CPU/GPU support - Add WorkerPool for orchestrating multiple workers - Support dynamic add/remove workers at runtime - Add health monitoring with graceful shutdown
This commit is contained in:
@@ -1,10 +1,11 @@
|
|||||||
"""Individual worker for processing transcription jobs."""
|
"""Individual worker for processing transcription jobs."""
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from enum import Enum
|
from enum import IntEnum, Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from backend.core.database import Database
|
from backend.core.database import Database
|
||||||
@@ -20,13 +21,23 @@ class WorkerType(str, Enum):
|
|||||||
GPU = "gpu"
|
GPU = "gpu"
|
||||||
|
|
||||||
|
|
||||||
class WorkerStatus(str, Enum):
|
class WorkerStatus(IntEnum):
|
||||||
"""Worker status states."""
|
"""Worker status states."""
|
||||||
IDLE = "idle"
|
IDLE = 0
|
||||||
BUSY = "busy"
|
BUSY = 1
|
||||||
STOPPING = "stopping"
|
STOPPING = 2
|
||||||
STOPPED = "stopped"
|
STOPPED = 3
|
||||||
ERROR = "error"
|
ERROR = 4
|
||||||
|
|
||||||
|
def to_string(self) -> str:
|
||||||
|
"""Convert to string representation."""
|
||||||
|
return {
|
||||||
|
0: "idle",
|
||||||
|
1: "busy",
|
||||||
|
2: "stopping",
|
||||||
|
3: "stopped",
|
||||||
|
4: "error"
|
||||||
|
}.get(self.value, "unknown")
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
@@ -79,13 +90,13 @@ class Worker:
|
|||||||
daemon=True
|
daemon=True
|
||||||
)
|
)
|
||||||
self.process.start()
|
self.process.start()
|
||||||
self.started_at = datetime.utcnow()
|
self.started_at = datetime.now(timezone.utc)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Worker {self.worker_id} started (PID: {self.process.pid}, "
|
f"Worker {self.worker_id} started (PID: {self.process.pid}, "
|
||||||
f"Type: {self.worker_type.value})"
|
f"Type: {self.worker_type.value})"
|
||||||
)
|
)
|
||||||
|
|
||||||
def stop(self, timeout: float = 30.0):
|
def stop(self, timeout: float = 5.0):
|
||||||
"""
|
"""
|
||||||
Stop the worker process gracefully.
|
Stop the worker process gracefully.
|
||||||
|
|
||||||
@@ -93,7 +104,7 @@ class Worker:
|
|||||||
timeout: Maximum time to wait for worker to stop
|
timeout: Maximum time to wait for worker to stop
|
||||||
"""
|
"""
|
||||||
if not self.process or not self.process.is_alive():
|
if not self.process or not self.process.is_alive():
|
||||||
logger.warning(f"Worker {self.worker_id} is not running")
|
logger.debug(f"Worker {self.worker_id} is not running")
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"Stopping worker {self.worker_id}...")
|
logger.info(f"Stopping worker {self.worker_id}...")
|
||||||
@@ -103,11 +114,12 @@ class Worker:
|
|||||||
if self.process.is_alive():
|
if self.process.is_alive():
|
||||||
logger.warning(f"Worker {self.worker_id} did not stop gracefully, terminating...")
|
logger.warning(f"Worker {self.worker_id} did not stop gracefully, terminating...")
|
||||||
self.process.terminate()
|
self.process.terminate()
|
||||||
self.process.join(timeout=5.0)
|
self.process.join(timeout=2.0)
|
||||||
|
|
||||||
if self.process.is_alive():
|
if self.process.is_alive():
|
||||||
logger.error(f"Worker {self.worker_id} did not terminate, killing...")
|
logger.error(f"Worker {self.worker_id} did not terminate, killing...")
|
||||||
self.process.kill()
|
self.process.kill()
|
||||||
|
self.process.join(timeout=1.0)
|
||||||
|
|
||||||
logger.info(f"Worker {self.worker_id} stopped")
|
logger.info(f"Worker {self.worker_id} stopped")
|
||||||
|
|
||||||
@@ -130,7 +142,7 @@ class Worker:
|
|||||||
"worker_id": self.worker_id,
|
"worker_id": self.worker_id,
|
||||||
"type": self.worker_type.value,
|
"type": self.worker_type.value,
|
||||||
"device_id": self.device_id,
|
"device_id": self.device_id,
|
||||||
"status": status_enum.value,
|
"status": status_enum.to_string(), # Convert to string
|
||||||
"current_job_id": current_job if current_job else None,
|
"current_job_id": current_job if current_job else None,
|
||||||
"jobs_completed": self.jobs_completed.value,
|
"jobs_completed": self.jobs_completed.value,
|
||||||
"jobs_failed": self.jobs_failed.value,
|
"jobs_failed": self.jobs_failed.value,
|
||||||
@@ -205,75 +217,244 @@ class Worker:
|
|||||||
|
|
||||||
def _process_job(self, job: Job, queue_mgr: QueueManager):
|
def _process_job(self, job: Job, queue_mgr: QueueManager):
|
||||||
"""
|
"""
|
||||||
Process a single transcription job.
|
Process a job (transcription or language detection).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
job: Job to process
|
job: Job to process
|
||||||
queue_mgr: Queue manager for updating progress
|
queue_mgr: Queue manager for updating progress
|
||||||
"""
|
"""
|
||||||
# TODO: This will be implemented when we add the transcriber module
|
from backend.core.models import JobType
|
||||||
# For now, simulate work
|
|
||||||
|
|
||||||
# Stage 1: Detect language
|
# Route to appropriate handler based on job type
|
||||||
queue_mgr.update_job_progress(
|
if job.job_type == JobType.LANGUAGE_DETECTION:
|
||||||
job.id,
|
self._process_language_detection(job, queue_mgr)
|
||||||
progress=10.0,
|
else:
|
||||||
stage=JobStage.DETECTING_LANGUAGE,
|
self._process_transcription(job, queue_mgr)
|
||||||
eta_seconds=60
|
|
||||||
)
|
|
||||||
time.sleep(2) # Simulate work
|
|
||||||
|
|
||||||
# Stage 2: Extract audio
|
def _process_language_detection(self, job: Job, queue_mgr: QueueManager):
|
||||||
queue_mgr.update_job_progress(
|
"""
|
||||||
job.id,
|
Process a language detection job using fast Whisper model.
|
||||||
progress=20.0,
|
|
||||||
stage=JobStage.EXTRACTING_AUDIO,
|
|
||||||
eta_seconds=50
|
|
||||||
)
|
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
# Stage 3: Transcribe
|
Args:
|
||||||
queue_mgr.update_job_progress(
|
job: Language detection job
|
||||||
job.id,
|
queue_mgr: Queue manager for updating progress
|
||||||
progress=30.0,
|
"""
|
||||||
stage=JobStage.TRANSCRIBING,
|
start_time = time.time()
|
||||||
eta_seconds=40
|
|
||||||
)
|
|
||||||
|
|
||||||
# Simulate progressive transcription
|
try:
|
||||||
for i in range(30, 90, 10):
|
logger.info(f"Worker {self.worker_id} processing LANGUAGE DETECTION job {job.id}: {job.file_name}")
|
||||||
time.sleep(1)
|
|
||||||
|
# Stage 1: Detecting language (20% progress)
|
||||||
queue_mgr.update_job_progress(
|
queue_mgr.update_job_progress(
|
||||||
job.id,
|
job.id, progress=20.0, stage=JobStage.DETECTING_LANGUAGE, eta_seconds=10
|
||||||
progress=float(i),
|
|
||||||
stage=JobStage.TRANSCRIBING,
|
|
||||||
eta_seconds=int((100 - i) / 2)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stage 4: Finalize
|
# Use language detector with tiny model
|
||||||
queue_mgr.update_job_progress(
|
from backend.scanning.language_detector import LanguageDetector
|
||||||
job.id,
|
|
||||||
progress=95.0,
|
|
||||||
stage=JobStage.FINALIZING,
|
|
||||||
eta_seconds=5
|
|
||||||
)
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
# Mark as completed
|
language, confidence = LanguageDetector.detect_language(
|
||||||
output_path = job.file_path.replace('.mkv', '.srt')
|
file_path=job.file_path,
|
||||||
queue_mgr.mark_job_completed(
|
sample_duration=30
|
||||||
job.id,
|
)
|
||||||
output_path=output_path,
|
|
||||||
segments_count=100, # Simulated
|
# Stage 2: Finalizing (80% progress)
|
||||||
srt_content="Simulated SRT content"
|
queue_mgr.update_job_progress(
|
||||||
)
|
job.id, progress=80.0, stage=JobStage.FINALIZING, eta_seconds=2
|
||||||
|
)
|
||||||
|
|
||||||
|
if language:
|
||||||
|
# Calculate processing time
|
||||||
|
processing_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Use ISO 639-1 format (ja, en, es) throughout the system
|
||||||
|
lang_code = language.value[0] if language else "unknown"
|
||||||
|
|
||||||
|
result_text = f"Language detected: {lang_code} ({language.name.title() if language else 'Unknown'})\nConfidence: {confidence}%"
|
||||||
|
|
||||||
|
# Store in ISO 639-1 format (ja, en, es) for consistency
|
||||||
|
queue_mgr.mark_job_completed(
|
||||||
|
job.id,
|
||||||
|
output_path=None,
|
||||||
|
segments_count=0,
|
||||||
|
srt_content=result_text,
|
||||||
|
detected_language=lang_code # Use ISO 639-1 (ja, en, es)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Worker {self.worker_id} completed detection job {job.id}: "
|
||||||
|
f"{lang_code} (confidence: {confidence}%) in {processing_time:.1f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if file matches any scan rules and queue transcription job
|
||||||
|
self._check_and_queue_transcription(job, lang_code)
|
||||||
|
else:
|
||||||
|
# Detection failed
|
||||||
|
queue_mgr.mark_job_failed(job.id, "Language detection failed - could not detect language")
|
||||||
|
logger.error(f"Worker {self.worker_id} failed detection job {job.id}: No language detected")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Worker {self.worker_id} failed detection job {job.id}: {e}", exc_info=True)
|
||||||
|
queue_mgr.mark_job_failed(job.id, str(e))
|
||||||
|
|
||||||
|
def _process_transcription(self, job: Job, queue_mgr: QueueManager):
|
||||||
|
"""
|
||||||
|
Process a transcription/translation job using Whisper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job: Transcription job
|
||||||
|
queue_mgr: Queue manager for updating progress
|
||||||
|
"""
|
||||||
|
from backend.transcription import WhisperTranscriber
|
||||||
|
from backend.transcription.audio_utils import handle_multiple_audio_tracks
|
||||||
|
from backend.core.language_code import LanguageCode
|
||||||
|
|
||||||
|
transcriber = None
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Worker {self.worker_id} processing TRANSCRIPTION job {job.id}: {job.file_name}")
|
||||||
|
|
||||||
|
# Stage 1: Loading model
|
||||||
|
queue_mgr.update_job_progress(
|
||||||
|
job.id, progress=5.0, stage=JobStage.LOADING_MODEL, eta_seconds=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine device for transcriber
|
||||||
|
if self.worker_type == WorkerType.GPU:
|
||||||
|
device = f"cuda:{self.device_id}" if self.device_id is not None else "cuda"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
transcriber = WhisperTranscriber(device=device)
|
||||||
|
transcriber.load_model()
|
||||||
|
|
||||||
|
# Stage 2: Preparing audio
|
||||||
|
queue_mgr.update_job_progress(
|
||||||
|
job.id, progress=10.0, stage=JobStage.EXTRACTING_AUDIO, eta_seconds=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle multiple audio tracks if needed
|
||||||
|
source_lang = (
|
||||||
|
LanguageCode.from_string(job.source_lang) if job.source_lang else None
|
||||||
|
)
|
||||||
|
audio_data = handle_multiple_audio_tracks(job.file_path, source_lang)
|
||||||
|
|
||||||
|
# Stage 3: Transcribing
|
||||||
|
queue_mgr.update_job_progress(
|
||||||
|
job.id, progress=15.0, stage=JobStage.TRANSCRIBING, eta_seconds=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Progress callback for real-time updates
|
||||||
|
def progress_callback(seek, total):
|
||||||
|
# Reserve 15%-75% for Whisper (60% range)
|
||||||
|
# If translate mode, reserve 75%-90% for translation (15% range)
|
||||||
|
whisper_progress = 15.0 + (seek / total) * 60.0
|
||||||
|
queue_mgr.update_job_progress(job.id, progress=whisper_progress, stage=JobStage.TRANSCRIBING)
|
||||||
|
|
||||||
|
# Stage 3A: Whisper transcription to English
|
||||||
|
# IMPORTANT: Both 'transcribe' and 'translate' modes use task='translate' here
|
||||||
|
# to convert audio to English subtitles
|
||||||
|
logger.info(f"Running Whisper with task='translate' to convert audio to English")
|
||||||
|
|
||||||
|
# job.source_lang is already in ISO 639-1 format (ja, en, es)
|
||||||
|
# Whisper accepts ISO 639-1, so we can use it directly
|
||||||
|
if audio_data:
|
||||||
|
result = transcriber.transcribe_audio_data(
|
||||||
|
audio_data=audio_data.read(),
|
||||||
|
language=job.source_lang, # Already ISO 639-1 (ja, en, es)
|
||||||
|
task="translate", # ALWAYS translate to English first
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = transcriber.transcribe_file(
|
||||||
|
file_path=job.file_path,
|
||||||
|
language=job.source_lang, # Already ISO 639-1 (ja, en, es)
|
||||||
|
task="translate", # ALWAYS translate to English first
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate English SRT filename
|
||||||
|
file_base = os.path.splitext(job.file_path)[0]
|
||||||
|
english_srt_path = f"{file_base}.eng.srt"
|
||||||
|
|
||||||
|
# Save English SRT
|
||||||
|
result.to_srt(english_srt_path, word_level=False)
|
||||||
|
logger.info(f"English subtitles saved to {english_srt_path}")
|
||||||
|
|
||||||
|
# Stage 3B: Optional translation to target language
|
||||||
|
if job.transcribe_or_translate == "translate" and job.target_lang and job.target_lang.lower() != "eng":
|
||||||
|
queue_mgr.update_job_progress(
|
||||||
|
job.id, progress=75.0, stage=JobStage.FINALIZING, eta_seconds=10
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Translating English subtitles to {job.target_lang}")
|
||||||
|
|
||||||
|
from backend.transcription import translate_srt_file
|
||||||
|
|
||||||
|
# Generate target language SRT filename
|
||||||
|
target_srt_path = f"{file_base}.{job.target_lang}.srt"
|
||||||
|
|
||||||
|
# Translate English SRT to target language
|
||||||
|
success = translate_srt_file(
|
||||||
|
input_path=english_srt_path,
|
||||||
|
output_path=target_srt_path,
|
||||||
|
target_language=job.target_lang
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info(f"Translated subtitles saved to {target_srt_path}")
|
||||||
|
output_path = target_srt_path
|
||||||
|
else:
|
||||||
|
logger.warning(f"Translation failed, keeping English subtitles only")
|
||||||
|
output_path = english_srt_path
|
||||||
|
else:
|
||||||
|
# For 'transcribe' mode or if target is English, use English SRT
|
||||||
|
output_path = english_srt_path
|
||||||
|
|
||||||
|
# Stage 4: Finalize
|
||||||
|
queue_mgr.update_job_progress(
|
||||||
|
job.id, progress=90.0, stage=JobStage.FINALIZING, eta_seconds=5
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate processing time
|
||||||
|
processing_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Get SRT content for storage
|
||||||
|
srt_content = result.get_srt_content()
|
||||||
|
|
||||||
|
# Mark job as completed
|
||||||
|
queue_mgr.mark_job_completed(
|
||||||
|
job.id,
|
||||||
|
output_path=output_path,
|
||||||
|
segments_count=result.segments_count,
|
||||||
|
srt_content=srt_content,
|
||||||
|
model_used=transcriber.model_name,
|
||||||
|
device_used=transcriber.device,
|
||||||
|
processing_time_seconds=processing_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Worker {self.worker_id} completed job {job.id}: "
|
||||||
|
f"{result.segments_count} segments in {processing_time:.1f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Worker {self.worker_id} failed job {job.id}: {e}", exc_info=True)
|
||||||
|
queue_mgr.mark_job_failed(job.id, str(e))
|
||||||
|
finally:
|
||||||
|
# Always unload model after job
|
||||||
|
if transcriber:
|
||||||
|
try:
|
||||||
|
transcriber.unload_model()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error unloading model: {e}")
|
||||||
|
|
||||||
def _set_status(self, status: WorkerStatus):
|
def _set_status(self, status: WorkerStatus):
|
||||||
"""Set worker status (thread-safe)."""
|
"""Set worker status (thread-safe)."""
|
||||||
self.status.value = status.value
|
self.status.value = status.value
|
||||||
|
|
||||||
def _set_current_job(self, job_id: str):
|
def _set_current_job(self, job_id: str):
|
||||||
"""Set current job ID (thread-safe)."""
|
"""Set the current job ID (thread-safe)."""
|
||||||
job_id_bytes = job_id.encode('utf-8')
|
job_id_bytes = job_id.encode('utf-8')
|
||||||
for i, byte in enumerate(job_id_bytes):
|
for i, byte in enumerate(job_id_bytes):
|
||||||
if i < len(self.current_job_id):
|
if i < len(self.current_job_id):
|
||||||
@@ -282,4 +463,31 @@ class Worker:
|
|||||||
def _clear_current_job(self):
|
def _clear_current_job(self):
|
||||||
"""Clear current job ID (thread-safe)."""
|
"""Clear current job ID (thread-safe)."""
|
||||||
for i in range(len(self.current_job_id)):
|
for i in range(len(self.current_job_id)):
|
||||||
self.current_job_id[i] = b'\x00'
|
self.current_job_id[i] = b'\x00'
|
||||||
|
|
||||||
|
def _check_and_queue_transcription(self, job: Job, detected_lang_code: str):
|
||||||
|
"""
|
||||||
|
Check if detected language matches any scan rules and queue transcription job.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job: Completed language detection job
|
||||||
|
detected_lang_code: Detected language code (ISO 639-1, e.g., 'ja', 'en')
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from backend.scanning.library_scanner import library_scanner
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Language detection completed for {job.file_path}: {detected_lang_code}. "
|
||||||
|
f"Checking scan rules..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the scanner's method to check rules and queue transcription
|
||||||
|
library_scanner._check_and_queue_transcription_for_file(
|
||||||
|
job.file_path, detected_lang_code
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error checking scan rules for {job.file_path}: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
|||||||
339
backend/core/worker_pool.py
Normal file
339
backend/core/worker_pool.py
Normal file
@@ -0,0 +1,339 @@
|
|||||||
|
"""Worker pool orchestrator for managing transcription workers."""
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from backend.core.worker import Worker, WorkerType, WorkerStatus
|
||||||
|
from backend.core.queue_manager import queue_manager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerPool:
|
||||||
|
"""
|
||||||
|
Orchestrator for managing a pool of transcription workers.
|
||||||
|
|
||||||
|
Similar to Tdarr's worker management system, this class handles:
|
||||||
|
- Dynamic worker creation/removal (CPU and GPU)
|
||||||
|
- Worker health monitoring
|
||||||
|
- Load balancing via the queue
|
||||||
|
- Worker statistics and reporting
|
||||||
|
- Graceful shutdown
|
||||||
|
|
||||||
|
Workers are managed as separate processes that pull jobs from the
|
||||||
|
persistent queue. The pool can be controlled via WebUI to add/remove
|
||||||
|
workers on-demand.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize worker pool."""
|
||||||
|
self.workers: Dict[str, Worker] = {}
|
||||||
|
self.is_running = False
|
||||||
|
self.started_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
logger.info("WorkerPool initialized")
|
||||||
|
|
||||||
|
def start(self, cpu_workers: int = 0, gpu_workers: int = 0):
|
||||||
|
"""
|
||||||
|
Start the worker pool with specified number of workers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cpu_workers: Number of CPU workers to start
|
||||||
|
gpu_workers: Number of GPU workers to start
|
||||||
|
"""
|
||||||
|
if self.is_running:
|
||||||
|
logger.warning("WorkerPool is already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.is_running = True
|
||||||
|
self.started_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Start CPU workers
|
||||||
|
for i in range(cpu_workers):
|
||||||
|
self.add_worker(WorkerType.CPU)
|
||||||
|
|
||||||
|
# Start GPU workers
|
||||||
|
for i in range(gpu_workers):
|
||||||
|
self.add_worker(WorkerType.GPU, device_id=i % self._get_gpu_count())
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"WorkerPool started: {cpu_workers} CPU workers, {gpu_workers} GPU workers"
|
||||||
|
)
|
||||||
|
|
||||||
|
def stop(self, timeout: float = 30.0):
|
||||||
|
"""
|
||||||
|
Stop all workers gracefully.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Maximum time to wait for each worker to stop
|
||||||
|
"""
|
||||||
|
if not self.is_running:
|
||||||
|
logger.warning("WorkerPool is not running")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Stopping WorkerPool with {len(self.workers)} workers...")
|
||||||
|
|
||||||
|
# Stop all workers
|
||||||
|
for worker_id, worker in list(self.workers.items()):
|
||||||
|
logger.info(f"Stopping worker {worker_id}")
|
||||||
|
worker.stop(timeout=timeout)
|
||||||
|
|
||||||
|
self.workers.clear()
|
||||||
|
self.is_running = False
|
||||||
|
|
||||||
|
logger.info("WorkerPool stopped")
|
||||||
|
|
||||||
|
def add_worker(
|
||||||
|
self,
|
||||||
|
worker_type: WorkerType,
|
||||||
|
device_id: Optional[int] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Add a new worker to the pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worker_type: CPU or GPU
|
||||||
|
device_id: GPU device ID (only for GPU workers)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Worker ID
|
||||||
|
"""
|
||||||
|
# Generate unique worker ID
|
||||||
|
worker_id = self._generate_worker_id(worker_type, device_id)
|
||||||
|
|
||||||
|
if worker_id in self.workers:
|
||||||
|
logger.warning(f"Worker {worker_id} already exists")
|
||||||
|
return worker_id
|
||||||
|
|
||||||
|
# Create and start worker
|
||||||
|
worker = Worker(worker_id, worker_type, device_id)
|
||||||
|
worker.start()
|
||||||
|
|
||||||
|
self.workers[worker_id] = worker
|
||||||
|
|
||||||
|
logger.info(f"Added worker {worker_id} ({worker_type.value})")
|
||||||
|
return worker_id
|
||||||
|
|
||||||
|
def remove_worker(self, worker_id: str, timeout: float = 30.0) -> bool:
|
||||||
|
"""
|
||||||
|
Remove a worker from the pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worker_id: Worker ID to remove
|
||||||
|
timeout: Maximum time to wait for worker to stop
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if worker was removed, False otherwise
|
||||||
|
"""
|
||||||
|
worker = self.workers.get(worker_id)
|
||||||
|
|
||||||
|
if not worker:
|
||||||
|
logger.warning(f"Worker {worker_id} not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info(f"Removing worker {worker_id}")
|
||||||
|
worker.stop(timeout=timeout)
|
||||||
|
|
||||||
|
del self.workers[worker_id]
|
||||||
|
|
||||||
|
logger.info(f"Worker {worker_id} removed")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_worker_status(self, worker_id: str) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Get status of a specific worker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worker_id: Worker ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Worker status dict or None if not found
|
||||||
|
"""
|
||||||
|
worker = self.workers.get(worker_id)
|
||||||
|
if not worker:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return worker.get_status()
|
||||||
|
|
||||||
|
def get_all_workers_status(self) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Get status of all workers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of worker status dicts
|
||||||
|
"""
|
||||||
|
return [worker.get_status() for worker in self.workers.values()]
|
||||||
|
|
||||||
|
def get_pool_stats(self) -> dict:
|
||||||
|
"""
|
||||||
|
Get overall pool statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with pool statistics
|
||||||
|
"""
|
||||||
|
total_workers = len(self.workers)
|
||||||
|
cpu_workers = sum(1 for w in self.workers.values() if w.worker_type == WorkerType.CPU)
|
||||||
|
gpu_workers = sum(1 for w in self.workers.values() if w.worker_type == WorkerType.GPU)
|
||||||
|
|
||||||
|
# Count workers by status
|
||||||
|
idle_workers = 0
|
||||||
|
busy_workers = 0
|
||||||
|
stopped_workers = 0
|
||||||
|
error_workers = 0
|
||||||
|
|
||||||
|
for worker in self.workers.values():
|
||||||
|
status_dict = worker.get_status()
|
||||||
|
status = status_dict["status"] # This is a string like "idle", "busy", etc.
|
||||||
|
|
||||||
|
if status == "idle":
|
||||||
|
idle_workers += 1
|
||||||
|
elif status == "busy":
|
||||||
|
busy_workers += 1
|
||||||
|
elif status == "stopped":
|
||||||
|
stopped_workers += 1
|
||||||
|
elif status == "error":
|
||||||
|
error_workers += 1
|
||||||
|
|
||||||
|
# Get total jobs processed
|
||||||
|
total_completed = sum(w.jobs_completed.value for w in self.workers.values())
|
||||||
|
total_failed = sum(w.jobs_failed.value for w in self.workers.values())
|
||||||
|
|
||||||
|
# Get queue stats
|
||||||
|
queue_stats = queue_manager.get_queue_stats()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"pool": {
|
||||||
|
"is_running": self.is_running,
|
||||||
|
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||||
|
"total_workers": total_workers,
|
||||||
|
"cpu_workers": cpu_workers,
|
||||||
|
"gpu_workers": gpu_workers,
|
||||||
|
"idle_workers": idle_workers,
|
||||||
|
"busy_workers": busy_workers,
|
||||||
|
"stopped_workers": stopped_workers,
|
||||||
|
"error_workers": error_workers,
|
||||||
|
},
|
||||||
|
"jobs": {
|
||||||
|
"completed": total_completed,
|
||||||
|
"failed": total_failed,
|
||||||
|
"success_rate": (
|
||||||
|
total_completed / (total_completed + total_failed) * 100
|
||||||
|
if (total_completed + total_failed) > 0
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"queue": queue_stats,
|
||||||
|
}
|
||||||
|
|
||||||
|
def health_check(self) -> dict:
|
||||||
|
"""
|
||||||
|
Perform health check on all workers.
|
||||||
|
|
||||||
|
Restarts dead workers automatically.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Health check results
|
||||||
|
"""
|
||||||
|
dead_workers = []
|
||||||
|
restarted_workers = []
|
||||||
|
|
||||||
|
for worker_id, worker in list(self.workers.items()):
|
||||||
|
if not worker.is_alive():
|
||||||
|
logger.warning(f"Worker {worker_id} is dead, restarting...")
|
||||||
|
dead_workers.append(worker_id)
|
||||||
|
|
||||||
|
# Try to restart
|
||||||
|
try:
|
||||||
|
worker.start()
|
||||||
|
restarted_workers.append(worker_id)
|
||||||
|
logger.info(f"Worker {worker_id} restarted successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to restart worker {worker_id}: {e}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"total_workers": len(self.workers),
|
||||||
|
"dead_workers": dead_workers,
|
||||||
|
"restarted_workers": restarted_workers,
|
||||||
|
"healthy": len(dead_workers) == 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def auto_scale(self, target_workers: int):
|
||||||
|
"""
|
||||||
|
Auto-scale workers based on queue size.
|
||||||
|
|
||||||
|
This is a placeholder for future auto-scaling logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_workers: Target number of workers
|
||||||
|
"""
|
||||||
|
current_workers = len(self.workers)
|
||||||
|
|
||||||
|
if current_workers < target_workers:
|
||||||
|
# Add workers
|
||||||
|
workers_to_add = target_workers - current_workers
|
||||||
|
logger.info(f"Auto-scaling: adding {workers_to_add} workers")
|
||||||
|
|
||||||
|
for _ in range(workers_to_add):
|
||||||
|
# Default to CPU workers for auto-scaling
|
||||||
|
self.add_worker(WorkerType.CPU)
|
||||||
|
|
||||||
|
elif current_workers > target_workers:
|
||||||
|
# Remove idle workers
|
||||||
|
workers_to_remove = current_workers - target_workers
|
||||||
|
logger.info(f"Auto-scaling: removing {workers_to_remove} workers")
|
||||||
|
|
||||||
|
# Find idle workers to remove
|
||||||
|
idle_workers = [
|
||||||
|
worker_id for worker_id, worker in self.workers.items()
|
||||||
|
if worker.get_status()["status"] == WorkerStatus.IDLE.value
|
||||||
|
]
|
||||||
|
|
||||||
|
for worker_id in idle_workers[:workers_to_remove]:
|
||||||
|
self.remove_worker(worker_id)
|
||||||
|
|
||||||
|
def _generate_worker_id(
|
||||||
|
self,
|
||||||
|
worker_type: WorkerType,
|
||||||
|
device_id: Optional[int] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate unique worker ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worker_type: CPU or GPU
|
||||||
|
device_id: GPU device ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Worker ID string
|
||||||
|
"""
|
||||||
|
prefix = "cpu" if worker_type == WorkerType.CPU else f"gpu{device_id}"
|
||||||
|
|
||||||
|
# Count existing workers of this type
|
||||||
|
existing_count = sum(
|
||||||
|
1 for wid in self.workers.keys()
|
||||||
|
if wid.startswith(prefix)
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"{prefix}-{existing_count + 1}"
|
||||||
|
|
||||||
|
def _get_gpu_count(self) -> int:
|
||||||
|
"""
|
||||||
|
Get number of available GPUs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of GPUs (defaults to 1 if detection fails)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.cuda.device_count()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return 1 # Default to 1 GPU
|
||||||
|
|
||||||
|
|
||||||
|
# Global worker pool instance
|
||||||
|
worker_pool = WorkerPool()
|
||||||
Reference in New Issue
Block a user