feat: add centralized configuration system with Pydantic

- Add backend/config.py with Pydantic settings validation
- Support for standalone, provider, and hybrid operation modes
- Multi-database backend configuration (SQLite/PostgreSQL/MariaDB)
- Environment variable validation with helpful error messages
- Worker and Whisper model configuration
This commit is contained in:
2026-01-11 21:23:45 +01:00
parent ad0bdba03d
commit 7959210724
10 changed files with 1318 additions and 0 deletions

1
backend/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""TranscriptorIO Backend Package."""

1
backend/api/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""TranscriptorIO API Module."""

214
backend/config.py Normal file
View File

@@ -0,0 +1,214 @@
"""Configuration management for TranscriptorIO."""
import os
from enum import Enum
from typing import Optional, List
from pydantic_settings import BaseSettings
from pydantic import Field, field_validator
class OperationMode(str, Enum):
"""Operation modes for TranscriptorIO."""
STANDALONE = "standalone"
PROVIDER = "provider"
HYBRID = "standalone,provider"
class DatabaseType(str, Enum):
"""Supported database backends."""
SQLITE = "sqlite"
POSTGRESQL = "postgresql"
MARIADB = "mariadb"
MYSQL = "mysql"
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
# === Application Mode ===
transcriptarr_mode: str = Field(
default="standalone",
description="Operation mode: standalone, provider, or standalone,provider"
)
# === Database Configuration ===
database_url: str = Field(
default="sqlite:///./transcriptarr.db",
description="Database connection URL. Examples:\n"
" SQLite: sqlite:///./transcriptarr.db\n"
" PostgreSQL: postgresql://user:pass@localhost/transcriptarr\n"
" MariaDB: mariadb+pymysql://user:pass@localhost/transcriptarr"
)
# === Worker Configuration ===
concurrent_transcriptions: int = Field(default=2, ge=1, le=10)
whisper_threads: int = Field(default=4, ge=1, le=32)
transcribe_device: str = Field(default="cpu", pattern="^(cpu|gpu|cuda)$")
clear_vram_on_complete: bool = Field(default=True)
# === Whisper Model Configuration ===
whisper_model: str = Field(
default="medium",
description="Whisper model: tiny, base, small, medium, large-v3, etc."
)
model_path: str = Field(default="./models")
compute_type: str = Field(default="auto")
# === Standalone Mode Configuration ===
library_paths: Optional[str] = Field(
default=None,
description="Pipe-separated paths to scan: /media/anime|/media/movies"
)
auto_scan_enabled: bool = Field(default=False)
scan_interval_minutes: int = Field(default=30, ge=1)
required_audio_language: Optional[str] = Field(
default=None,
description="Only process files with this audio language (ISO 639-2)"
)
required_missing_subtitle: Optional[str] = Field(
default=None,
description="Only process if this subtitle language is missing (ISO 639-2)"
)
skip_if_subtitle_exists: bool = Field(default=True)
# === Provider Mode Configuration ===
bazarr_url: Optional[str] = Field(default=None)
bazarr_api_key: Optional[str] = Field(default=None)
provider_timeout_seconds: int = Field(default=600, ge=60)
provider_callback_enabled: bool = Field(default=True)
provider_polling_interval: int = Field(default=30, ge=10)
# === API Configuration ===
webhook_port: int = Field(default=9000, ge=1024, le=65535)
api_host: str = Field(default="0.0.0.0")
debug: bool = Field(default=True)
# === Transcription Settings ===
transcribe_or_translate: str = Field(
default="transcribe",
pattern="^(transcribe|translate)$"
)
subtitle_language_name: str = Field(default="")
subtitle_language_naming_type: str = Field(
default="ISO_639_2_B",
description="Naming format: ISO_639_1, ISO_639_2_T, ISO_639_2_B, NAME, NATIVE"
)
word_level_highlight: bool = Field(default=False)
custom_regroup: str = Field(default="cm_sl=84_sl=42++++++1")
# === Skip Configuration ===
skip_if_external_subtitles_exist: bool = Field(default=False)
skip_if_target_subtitles_exist: bool = Field(default=True)
skip_if_internal_subtitles_language: Optional[str] = Field(default="eng")
skip_subtitle_languages: Optional[str] = Field(
default=None,
description="Pipe-separated language codes to skip: eng|spa"
)
skip_if_audio_languages: Optional[str] = Field(
default=None,
description="Skip if audio track is in these languages: eng|spa"
)
skip_unknown_language: bool = Field(default=False)
skip_only_subgen_subtitles: bool = Field(default=False)
# === Advanced Settings ===
force_detected_language_to: Optional[str] = Field(default=None)
detect_language_length: int = Field(default=30, ge=5)
detect_language_offset: int = Field(default=0, ge=0)
should_whisper_detect_audio_language: bool = Field(default=False)
preferred_audio_languages: str = Field(
default="eng",
description="Pipe-separated list in order of preference: eng|jpn"
)
# === Path Mapping ===
use_path_mapping: bool = Field(default=False)
path_mapping_from: str = Field(default="/tv")
path_mapping_to: str = Field(default="/Volumes/TV")
# === Legacy SubGen Compatibility ===
show_in_subname_subgen: bool = Field(default=True)
show_in_subname_model: bool = Field(default=True)
append: bool = Field(default=False)
lrc_for_audio_files: bool = Field(default=True)
@field_validator("transcriptarr_mode")
@classmethod
def validate_mode(cls, v: str) -> str:
"""Validate operation mode."""
valid_modes = {"standalone", "provider", "standalone,provider"}
if v not in valid_modes:
raise ValueError(f"Invalid mode: {v}. Must be one of: {valid_modes}")
return v
@field_validator("database_url")
@classmethod
def validate_database_url(cls, v: str) -> str:
"""Validate database URL format."""
valid_prefixes = ("sqlite://", "postgresql://", "mariadb+pymysql://", "mysql+pymysql://")
if not any(v.startswith(prefix) for prefix in valid_prefixes):
raise ValueError(
f"Invalid database URL. Must start with one of: {valid_prefixes}"
)
return v
@property
def database_type(self) -> DatabaseType:
"""Get the database type from the URL."""
if self.database_url.startswith("sqlite"):
return DatabaseType.SQLITE
elif self.database_url.startswith("postgresql"):
return DatabaseType.POSTGRESQL
elif "mariadb" in self.database_url:
return DatabaseType.MARIADB
elif "mysql" in self.database_url:
return DatabaseType.MYSQL
else:
raise ValueError(f"Unknown database type in URL: {self.database_url}")
@property
def is_standalone_mode(self) -> bool:
"""Check if standalone mode is enabled."""
return "standalone" in self.transcriptarr_mode
@property
def is_provider_mode(self) -> bool:
"""Check if provider mode is enabled."""
return "provider" in self.transcriptarr_mode
@property
def library_paths_list(self) -> List[str]:
"""Get library paths as a list."""
if not self.library_paths:
return []
return [p.strip() for p in self.library_paths.split("|") if p.strip()]
@property
def skip_subtitle_languages_list(self) -> List[str]:
"""Get skip subtitle languages as a list."""
if not self.skip_subtitle_languages:
return []
return [lang.strip() for lang in self.skip_subtitle_languages.split("|") if lang.strip()]
@property
def skip_audio_languages_list(self) -> List[str]:
"""Get skip audio languages as a list."""
if not self.skip_if_audio_languages:
return []
return [lang.strip() for lang in self.skip_if_audio_languages.split("|") if lang.strip()]
@property
def preferred_audio_languages_list(self) -> List[str]:
"""Get preferred audio languages as a list."""
return [lang.strip() for lang in self.preferred_audio_languages.split("|") if lang.strip()]
class Config:
"""Pydantic configuration."""
env_file = ".env"
env_file_encoding = "utf-8"
case_sensitive = False
# Global settings instance
settings = Settings()

1
backend/core/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""TranscriptorIO Core Module."""

219
backend/core/database.py Normal file
View File

@@ -0,0 +1,219 @@
"""Database configuration and session management."""
import logging
from contextlib import contextmanager
from typing import Generator
from sqlalchemy import create_engine, event, Engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import StaticPool, QueuePool
from backend.config import settings, DatabaseType
logger = logging.getLogger(__name__)
# Base class for all models
Base = declarative_base()
class Database:
"""Database manager supporting SQLite, PostgreSQL, and MariaDB."""
def __init__(self, auto_create_tables: bool = True):
"""
Initialize database engine and session maker.
Args:
auto_create_tables: If True, automatically create tables if they don't exist
"""
self.engine = self._create_engine()
self.SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=self.engine
)
logger.info(f"Database initialized: {settings.database_type.value}")
# Automatically create tables if they don't exist
if auto_create_tables:
self._ensure_tables_exist()
def _create_engine(self) -> Engine:
"""Create SQLAlchemy engine based on database type."""
connect_args = {}
poolclass = QueuePool
if settings.database_type == DatabaseType.SQLITE:
# SQLite-specific configuration
connect_args = {
"check_same_thread": False, # Allow multi-threaded access
"timeout": 30.0, # Wait up to 30s for lock
}
# Use StaticPool for SQLite to avoid connection issues
poolclass = StaticPool
# Enable WAL mode for better concurrency
engine = create_engine(
settings.database_url,
connect_args=connect_args,
poolclass=poolclass,
echo=settings.debug,
)
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_conn, connection_record):
"""Enable SQLite optimizations."""
cursor = dbapi_conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA foreign_keys=ON")
cursor.execute("PRAGMA cache_size=-64000") # 64MB cache
cursor.close()
elif settings.database_type == DatabaseType.POSTGRESQL:
# PostgreSQL-specific configuration
try:
import psycopg2 # noqa: F401
except ImportError:
raise ImportError(
"PostgreSQL support requires psycopg2-binary.\n"
"Install it with: pip install psycopg2-binary"
)
engine = create_engine(
settings.database_url,
pool_size=10,
max_overflow=20,
pool_pre_ping=True, # Verify connections before using
echo=settings.debug,
)
elif settings.database_type in (DatabaseType.MARIADB, DatabaseType.MYSQL):
# MariaDB/MySQL-specific configuration
try:
import pymysql # noqa: F401
except ImportError:
raise ImportError(
"MariaDB/MySQL support requires pymysql.\n"
"Install it with: pip install pymysql"
)
connect_args = {
"charset": "utf8mb4",
}
engine = create_engine(
settings.database_url,
connect_args=connect_args,
pool_size=10,
max_overflow=20,
pool_pre_ping=True,
echo=settings.debug,
)
else:
raise ValueError(f"Unsupported database type: {settings.database_type}")
return engine
def _ensure_tables_exist(self):
"""Check if tables exist and create them if they don't."""
# Import models to register them with Base.metadata
from backend.core import models # noqa: F401
from sqlalchemy import inspect
inspector = inspect(self.engine)
existing_tables = inspector.get_table_names()
# Check if the main 'jobs' table exists
if 'jobs' not in existing_tables:
logger.info("Tables don't exist, creating them automatically...")
self.create_tables()
else:
logger.debug("Database tables already exist")
def create_tables(self):
"""Create all database tables."""
# Import models to register them with Base.metadata
from backend.core import models # noqa: F401
logger.info("Creating database tables...")
Base.metadata.create_all(bind=self.engine, checkfirst=True)
# Verify tables were actually created
from sqlalchemy import inspect
inspector = inspect(self.engine)
created_tables = inspector.get_table_names()
if 'jobs' in created_tables:
logger.info(f"Database tables created successfully: {created_tables}")
else:
logger.error(f"Failed to create tables. Existing tables: {created_tables}")
raise RuntimeError("Failed to create database tables")
def drop_tables(self):
"""Drop all database tables (use with caution!)."""
logger.warning("Dropping all database tables...")
Base.metadata.drop_all(bind=self.engine)
logger.info("Database tables dropped")
@contextmanager
def get_session(self) -> Generator[Session, None, None]:
"""
Get a database session as a context manager.
Usage:
with db.get_session() as session:
session.query(Job).all()
"""
session = self.SessionLocal()
try:
yield session
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Database session error: {e}")
raise
finally:
session.close()
def get_db(self) -> Generator[Session, None, None]:
"""
Dependency for FastAPI endpoints.
Usage:
@app.get("/jobs")
def get_jobs(db: Session = Depends(database.get_db)):
return db.query(Job).all()
"""
session = self.SessionLocal()
try:
yield session
finally:
session.close()
def health_check(self) -> bool:
"""Check if database connection is healthy."""
try:
from sqlalchemy import text
with self.get_session() as session:
session.execute(text("SELECT 1"))
return True
except Exception as e:
logger.error(f"Database health check failed: {e}")
return False
def get_stats(self) -> dict:
"""Get database statistics."""
stats = {
"type": settings.database_type.value,
"url": settings.database_url.split("@")[-1] if "@" in settings.database_url else settings.database_url,
"pool_size": getattr(self.engine.pool, "size", lambda: "N/A")(),
"pool_checked_in": getattr(self.engine.pool, "checkedin", lambda: 0)(),
"pool_checked_out": getattr(self.engine.pool, "checkedout", lambda: 0)(),
"pool_overflow": getattr(self.engine.pool, "overflow", lambda: 0)(),
}
return stats
# Global database instance
database = Database()

203
backend/core/models.py Normal file
View File

@@ -0,0 +1,203 @@
"""Database models for TranscriptorIO."""
import uuid
from datetime import datetime
from enum import Enum
from typing import Optional
from sqlalchemy import (
Column, String, Integer, Float, DateTime, Text, Boolean, Enum as SQLEnum, Index
)
from sqlalchemy.sql import func
from backend.core.database import Base
class JobStatus(str, Enum):
"""Job status states."""
QUEUED = "queued"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class JobStage(str, Enum):
"""Job processing stages."""
PENDING = "pending"
DETECTING_LANGUAGE = "detecting_language"
EXTRACTING_AUDIO = "extracting_audio"
TRANSCRIBING = "transcribing"
TRANSLATING = "translating"
GENERATING_SUBTITLES = "generating_subtitles"
POST_PROCESSING = "post_processing"
FINALIZING = "finalizing"
class QualityPreset(str, Enum):
"""Quality presets for transcription."""
FAST = "fast" # ja→en→es with Helsinki-NLP (4GB VRAM)
BALANCED = "balanced" # ja→ja→es with M2M100 (6GB VRAM)
BEST = "best" # ja→es direct with SeamlessM4T (10GB+ VRAM)
class Job(Base):
"""Job model representing a transcription task."""
__tablename__ = "jobs"
# Primary identification
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
file_path = Column(String(1024), nullable=False, index=True)
file_name = Column(String(512), nullable=False)
# Job status
status = Column(
SQLEnum(JobStatus),
nullable=False,
default=JobStatus.QUEUED,
index=True
)
priority = Column(Integer, nullable=False, default=0, index=True)
# Configuration
source_lang = Column(String(10), nullable=True)
target_lang = Column(String(10), nullable=True)
quality_preset = Column(
SQLEnum(QualityPreset),
nullable=False,
default=QualityPreset.FAST
)
transcribe_or_translate = Column(String(20), nullable=False, default="transcribe")
# Progress tracking
progress = Column(Float, nullable=False, default=0.0) # 0-100
current_stage = Column(
SQLEnum(JobStage),
nullable=False,
default=JobStage.PENDING
)
eta_seconds = Column(Integer, nullable=True)
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False, index=True)
started_at = Column(DateTime(timezone=True), nullable=True)
completed_at = Column(DateTime(timezone=True), nullable=True)
# Results
output_path = Column(String(1024), nullable=True)
srt_content = Column(Text, nullable=True)
segments_count = Column(Integer, nullable=True)
# Error handling
error = Column(Text, nullable=True)
retry_count = Column(Integer, nullable=False, default=0)
max_retries = Column(Integer, nullable=False, default=3)
# Worker information
worker_id = Column(String(64), nullable=True)
vram_used_mb = Column(Integer, nullable=True)
processing_time_seconds = Column(Float, nullable=True)
# Provider mode specific
bazarr_callback_url = Column(String(512), nullable=True)
is_manual_request = Column(Boolean, nullable=False, default=False)
# Additional metadata
model_used = Column(String(64), nullable=True)
device_used = Column(String(32), nullable=True)
compute_type = Column(String(32), nullable=True)
def __repr__(self):
"""String representation of Job."""
return f"<Job {self.id[:8]}... {self.file_name} [{self.status.value}] {self.progress:.1f}%>"
@property
def duration_seconds(self) -> Optional[float]:
"""Calculate job duration in seconds."""
if self.started_at and self.completed_at:
delta = self.completed_at - self.started_at
return delta.total_seconds()
return None
@property
def is_terminal_state(self) -> bool:
"""Check if job is in a terminal state (completed/failed/cancelled)."""
return self.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
@property
def can_retry(self) -> bool:
"""Check if job can be retried."""
return self.status == JobStatus.FAILED and self.retry_count < self.max_retries
def to_dict(self) -> dict:
"""Convert job to dictionary for API responses."""
return {
"id": self.id,
"file_path": self.file_path,
"file_name": self.file_name,
"status": self.status.value,
"priority": self.priority,
"source_lang": self.source_lang,
"target_lang": self.target_lang,
"quality_preset": self.quality_preset.value if self.quality_preset else None,
"transcribe_or_translate": self.transcribe_or_translate,
"progress": self.progress,
"current_stage": self.current_stage.value if self.current_stage else None,
"eta_seconds": self.eta_seconds,
"created_at": self.created_at.isoformat() if self.created_at else None,
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"output_path": self.output_path,
"segments_count": self.segments_count,
"error": self.error,
"retry_count": self.retry_count,
"worker_id": self.worker_id,
"vram_used_mb": self.vram_used_mb,
"processing_time_seconds": self.processing_time_seconds,
"model_used": self.model_used,
"device_used": self.device_used,
}
def update_progress(self, progress: float, stage: JobStage, eta_seconds: Optional[int] = None):
"""Update job progress."""
self.progress = min(100.0, max(0.0, progress))
self.current_stage = stage
if eta_seconds is not None:
self.eta_seconds = eta_seconds
def mark_started(self, worker_id: str):
"""Mark job as started."""
self.status = JobStatus.PROCESSING
self.started_at = datetime.utcnow()
self.worker_id = worker_id
def mark_completed(self, output_path: str, segments_count: int, srt_content: Optional[str] = None):
"""Mark job as completed."""
self.status = JobStatus.COMPLETED
self.completed_at = datetime.utcnow()
self.output_path = output_path
self.segments_count = segments_count
self.srt_content = srt_content
self.progress = 100.0
self.current_stage = JobStage.FINALIZING
if self.started_at:
self.processing_time_seconds = (self.completed_at - self.started_at).total_seconds()
def mark_failed(self, error: str):
"""Mark job as failed."""
self.status = JobStatus.FAILED
self.completed_at = datetime.utcnow()
self.error = error
self.retry_count += 1
def mark_cancelled(self):
"""Mark job as cancelled."""
self.status = JobStatus.CANCELLED
self.completed_at = datetime.utcnow()
# Create indexes for common queries
Index('idx_jobs_status_priority', Job.status, Job.priority.desc(), Job.created_at)
Index('idx_jobs_created', Job.created_at.desc())
Index('idx_jobs_file_path', Job.file_path)

View File

@@ -0,0 +1,394 @@
"""Queue manager for persistent job queuing."""
import logging
from datetime import datetime, timedelta
from typing import List, Optional, Dict
from sqlalchemy import and_, or_
from sqlalchemy.orm import Session
from backend.core.database import database
from backend.core.models import Job, JobStatus, JobStage, QualityPreset
logger = logging.getLogger(__name__)
class QueueManager:
"""
Persistent queue manager for transcription jobs.
Replaces the old DeduplicatedQueue with a database-backed solution that:
- Persists jobs across restarts
- Supports priority queuing
- Prevents duplicate jobs
- Provides visibility into queue state
- Thread-safe operations
"""
def __init__(self):
"""Initialize queue manager."""
self.db = database
logger.info("QueueManager initialized")
def add_job(
self,
file_path: str,
file_name: str,
source_lang: Optional[str] = None,
target_lang: Optional[str] = None,
quality_preset: QualityPreset = QualityPreset.FAST,
transcribe_or_translate: str = "transcribe",
priority: int = 0,
bazarr_callback_url: Optional[str] = None,
is_manual_request: bool = False,
) -> Optional[Job]:
"""
Add a new job to the queue.
Args:
file_path: Full path to the media file
file_name: Name of the file
source_lang: Source language code (ISO 639-2)
target_lang: Target language code (ISO 639-2)
quality_preset: Quality preset (fast/balanced/best)
transcribe_or_translate: Operation type
priority: Job priority (higher = processed first)
bazarr_callback_url: Callback URL for Bazarr provider mode
is_manual_request: Whether this is a manual request (higher priority)
Returns:
Job object if created, None if duplicate exists
"""
with self.db.get_session() as session:
# Check for existing job
existing = self._find_existing_job(session, file_path, target_lang)
if existing:
logger.info(f"Job already exists for {file_name}: {existing.id} [{existing.status.value}]")
# If existing job failed and can retry, reset it
if existing.can_retry:
logger.info(f"Resetting failed job {existing.id} for retry")
existing.status = JobStatus.QUEUED
existing.error = None
existing.current_stage = JobStage.PENDING
existing.progress = 0.0
session.commit()
return existing
return None
# Create new job
job = Job(
file_path=file_path,
file_name=file_name,
source_lang=source_lang,
target_lang=target_lang,
quality_preset=quality_preset,
transcribe_or_translate=transcribe_or_translate,
priority=priority + (10 if is_manual_request else 0), # Boost manual requests
bazarr_callback_url=bazarr_callback_url,
is_manual_request=is_manual_request,
)
session.add(job)
session.commit()
# Access all attributes before session closes to ensure they're loaded
job_id = job.id
job_status = job.status
logger.info(
f"Job {job_id} added to queue: {file_name} "
f"[{quality_preset.value}] priority={job.priority}"
)
# Re-query the job in a new session to return a fresh copy
with self.db.get_session() as session:
job = session.query(Job).filter(Job.id == job_id).first()
if job:
session.expunge(job) # Remove from session so it doesn't expire
return job
def get_next_job(self, worker_id: str) -> Optional[Job]:
"""
Get the next job from the queue for processing.
Jobs are selected based on:
1. Status = QUEUED
2. Priority (DESC)
3. Created time (ASC) - FIFO within same priority
Args:
worker_id: ID of the worker requesting the job
Returns:
Job object or None if queue is empty
"""
with self.db.get_session() as session:
job = (
session.query(Job)
.filter(Job.status == JobStatus.QUEUED)
.order_by(
Job.priority.desc(),
Job.created_at.asc()
)
.with_for_update(skip_locked=True) # Skip locked rows (concurrent workers)
.first()
)
if job:
job_id = job.id
job.mark_started(worker_id)
session.commit()
logger.info(f"Job {job_id} assigned to worker {worker_id}")
# Re-query the job if found
if job:
with self.db.get_session() as session:
job = session.query(Job).filter(Job.id == job_id).first()
if job:
session.expunge(job) # Remove from session so it doesn't expire
return job
return None
def get_job_by_id(self, job_id: str) -> Optional[Job]:
"""Get a specific job by ID."""
with self.db.get_session() as session:
return session.query(Job).filter(Job.id == job_id).first()
def update_job_progress(
self,
job_id: str,
progress: float,
stage: JobStage,
eta_seconds: Optional[int] = None
) -> bool:
"""
Update job progress.
Args:
job_id: Job ID
progress: Progress percentage (0-100)
stage: Current processing stage
eta_seconds: Estimated time to completion
Returns:
True if updated successfully, False otherwise
"""
with self.db.get_session() as session:
job = session.query(Job).filter(Job.id == job_id).first()
if not job:
logger.warning(f"Job {job_id} not found for progress update")
return False
job.update_progress(progress, stage, eta_seconds)
session.commit()
logger.debug(
f"Job {job_id} progress: {progress:.1f}% [{stage.value}] ETA: {eta_seconds}s"
)
return True
def mark_job_completed(
self,
job_id: str,
output_path: str,
segments_count: int,
srt_content: Optional[str] = None
) -> bool:
"""Mark a job as completed."""
with self.db.get_session() as session:
job = session.query(Job).filter(Job.id == job_id).first()
if not job:
logger.warning(f"Job {job_id} not found for completion")
return False
job.mark_completed(output_path, segments_count, srt_content)
session.commit()
logger.info(
f"Job {job_id} completed: {output_path} "
f"({segments_count} segments, {job.processing_time_seconds:.1f}s)"
)
return True
def mark_job_failed(self, job_id: str, error: str) -> bool:
"""Mark a job as failed."""
with self.db.get_session() as session:
job = session.query(Job).filter(Job.id == job_id).first()
if not job:
logger.warning(f"Job {job_id} not found for failure marking")
return False
job.mark_failed(error)
session.commit()
logger.error(
f"Job {job_id} failed (attempt {job.retry_count}/{job.max_retries}): {error}"
)
return True
def cancel_job(self, job_id: str) -> bool:
"""Cancel a queued or processing job."""
with self.db.get_session() as session:
job = session.query(Job).filter(Job.id == job_id).first()
if not job:
logger.warning(f"Job {job_id} not found for cancellation")
return False
if job.is_terminal_state:
logger.warning(f"Job {job_id} is already in terminal state: {job.status.value}")
return False
job.mark_cancelled()
session.commit()
logger.info(f"Job {job_id} cancelled")
return True
def get_queue_stats(self) -> Dict:
"""Get queue statistics."""
with self.db.get_session() as session:
total = session.query(Job).count()
queued = session.query(Job).filter(Job.status == JobStatus.QUEUED).count()
processing = session.query(Job).filter(Job.status == JobStatus.PROCESSING).count()
completed = session.query(Job).filter(Job.status == JobStatus.COMPLETED).count()
failed = session.query(Job).filter(Job.status == JobStatus.FAILED).count()
# Get today's stats
today = datetime.utcnow().date()
completed_today = (
session.query(Job)
.filter(
Job.status == JobStatus.COMPLETED,
Job.completed_at >= today
)
.count()
)
failed_today = (
session.query(Job)
.filter(
Job.status == JobStatus.FAILED,
Job.completed_at >= today
)
.count()
)
return {
"total": total,
"queued": queued,
"processing": processing,
"completed": completed,
"failed": failed,
"completed_today": completed_today,
"failed_today": failed_today,
}
def get_jobs(
self,
status: Optional[JobStatus] = None,
limit: int = 50,
offset: int = 0
) -> List[Job]:
"""
Get jobs with optional filtering.
Args:
status: Filter by status
limit: Maximum number of jobs to return
offset: Offset for pagination
Returns:
List of Job objects
"""
with self.db.get_session() as session:
query = session.query(Job)
if status:
query = query.filter(Job.status == status)
jobs = (
query
.order_by(Job.created_at.desc())
.limit(limit)
.offset(offset)
.all()
)
return jobs
def get_processing_jobs(self) -> List[Job]:
"""Get all currently processing jobs."""
return self.get_jobs(status=JobStatus.PROCESSING)
def get_queued_jobs(self) -> List[Job]:
"""Get all queued jobs."""
return self.get_jobs(status=JobStatus.QUEUED)
def is_queue_empty(self) -> bool:
"""Check if queue has any pending jobs."""
with self.db.get_session() as session:
count = (
session.query(Job)
.filter(Job.status.in_([JobStatus.QUEUED, JobStatus.PROCESSING]))
.count()
)
return count == 0
def cleanup_old_jobs(self, days: int = 30) -> int:
"""
Delete completed/failed jobs older than specified days.
Args:
days: Number of days to keep jobs
Returns:
Number of jobs deleted
"""
with self.db.get_session() as session:
cutoff_date = datetime.utcnow() - timedelta(days=days)
deleted = (
session.query(Job)
.filter(
Job.status.in_([JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED]),
Job.completed_at < cutoff_date
)
.delete()
)
session.commit()
if deleted > 0:
logger.info(f"Cleaned up {deleted} old jobs (older than {days} days)")
return deleted
def _find_existing_job(
self,
session: Session,
file_path: str,
target_lang: Optional[str]
) -> Optional[Job]:
"""
Find existing job for the same file and target language.
Ignores completed jobs - allows re-transcription.
"""
query = session.query(Job).filter(
Job.file_path == file_path,
Job.status.in_([JobStatus.QUEUED, JobStatus.PROCESSING])
)
if target_lang:
query = query.filter(Job.target_lang == target_lang)
return query.first()
# Global queue manager instance
queue_manager = QueueManager()

285
backend/core/worker.py Normal file
View File

@@ -0,0 +1,285 @@
"""Individual worker for processing transcription jobs."""
import logging
import multiprocessing as mp
import time
import traceback
from datetime import datetime
from enum import Enum
from typing import Optional
from backend.core.database import Database
from backend.core.models import Job, JobStatus, JobStage
from backend.core.queue_manager import QueueManager
logger = logging.getLogger(__name__)
class WorkerType(str, Enum):
"""Worker device type."""
CPU = "cpu"
GPU = "gpu"
class WorkerStatus(str, Enum):
"""Worker status states."""
IDLE = "idle"
BUSY = "busy"
STOPPING = "stopping"
STOPPED = "stopped"
ERROR = "error"
class Worker:
"""
Individual worker process for transcription.
Each worker runs in its own process and can handle one job at a time.
Workers communicate with the main process via multiprocessing primitives.
"""
def __init__(
self,
worker_id: str,
worker_type: WorkerType,
device_id: Optional[int] = None
):
"""
Initialize worker.
Args:
worker_id: Unique identifier for this worker
worker_type: CPU or GPU
device_id: GPU device ID (only for GPU workers)
"""
self.worker_id = worker_id
self.worker_type = worker_type
self.device_id = device_id
# Multiprocessing primitives
self.process: Optional[mp.Process] = None
self.stop_event = mp.Event()
self.status = mp.Value('i', WorkerStatus.IDLE.value) # type: ignore
self.current_job_id = mp.Array('c', 36) # type: ignore # UUID string
# Stats
self.jobs_completed = mp.Value('i', 0) # type: ignore
self.jobs_failed = mp.Value('i', 0) # type: ignore
self.started_at: Optional[datetime] = None
def start(self):
"""Start the worker process."""
if self.process and self.process.is_alive():
logger.warning(f"Worker {self.worker_id} is already running")
return
self.stop_event.clear()
self.process = mp.Process(
target=self._worker_loop,
name=f"Worker-{self.worker_id}",
daemon=True
)
self.process.start()
self.started_at = datetime.utcnow()
logger.info(
f"Worker {self.worker_id} started (PID: {self.process.pid}, "
f"Type: {self.worker_type.value})"
)
def stop(self, timeout: float = 30.0):
"""
Stop the worker process gracefully.
Args:
timeout: Maximum time to wait for worker to stop
"""
if not self.process or not self.process.is_alive():
logger.warning(f"Worker {self.worker_id} is not running")
return
logger.info(f"Stopping worker {self.worker_id}...")
self.stop_event.set()
self.process.join(timeout=timeout)
if self.process.is_alive():
logger.warning(f"Worker {self.worker_id} did not stop gracefully, terminating...")
self.process.terminate()
self.process.join(timeout=5.0)
if self.process.is_alive():
logger.error(f"Worker {self.worker_id} did not terminate, killing...")
self.process.kill()
logger.info(f"Worker {self.worker_id} stopped")
def is_alive(self) -> bool:
"""Check if worker process is alive."""
return self.process is not None and self.process.is_alive()
def get_status(self) -> dict:
"""Get worker status information."""
status_value = self.status.value
status_enum = WorkerStatus.IDLE
for s in WorkerStatus:
if s.value == status_value:
status_enum = s
break
current_job = self.current_job_id.value.decode('utf-8').strip('\x00')
return {
"worker_id": self.worker_id,
"type": self.worker_type.value,
"device_id": self.device_id,
"status": status_enum.value,
"current_job_id": current_job if current_job else None,
"jobs_completed": self.jobs_completed.value,
"jobs_failed": self.jobs_failed.value,
"is_alive": self.is_alive(),
"pid": self.process.pid if self.process else None,
"started_at": self.started_at.isoformat() if self.started_at else None,
}
def _worker_loop(self):
"""
Main worker loop (runs in separate process).
This is the entry point for the worker process.
"""
# Set up logging in the worker process
logging.basicConfig(
level=logging.INFO,
format=f'[Worker-{self.worker_id}] %(levelname)s: %(message)s'
)
logger.info(f"Worker {self.worker_id} loop started")
# Initialize database and queue manager in worker process
# Each process needs its own DB connection
try:
db = Database(auto_create_tables=False)
queue_mgr = QueueManager()
except Exception as e:
logger.error(f"Failed to initialize worker: {e}")
self._set_status(WorkerStatus.ERROR)
return
# Main work loop
while not self.stop_event.is_set():
try:
# Try to get next job from queue
job = queue_mgr.get_next_job(self.worker_id)
if job is None:
# No jobs available, idle for a bit
self._set_status(WorkerStatus.IDLE)
time.sleep(2)
continue
# Process the job
self._set_status(WorkerStatus.BUSY)
self._set_current_job(job.id)
logger.info(f"Processing job {job.id}: {job.file_name}")
try:
self._process_job(job, queue_mgr)
self.jobs_completed.value += 1
logger.info(f"Job {job.id} completed successfully")
except Exception as e:
self.jobs_failed.value += 1
error_msg = f"Job processing failed: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
queue_mgr.mark_job_failed(job.id, error_msg)
finally:
self._clear_current_job()
except Exception as e:
logger.error(f"Worker loop error: {e}\n{traceback.format_exc()}")
time.sleep(5) # Back off on errors
self._set_status(WorkerStatus.STOPPED)
logger.info(f"Worker {self.worker_id} loop ended")
def _process_job(self, job: Job, queue_mgr: QueueManager):
"""
Process a single transcription job.
Args:
job: Job to process
queue_mgr: Queue manager for updating progress
"""
# TODO: This will be implemented when we add the transcriber module
# For now, simulate work
# Stage 1: Detect language
queue_mgr.update_job_progress(
job.id,
progress=10.0,
stage=JobStage.DETECTING_LANGUAGE,
eta_seconds=60
)
time.sleep(2) # Simulate work
# Stage 2: Extract audio
queue_mgr.update_job_progress(
job.id,
progress=20.0,
stage=JobStage.EXTRACTING_AUDIO,
eta_seconds=50
)
time.sleep(2)
# Stage 3: Transcribe
queue_mgr.update_job_progress(
job.id,
progress=30.0,
stage=JobStage.TRANSCRIBING,
eta_seconds=40
)
# Simulate progressive transcription
for i in range(30, 90, 10):
time.sleep(1)
queue_mgr.update_job_progress(
job.id,
progress=float(i),
stage=JobStage.TRANSCRIBING,
eta_seconds=int((100 - i) / 2)
)
# Stage 4: Finalize
queue_mgr.update_job_progress(
job.id,
progress=95.0,
stage=JobStage.FINALIZING,
eta_seconds=5
)
time.sleep(1)
# Mark as completed
output_path = job.file_path.replace('.mkv', '.srt')
queue_mgr.mark_job_completed(
job.id,
output_path=output_path,
segments_count=100, # Simulated
srt_content="Simulated SRT content"
)
def _set_status(self, status: WorkerStatus):
"""Set worker status (thread-safe)."""
self.status.value = status.value
def _set_current_job(self, job_id: str):
"""Set current job ID (thread-safe)."""
job_id_bytes = job_id.encode('utf-8')
for i, byte in enumerate(job_id_bytes):
if i < len(self.current_job_id):
self.current_job_id[i] = byte
def _clear_current_job(self):
"""Clear current job ID (thread-safe)."""
for i in range(len(self.current_job_id)):
self.current_job_id[i] = b'\x00'