feat: add centralized configuration system with Pydantic
- Add backend/config.py with Pydantic settings validation - Support for standalone, provider, and hybrid operation modes - Multi-database backend configuration (SQLite/PostgreSQL/MariaDB) - Environment variable validation with helpful error messages - Worker and Whisper model configuration
This commit is contained in:
1
backend/__init__.py
Normal file
1
backend/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""TranscriptorIO Backend Package."""
|
||||||
1
backend/api/__init__.py
Normal file
1
backend/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""TranscriptorIO API Module."""
|
||||||
214
backend/config.py
Normal file
214
backend/config.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""Configuration management for TranscriptorIO."""
|
||||||
|
import os
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, List
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
class OperationMode(str, Enum):
|
||||||
|
"""Operation modes for TranscriptorIO."""
|
||||||
|
STANDALONE = "standalone"
|
||||||
|
PROVIDER = "provider"
|
||||||
|
HYBRID = "standalone,provider"
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseType(str, Enum):
|
||||||
|
"""Supported database backends."""
|
||||||
|
SQLITE = "sqlite"
|
||||||
|
POSTGRESQL = "postgresql"
|
||||||
|
MARIADB = "mariadb"
|
||||||
|
MYSQL = "mysql"
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Application settings loaded from environment variables."""
|
||||||
|
|
||||||
|
# === Application Mode ===
|
||||||
|
transcriptarr_mode: str = Field(
|
||||||
|
default="standalone",
|
||||||
|
description="Operation mode: standalone, provider, or standalone,provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
# === Database Configuration ===
|
||||||
|
database_url: str = Field(
|
||||||
|
default="sqlite:///./transcriptarr.db",
|
||||||
|
description="Database connection URL. Examples:\n"
|
||||||
|
" SQLite: sqlite:///./transcriptarr.db\n"
|
||||||
|
" PostgreSQL: postgresql://user:pass@localhost/transcriptarr\n"
|
||||||
|
" MariaDB: mariadb+pymysql://user:pass@localhost/transcriptarr"
|
||||||
|
)
|
||||||
|
|
||||||
|
# === Worker Configuration ===
|
||||||
|
concurrent_transcriptions: int = Field(default=2, ge=1, le=10)
|
||||||
|
whisper_threads: int = Field(default=4, ge=1, le=32)
|
||||||
|
transcribe_device: str = Field(default="cpu", pattern="^(cpu|gpu|cuda)$")
|
||||||
|
clear_vram_on_complete: bool = Field(default=True)
|
||||||
|
|
||||||
|
# === Whisper Model Configuration ===
|
||||||
|
whisper_model: str = Field(
|
||||||
|
default="medium",
|
||||||
|
description="Whisper model: tiny, base, small, medium, large-v3, etc."
|
||||||
|
)
|
||||||
|
model_path: str = Field(default="./models")
|
||||||
|
compute_type: str = Field(default="auto")
|
||||||
|
|
||||||
|
# === Standalone Mode Configuration ===
|
||||||
|
library_paths: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Pipe-separated paths to scan: /media/anime|/media/movies"
|
||||||
|
)
|
||||||
|
auto_scan_enabled: bool = Field(default=False)
|
||||||
|
scan_interval_minutes: int = Field(default=30, ge=1)
|
||||||
|
|
||||||
|
required_audio_language: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Only process files with this audio language (ISO 639-2)"
|
||||||
|
)
|
||||||
|
required_missing_subtitle: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Only process if this subtitle language is missing (ISO 639-2)"
|
||||||
|
)
|
||||||
|
skip_if_subtitle_exists: bool = Field(default=True)
|
||||||
|
|
||||||
|
# === Provider Mode Configuration ===
|
||||||
|
bazarr_url: Optional[str] = Field(default=None)
|
||||||
|
bazarr_api_key: Optional[str] = Field(default=None)
|
||||||
|
provider_timeout_seconds: int = Field(default=600, ge=60)
|
||||||
|
provider_callback_enabled: bool = Field(default=True)
|
||||||
|
provider_polling_interval: int = Field(default=30, ge=10)
|
||||||
|
|
||||||
|
# === API Configuration ===
|
||||||
|
webhook_port: int = Field(default=9000, ge=1024, le=65535)
|
||||||
|
api_host: str = Field(default="0.0.0.0")
|
||||||
|
debug: bool = Field(default=True)
|
||||||
|
|
||||||
|
# === Transcription Settings ===
|
||||||
|
transcribe_or_translate: str = Field(
|
||||||
|
default="transcribe",
|
||||||
|
pattern="^(transcribe|translate)$"
|
||||||
|
)
|
||||||
|
subtitle_language_name: str = Field(default="")
|
||||||
|
subtitle_language_naming_type: str = Field(
|
||||||
|
default="ISO_639_2_B",
|
||||||
|
description="Naming format: ISO_639_1, ISO_639_2_T, ISO_639_2_B, NAME, NATIVE"
|
||||||
|
)
|
||||||
|
word_level_highlight: bool = Field(default=False)
|
||||||
|
custom_regroup: str = Field(default="cm_sl=84_sl=42++++++1")
|
||||||
|
|
||||||
|
# === Skip Configuration ===
|
||||||
|
skip_if_external_subtitles_exist: bool = Field(default=False)
|
||||||
|
skip_if_target_subtitles_exist: bool = Field(default=True)
|
||||||
|
skip_if_internal_subtitles_language: Optional[str] = Field(default="eng")
|
||||||
|
skip_subtitle_languages: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Pipe-separated language codes to skip: eng|spa"
|
||||||
|
)
|
||||||
|
skip_if_audio_languages: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Skip if audio track is in these languages: eng|spa"
|
||||||
|
)
|
||||||
|
skip_unknown_language: bool = Field(default=False)
|
||||||
|
skip_only_subgen_subtitles: bool = Field(default=False)
|
||||||
|
|
||||||
|
# === Advanced Settings ===
|
||||||
|
force_detected_language_to: Optional[str] = Field(default=None)
|
||||||
|
detect_language_length: int = Field(default=30, ge=5)
|
||||||
|
detect_language_offset: int = Field(default=0, ge=0)
|
||||||
|
should_whisper_detect_audio_language: bool = Field(default=False)
|
||||||
|
|
||||||
|
preferred_audio_languages: str = Field(
|
||||||
|
default="eng",
|
||||||
|
description="Pipe-separated list in order of preference: eng|jpn"
|
||||||
|
)
|
||||||
|
|
||||||
|
# === Path Mapping ===
|
||||||
|
use_path_mapping: bool = Field(default=False)
|
||||||
|
path_mapping_from: str = Field(default="/tv")
|
||||||
|
path_mapping_to: str = Field(default="/Volumes/TV")
|
||||||
|
|
||||||
|
# === Legacy SubGen Compatibility ===
|
||||||
|
show_in_subname_subgen: bool = Field(default=True)
|
||||||
|
show_in_subname_model: bool = Field(default=True)
|
||||||
|
append: bool = Field(default=False)
|
||||||
|
lrc_for_audio_files: bool = Field(default=True)
|
||||||
|
|
||||||
|
@field_validator("transcriptarr_mode")
|
||||||
|
@classmethod
|
||||||
|
def validate_mode(cls, v: str) -> str:
|
||||||
|
"""Validate operation mode."""
|
||||||
|
valid_modes = {"standalone", "provider", "standalone,provider"}
|
||||||
|
if v not in valid_modes:
|
||||||
|
raise ValueError(f"Invalid mode: {v}. Must be one of: {valid_modes}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("database_url")
|
||||||
|
@classmethod
|
||||||
|
def validate_database_url(cls, v: str) -> str:
|
||||||
|
"""Validate database URL format."""
|
||||||
|
valid_prefixes = ("sqlite://", "postgresql://", "mariadb+pymysql://", "mysql+pymysql://")
|
||||||
|
if not any(v.startswith(prefix) for prefix in valid_prefixes):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid database URL. Must start with one of: {valid_prefixes}"
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
@property
|
||||||
|
def database_type(self) -> DatabaseType:
|
||||||
|
"""Get the database type from the URL."""
|
||||||
|
if self.database_url.startswith("sqlite"):
|
||||||
|
return DatabaseType.SQLITE
|
||||||
|
elif self.database_url.startswith("postgresql"):
|
||||||
|
return DatabaseType.POSTGRESQL
|
||||||
|
elif "mariadb" in self.database_url:
|
||||||
|
return DatabaseType.MARIADB
|
||||||
|
elif "mysql" in self.database_url:
|
||||||
|
return DatabaseType.MYSQL
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown database type in URL: {self.database_url}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_standalone_mode(self) -> bool:
|
||||||
|
"""Check if standalone mode is enabled."""
|
||||||
|
return "standalone" in self.transcriptarr_mode
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_provider_mode(self) -> bool:
|
||||||
|
"""Check if provider mode is enabled."""
|
||||||
|
return "provider" in self.transcriptarr_mode
|
||||||
|
|
||||||
|
@property
|
||||||
|
def library_paths_list(self) -> List[str]:
|
||||||
|
"""Get library paths as a list."""
|
||||||
|
if not self.library_paths:
|
||||||
|
return []
|
||||||
|
return [p.strip() for p in self.library_paths.split("|") if p.strip()]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def skip_subtitle_languages_list(self) -> List[str]:
|
||||||
|
"""Get skip subtitle languages as a list."""
|
||||||
|
if not self.skip_subtitle_languages:
|
||||||
|
return []
|
||||||
|
return [lang.strip() for lang in self.skip_subtitle_languages.split("|") if lang.strip()]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def skip_audio_languages_list(self) -> List[str]:
|
||||||
|
"""Get skip audio languages as a list."""
|
||||||
|
if not self.skip_if_audio_languages:
|
||||||
|
return []
|
||||||
|
return [lang.strip() for lang in self.skip_if_audio_languages.split("|") if lang.strip()]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def preferred_audio_languages_list(self) -> List[str]:
|
||||||
|
"""Get preferred audio languages as a list."""
|
||||||
|
return [lang.strip() for lang in self.preferred_audio_languages.split("|") if lang.strip()]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Pydantic configuration."""
|
||||||
|
env_file = ".env"
|
||||||
|
env_file_encoding = "utf-8"
|
||||||
|
case_sensitive = False
|
||||||
|
|
||||||
|
|
||||||
|
# Global settings instance
|
||||||
|
settings = Settings()
|
||||||
1
backend/core/__init__.py
Normal file
1
backend/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""TranscriptorIO Core Module."""
|
||||||
219
backend/core/database.py
Normal file
219
backend/core/database.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
"""Database configuration and session management."""
|
||||||
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine, event, Engine
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm import sessionmaker, Session
|
||||||
|
from sqlalchemy.pool import StaticPool, QueuePool
|
||||||
|
|
||||||
|
from backend.config import settings, DatabaseType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Base class for all models
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
"""Database manager supporting SQLite, PostgreSQL, and MariaDB."""
|
||||||
|
|
||||||
|
def __init__(self, auto_create_tables: bool = True):
|
||||||
|
"""
|
||||||
|
Initialize database engine and session maker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auto_create_tables: If True, automatically create tables if they don't exist
|
||||||
|
"""
|
||||||
|
self.engine = self._create_engine()
|
||||||
|
self.SessionLocal = sessionmaker(
|
||||||
|
autocommit=False,
|
||||||
|
autoflush=False,
|
||||||
|
bind=self.engine
|
||||||
|
)
|
||||||
|
logger.info(f"Database initialized: {settings.database_type.value}")
|
||||||
|
|
||||||
|
# Automatically create tables if they don't exist
|
||||||
|
if auto_create_tables:
|
||||||
|
self._ensure_tables_exist()
|
||||||
|
|
||||||
|
def _create_engine(self) -> Engine:
|
||||||
|
"""Create SQLAlchemy engine based on database type."""
|
||||||
|
connect_args = {}
|
||||||
|
poolclass = QueuePool
|
||||||
|
|
||||||
|
if settings.database_type == DatabaseType.SQLITE:
|
||||||
|
# SQLite-specific configuration
|
||||||
|
connect_args = {
|
||||||
|
"check_same_thread": False, # Allow multi-threaded access
|
||||||
|
"timeout": 30.0, # Wait up to 30s for lock
|
||||||
|
}
|
||||||
|
# Use StaticPool for SQLite to avoid connection issues
|
||||||
|
poolclass = StaticPool
|
||||||
|
|
||||||
|
# Enable WAL mode for better concurrency
|
||||||
|
engine = create_engine(
|
||||||
|
settings.database_url,
|
||||||
|
connect_args=connect_args,
|
||||||
|
poolclass=poolclass,
|
||||||
|
echo=settings.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
@event.listens_for(engine, "connect")
|
||||||
|
def set_sqlite_pragma(dbapi_conn, connection_record):
|
||||||
|
"""Enable SQLite optimizations."""
|
||||||
|
cursor = dbapi_conn.cursor()
|
||||||
|
cursor.execute("PRAGMA journal_mode=WAL")
|
||||||
|
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.execute("PRAGMA cache_size=-64000") # 64MB cache
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
elif settings.database_type == DatabaseType.POSTGRESQL:
|
||||||
|
# PostgreSQL-specific configuration
|
||||||
|
try:
|
||||||
|
import psycopg2 # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"PostgreSQL support requires psycopg2-binary.\n"
|
||||||
|
"Install it with: pip install psycopg2-binary"
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = create_engine(
|
||||||
|
settings.database_url,
|
||||||
|
pool_size=10,
|
||||||
|
max_overflow=20,
|
||||||
|
pool_pre_ping=True, # Verify connections before using
|
||||||
|
echo=settings.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif settings.database_type in (DatabaseType.MARIADB, DatabaseType.MYSQL):
|
||||||
|
# MariaDB/MySQL-specific configuration
|
||||||
|
try:
|
||||||
|
import pymysql # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"MariaDB/MySQL support requires pymysql.\n"
|
||||||
|
"Install it with: pip install pymysql"
|
||||||
|
)
|
||||||
|
|
||||||
|
connect_args = {
|
||||||
|
"charset": "utf8mb4",
|
||||||
|
}
|
||||||
|
engine = create_engine(
|
||||||
|
settings.database_url,
|
||||||
|
connect_args=connect_args,
|
||||||
|
pool_size=10,
|
||||||
|
max_overflow=20,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
echo=settings.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported database type: {settings.database_type}")
|
||||||
|
|
||||||
|
return engine
|
||||||
|
|
||||||
|
def _ensure_tables_exist(self):
|
||||||
|
"""Check if tables exist and create them if they don't."""
|
||||||
|
# Import models to register them with Base.metadata
|
||||||
|
from backend.core import models # noqa: F401
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
|
||||||
|
inspector = inspect(self.engine)
|
||||||
|
existing_tables = inspector.get_table_names()
|
||||||
|
|
||||||
|
# Check if the main 'jobs' table exists
|
||||||
|
if 'jobs' not in existing_tables:
|
||||||
|
logger.info("Tables don't exist, creating them automatically...")
|
||||||
|
self.create_tables()
|
||||||
|
else:
|
||||||
|
logger.debug("Database tables already exist")
|
||||||
|
|
||||||
|
def create_tables(self):
|
||||||
|
"""Create all database tables."""
|
||||||
|
# Import models to register them with Base.metadata
|
||||||
|
from backend.core import models # noqa: F401
|
||||||
|
|
||||||
|
logger.info("Creating database tables...")
|
||||||
|
Base.metadata.create_all(bind=self.engine, checkfirst=True)
|
||||||
|
|
||||||
|
# Verify tables were actually created
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
inspector = inspect(self.engine)
|
||||||
|
created_tables = inspector.get_table_names()
|
||||||
|
|
||||||
|
if 'jobs' in created_tables:
|
||||||
|
logger.info(f"Database tables created successfully: {created_tables}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to create tables. Existing tables: {created_tables}")
|
||||||
|
raise RuntimeError("Failed to create database tables")
|
||||||
|
|
||||||
|
def drop_tables(self):
|
||||||
|
"""Drop all database tables (use with caution!)."""
|
||||||
|
logger.warning("Dropping all database tables...")
|
||||||
|
Base.metadata.drop_all(bind=self.engine)
|
||||||
|
logger.info("Database tables dropped")
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_session(self) -> Generator[Session, None, None]:
|
||||||
|
"""
|
||||||
|
Get a database session as a context manager.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
with db.get_session() as session:
|
||||||
|
session.query(Job).all()
|
||||||
|
"""
|
||||||
|
session = self.SessionLocal()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
logger.error(f"Database session error: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def get_db(self) -> Generator[Session, None, None]:
|
||||||
|
"""
|
||||||
|
Dependency for FastAPI endpoints.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@app.get("/jobs")
|
||||||
|
def get_jobs(db: Session = Depends(database.get_db)):
|
||||||
|
return db.query(Job).all()
|
||||||
|
"""
|
||||||
|
session = self.SessionLocal()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def health_check(self) -> bool:
|
||||||
|
"""Check if database connection is healthy."""
|
||||||
|
try:
|
||||||
|
from sqlalchemy import text
|
||||||
|
with self.get_session() as session:
|
||||||
|
session.execute(text("SELECT 1"))
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Database health check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""Get database statistics."""
|
||||||
|
stats = {
|
||||||
|
"type": settings.database_type.value,
|
||||||
|
"url": settings.database_url.split("@")[-1] if "@" in settings.database_url else settings.database_url,
|
||||||
|
"pool_size": getattr(self.engine.pool, "size", lambda: "N/A")(),
|
||||||
|
"pool_checked_in": getattr(self.engine.pool, "checkedin", lambda: 0)(),
|
||||||
|
"pool_checked_out": getattr(self.engine.pool, "checkedout", lambda: 0)(),
|
||||||
|
"pool_overflow": getattr(self.engine.pool, "overflow", lambda: 0)(),
|
||||||
|
}
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
# Global database instance
|
||||||
|
database = Database()
|
||||||
203
backend/core/models.py
Normal file
203
backend/core/models.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
"""Database models for TranscriptorIO."""
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
Column, String, Integer, Float, DateTime, Text, Boolean, Enum as SQLEnum, Index
|
||||||
|
)
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
|
from backend.core.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
class JobStatus(str, Enum):
|
||||||
|
"""Job status states."""
|
||||||
|
QUEUED = "queued"
|
||||||
|
PROCESSING = "processing"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
class JobStage(str, Enum):
|
||||||
|
"""Job processing stages."""
|
||||||
|
PENDING = "pending"
|
||||||
|
DETECTING_LANGUAGE = "detecting_language"
|
||||||
|
EXTRACTING_AUDIO = "extracting_audio"
|
||||||
|
TRANSCRIBING = "transcribing"
|
||||||
|
TRANSLATING = "translating"
|
||||||
|
GENERATING_SUBTITLES = "generating_subtitles"
|
||||||
|
POST_PROCESSING = "post_processing"
|
||||||
|
FINALIZING = "finalizing"
|
||||||
|
|
||||||
|
|
||||||
|
class QualityPreset(str, Enum):
|
||||||
|
"""Quality presets for transcription."""
|
||||||
|
FAST = "fast" # ja→en→es with Helsinki-NLP (4GB VRAM)
|
||||||
|
BALANCED = "balanced" # ja→ja→es with M2M100 (6GB VRAM)
|
||||||
|
BEST = "best" # ja→es direct with SeamlessM4T (10GB+ VRAM)
|
||||||
|
|
||||||
|
|
||||||
|
class Job(Base):
|
||||||
|
"""Job model representing a transcription task."""
|
||||||
|
|
||||||
|
__tablename__ = "jobs"
|
||||||
|
|
||||||
|
# Primary identification
|
||||||
|
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
file_path = Column(String(1024), nullable=False, index=True)
|
||||||
|
file_name = Column(String(512), nullable=False)
|
||||||
|
|
||||||
|
# Job status
|
||||||
|
status = Column(
|
||||||
|
SQLEnum(JobStatus),
|
||||||
|
nullable=False,
|
||||||
|
default=JobStatus.QUEUED,
|
||||||
|
index=True
|
||||||
|
)
|
||||||
|
priority = Column(Integer, nullable=False, default=0, index=True)
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
source_lang = Column(String(10), nullable=True)
|
||||||
|
target_lang = Column(String(10), nullable=True)
|
||||||
|
quality_preset = Column(
|
||||||
|
SQLEnum(QualityPreset),
|
||||||
|
nullable=False,
|
||||||
|
default=QualityPreset.FAST
|
||||||
|
)
|
||||||
|
transcribe_or_translate = Column(String(20), nullable=False, default="transcribe")
|
||||||
|
|
||||||
|
# Progress tracking
|
||||||
|
progress = Column(Float, nullable=False, default=0.0) # 0-100
|
||||||
|
current_stage = Column(
|
||||||
|
SQLEnum(JobStage),
|
||||||
|
nullable=False,
|
||||||
|
default=JobStage.PENDING
|
||||||
|
)
|
||||||
|
eta_seconds = Column(Integer, nullable=True)
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False, index=True)
|
||||||
|
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
# Results
|
||||||
|
output_path = Column(String(1024), nullable=True)
|
||||||
|
srt_content = Column(Text, nullable=True)
|
||||||
|
segments_count = Column(Integer, nullable=True)
|
||||||
|
|
||||||
|
# Error handling
|
||||||
|
error = Column(Text, nullable=True)
|
||||||
|
retry_count = Column(Integer, nullable=False, default=0)
|
||||||
|
max_retries = Column(Integer, nullable=False, default=3)
|
||||||
|
|
||||||
|
# Worker information
|
||||||
|
worker_id = Column(String(64), nullable=True)
|
||||||
|
vram_used_mb = Column(Integer, nullable=True)
|
||||||
|
processing_time_seconds = Column(Float, nullable=True)
|
||||||
|
|
||||||
|
# Provider mode specific
|
||||||
|
bazarr_callback_url = Column(String(512), nullable=True)
|
||||||
|
is_manual_request = Column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
|
# Additional metadata
|
||||||
|
model_used = Column(String(64), nullable=True)
|
||||||
|
device_used = Column(String(32), nullable=True)
|
||||||
|
compute_type = Column(String(32), nullable=True)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""String representation of Job."""
|
||||||
|
return f"<Job {self.id[:8]}... {self.file_name} [{self.status.value}] {self.progress:.1f}%>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration_seconds(self) -> Optional[float]:
|
||||||
|
"""Calculate job duration in seconds."""
|
||||||
|
if self.started_at and self.completed_at:
|
||||||
|
delta = self.completed_at - self.started_at
|
||||||
|
return delta.total_seconds()
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_terminal_state(self) -> bool:
|
||||||
|
"""Check if job is in a terminal state (completed/failed/cancelled)."""
|
||||||
|
return self.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_retry(self) -> bool:
|
||||||
|
"""Check if job can be retried."""
|
||||||
|
return self.status == JobStatus.FAILED and self.retry_count < self.max_retries
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert job to dictionary for API responses."""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"file_path": self.file_path,
|
||||||
|
"file_name": self.file_name,
|
||||||
|
"status": self.status.value,
|
||||||
|
"priority": self.priority,
|
||||||
|
"source_lang": self.source_lang,
|
||||||
|
"target_lang": self.target_lang,
|
||||||
|
"quality_preset": self.quality_preset.value if self.quality_preset else None,
|
||||||
|
"transcribe_or_translate": self.transcribe_or_translate,
|
||||||
|
"progress": self.progress,
|
||||||
|
"current_stage": self.current_stage.value if self.current_stage else None,
|
||||||
|
"eta_seconds": self.eta_seconds,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||||
|
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||||
|
"output_path": self.output_path,
|
||||||
|
"segments_count": self.segments_count,
|
||||||
|
"error": self.error,
|
||||||
|
"retry_count": self.retry_count,
|
||||||
|
"worker_id": self.worker_id,
|
||||||
|
"vram_used_mb": self.vram_used_mb,
|
||||||
|
"processing_time_seconds": self.processing_time_seconds,
|
||||||
|
"model_used": self.model_used,
|
||||||
|
"device_used": self.device_used,
|
||||||
|
}
|
||||||
|
|
||||||
|
def update_progress(self, progress: float, stage: JobStage, eta_seconds: Optional[int] = None):
|
||||||
|
"""Update job progress."""
|
||||||
|
self.progress = min(100.0, max(0.0, progress))
|
||||||
|
self.current_stage = stage
|
||||||
|
if eta_seconds is not None:
|
||||||
|
self.eta_seconds = eta_seconds
|
||||||
|
|
||||||
|
def mark_started(self, worker_id: str):
|
||||||
|
"""Mark job as started."""
|
||||||
|
self.status = JobStatus.PROCESSING
|
||||||
|
self.started_at = datetime.utcnow()
|
||||||
|
self.worker_id = worker_id
|
||||||
|
|
||||||
|
def mark_completed(self, output_path: str, segments_count: int, srt_content: Optional[str] = None):
|
||||||
|
"""Mark job as completed."""
|
||||||
|
self.status = JobStatus.COMPLETED
|
||||||
|
self.completed_at = datetime.utcnow()
|
||||||
|
self.output_path = output_path
|
||||||
|
self.segments_count = segments_count
|
||||||
|
self.srt_content = srt_content
|
||||||
|
self.progress = 100.0
|
||||||
|
self.current_stage = JobStage.FINALIZING
|
||||||
|
|
||||||
|
if self.started_at:
|
||||||
|
self.processing_time_seconds = (self.completed_at - self.started_at).total_seconds()
|
||||||
|
|
||||||
|
def mark_failed(self, error: str):
|
||||||
|
"""Mark job as failed."""
|
||||||
|
self.status = JobStatus.FAILED
|
||||||
|
self.completed_at = datetime.utcnow()
|
||||||
|
self.error = error
|
||||||
|
self.retry_count += 1
|
||||||
|
|
||||||
|
def mark_cancelled(self):
|
||||||
|
"""Mark job as cancelled."""
|
||||||
|
self.status = JobStatus.CANCELLED
|
||||||
|
self.completed_at = datetime.utcnow()
|
||||||
|
|
||||||
|
|
||||||
|
# Create indexes for common queries
|
||||||
|
Index('idx_jobs_status_priority', Job.status, Job.priority.desc(), Job.created_at)
|
||||||
|
Index('idx_jobs_created', Job.created_at.desc())
|
||||||
|
Index('idx_jobs_file_path', Job.file_path)
|
||||||
394
backend/core/queue_manager.py
Normal file
394
backend/core/queue_manager.py
Normal file
@@ -0,0 +1,394 @@
|
|||||||
|
"""Queue manager for persistent job queuing."""
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import List, Optional, Dict
|
||||||
|
from sqlalchemy import and_, or_
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from backend.core.database import database
|
||||||
|
from backend.core.models import Job, JobStatus, JobStage, QualityPreset
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QueueManager:
|
||||||
|
"""
|
||||||
|
Persistent queue manager for transcription jobs.
|
||||||
|
|
||||||
|
Replaces the old DeduplicatedQueue with a database-backed solution that:
|
||||||
|
- Persists jobs across restarts
|
||||||
|
- Supports priority queuing
|
||||||
|
- Prevents duplicate jobs
|
||||||
|
- Provides visibility into queue state
|
||||||
|
- Thread-safe operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize queue manager."""
|
||||||
|
self.db = database
|
||||||
|
logger.info("QueueManager initialized")
|
||||||
|
|
||||||
|
def add_job(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
file_name: str,
|
||||||
|
source_lang: Optional[str] = None,
|
||||||
|
target_lang: Optional[str] = None,
|
||||||
|
quality_preset: QualityPreset = QualityPreset.FAST,
|
||||||
|
transcribe_or_translate: str = "transcribe",
|
||||||
|
priority: int = 0,
|
||||||
|
bazarr_callback_url: Optional[str] = None,
|
||||||
|
is_manual_request: bool = False,
|
||||||
|
) -> Optional[Job]:
|
||||||
|
"""
|
||||||
|
Add a new job to the queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Full path to the media file
|
||||||
|
file_name: Name of the file
|
||||||
|
source_lang: Source language code (ISO 639-2)
|
||||||
|
target_lang: Target language code (ISO 639-2)
|
||||||
|
quality_preset: Quality preset (fast/balanced/best)
|
||||||
|
transcribe_or_translate: Operation type
|
||||||
|
priority: Job priority (higher = processed first)
|
||||||
|
bazarr_callback_url: Callback URL for Bazarr provider mode
|
||||||
|
is_manual_request: Whether this is a manual request (higher priority)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Job object if created, None if duplicate exists
|
||||||
|
"""
|
||||||
|
with self.db.get_session() as session:
|
||||||
|
# Check for existing job
|
||||||
|
existing = self._find_existing_job(session, file_path, target_lang)
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
logger.info(f"Job already exists for {file_name}: {existing.id} [{existing.status.value}]")
|
||||||
|
|
||||||
|
# If existing job failed and can retry, reset it
|
||||||
|
if existing.can_retry:
|
||||||
|
logger.info(f"Resetting failed job {existing.id} for retry")
|
||||||
|
existing.status = JobStatus.QUEUED
|
||||||
|
existing.error = None
|
||||||
|
existing.current_stage = JobStage.PENDING
|
||||||
|
existing.progress = 0.0
|
||||||
|
session.commit()
|
||||||
|
return existing
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create new job
|
||||||
|
job = Job(
|
||||||
|
file_path=file_path,
|
||||||
|
file_name=file_name,
|
||||||
|
source_lang=source_lang,
|
||||||
|
target_lang=target_lang,
|
||||||
|
quality_preset=quality_preset,
|
||||||
|
transcribe_or_translate=transcribe_or_translate,
|
||||||
|
priority=priority + (10 if is_manual_request else 0), # Boost manual requests
|
||||||
|
bazarr_callback_url=bazarr_callback_url,
|
||||||
|
is_manual_request=is_manual_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(job)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Access all attributes before session closes to ensure they're loaded
|
||||||
|
job_id = job.id
|
||||||
|
job_status = job.status
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Job {job_id} added to queue: {file_name} "
|
||||||
|
f"[{quality_preset.value}] priority={job.priority}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Re-query the job in a new session to return a fresh copy
|
||||||
|
with self.db.get_session() as session:
|
||||||
|
job = session.query(Job).filter(Job.id == job_id).first()
|
||||||
|
if job:
|
||||||
|
session.expunge(job) # Remove from session so it doesn't expire
|
||||||
|
return job
|
||||||
|
|
||||||
|
def get_next_job(self, worker_id: str) -> Optional[Job]:
|
||||||
|
"""
|
||||||
|
Get the next job from the queue for processing.
|
||||||
|
|
||||||
|
Jobs are selected based on:
|
||||||
|
1. Status = QUEUED
|
||||||
|
2. Priority (DESC)
|
||||||
|
3. Created time (ASC) - FIFO within same priority
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worker_id: ID of the worker requesting the job
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Job object or None if queue is empty
|
||||||
|
"""
|
||||||
|
with self.db.get_session() as session:
|
||||||
|
job = (
|
||||||
|
session.query(Job)
|
||||||
|
.filter(Job.status == JobStatus.QUEUED)
|
||||||
|
.order_by(
|
||||||
|
Job.priority.desc(),
|
||||||
|
Job.created_at.asc()
|
||||||
|
)
|
||||||
|
.with_for_update(skip_locked=True) # Skip locked rows (concurrent workers)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if job:
|
||||||
|
job_id = job.id
|
||||||
|
job.mark_started(worker_id)
|
||||||
|
session.commit()
|
||||||
|
logger.info(f"Job {job_id} assigned to worker {worker_id}")
|
||||||
|
|
||||||
|
# Re-query the job if found
|
||||||
|
if job:
|
||||||
|
with self.db.get_session() as session:
|
||||||
|
job = session.query(Job).filter(Job.id == job_id).first()
|
||||||
|
if job:
|
||||||
|
session.expunge(job) # Remove from session so it doesn't expire
|
||||||
|
return job
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_job_by_id(self, job_id: str) -> Optional[Job]:
|
||||||
|
"""Get a specific job by ID."""
|
||||||
|
with self.db.get_session() as session:
|
||||||
|
return session.query(Job).filter(Job.id == job_id).first()
|
||||||
|
|
||||||
|
def update_job_progress(
|
||||||
|
self,
|
||||||
|
job_id: str,
|
||||||
|
progress: float,
|
||||||
|
stage: JobStage,
|
||||||
|
eta_seconds: Optional[int] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Update job progress.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_id: Job ID
|
||||||
|
progress: Progress percentage (0-100)
|
||||||
|
stage: Current processing stage
|
||||||
|
eta_seconds: Estimated time to completion
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if updated successfully, False otherwise
|
||||||
|
"""
|
||||||
|
with self.db.get_session() as session:
|
||||||
|
job = session.query(Job).filter(Job.id == job_id).first()
|
||||||
|
|
||||||
|
if not job:
|
||||||
|
logger.warning(f"Job {job_id} not found for progress update")
|
||||||
|
return False
|
||||||
|
|
||||||
|
job.update_progress(progress, stage, eta_seconds)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Job {job_id} progress: {progress:.1f}% [{stage.value}] ETA: {eta_seconds}s"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def mark_job_completed(
|
||||||
|
self,
|
||||||
|
job_id: str,
|
||||||
|
output_path: str,
|
||||||
|
segments_count: int,
|
||||||
|
srt_content: Optional[str] = None
|
||||||
|
) -> bool:
|
||||||
|
"""Mark a job as completed."""
|
||||||
|
with self.db.get_session() as session:
|
||||||
|
job = session.query(Job).filter(Job.id == job_id).first()
|
||||||
|
|
||||||
|
if not job:
|
||||||
|
logger.warning(f"Job {job_id} not found for completion")
|
||||||
|
return False
|
||||||
|
|
||||||
|
job.mark_completed(output_path, segments_count, srt_content)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Job {job_id} completed: {output_path} "
|
||||||
|
f"({segments_count} segments, {job.processing_time_seconds:.1f}s)"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def mark_job_failed(self, job_id: str, error: str) -> bool:
|
||||||
|
"""Mark a job as failed."""
|
||||||
|
with self.db.get_session() as session:
|
||||||
|
job = session.query(Job).filter(Job.id == job_id).first()
|
||||||
|
|
||||||
|
if not job:
|
||||||
|
logger.warning(f"Job {job_id} not found for failure marking")
|
||||||
|
return False
|
||||||
|
|
||||||
|
job.mark_failed(error)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
f"Job {job_id} failed (attempt {job.retry_count}/{job.max_retries}): {error}"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def cancel_job(self, job_id: str) -> bool:
|
||||||
|
"""Cancel a queued or processing job."""
|
||||||
|
with self.db.get_session() as session:
|
||||||
|
job = session.query(Job).filter(Job.id == job_id).first()
|
||||||
|
|
||||||
|
if not job:
|
||||||
|
logger.warning(f"Job {job_id} not found for cancellation")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if job.is_terminal_state:
|
||||||
|
logger.warning(f"Job {job_id} is already in terminal state: {job.status.value}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
job.mark_cancelled()
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
logger.info(f"Job {job_id} cancelled")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_queue_stats(self) -> Dict:
|
||||||
|
"""Get queue statistics."""
|
||||||
|
with self.db.get_session() as session:
|
||||||
|
total = session.query(Job).count()
|
||||||
|
queued = session.query(Job).filter(Job.status == JobStatus.QUEUED).count()
|
||||||
|
processing = session.query(Job).filter(Job.status == JobStatus.PROCESSING).count()
|
||||||
|
completed = session.query(Job).filter(Job.status == JobStatus.COMPLETED).count()
|
||||||
|
failed = session.query(Job).filter(Job.status == JobStatus.FAILED).count()
|
||||||
|
|
||||||
|
# Get today's stats
|
||||||
|
today = datetime.utcnow().date()
|
||||||
|
completed_today = (
|
||||||
|
session.query(Job)
|
||||||
|
.filter(
|
||||||
|
Job.status == JobStatus.COMPLETED,
|
||||||
|
Job.completed_at >= today
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
failed_today = (
|
||||||
|
session.query(Job)
|
||||||
|
.filter(
|
||||||
|
Job.status == JobStatus.FAILED,
|
||||||
|
Job.completed_at >= today
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"queued": queued,
|
||||||
|
"processing": processing,
|
||||||
|
"completed": completed,
|
||||||
|
"failed": failed,
|
||||||
|
"completed_today": completed_today,
|
||||||
|
"failed_today": failed_today,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_jobs(
|
||||||
|
self,
|
||||||
|
status: Optional[JobStatus] = None,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0
|
||||||
|
) -> List[Job]:
|
||||||
|
"""
|
||||||
|
Get jobs with optional filtering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status: Filter by status
|
||||||
|
limit: Maximum number of jobs to return
|
||||||
|
offset: Offset for pagination
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Job objects
|
||||||
|
"""
|
||||||
|
with self.db.get_session() as session:
|
||||||
|
query = session.query(Job)
|
||||||
|
|
||||||
|
if status:
|
||||||
|
query = query.filter(Job.status == status)
|
||||||
|
|
||||||
|
jobs = (
|
||||||
|
query
|
||||||
|
.order_by(Job.created_at.desc())
|
||||||
|
.limit(limit)
|
||||||
|
.offset(offset)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
return jobs
|
||||||
|
|
||||||
|
def get_processing_jobs(self) -> List[Job]:
|
||||||
|
"""Get all currently processing jobs."""
|
||||||
|
return self.get_jobs(status=JobStatus.PROCESSING)
|
||||||
|
|
||||||
|
def get_queued_jobs(self) -> List[Job]:
|
||||||
|
"""Get all queued jobs."""
|
||||||
|
return self.get_jobs(status=JobStatus.QUEUED)
|
||||||
|
|
||||||
|
def is_queue_empty(self) -> bool:
|
||||||
|
"""Check if queue has any pending jobs."""
|
||||||
|
with self.db.get_session() as session:
|
||||||
|
count = (
|
||||||
|
session.query(Job)
|
||||||
|
.filter(Job.status.in_([JobStatus.QUEUED, JobStatus.PROCESSING]))
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
return count == 0
|
||||||
|
|
||||||
|
def cleanup_old_jobs(self, days: int = 30) -> int:
|
||||||
|
"""
|
||||||
|
Delete completed/failed jobs older than specified days.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
days: Number of days to keep jobs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of jobs deleted
|
||||||
|
"""
|
||||||
|
with self.db.get_session() as session:
|
||||||
|
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||||
|
|
||||||
|
deleted = (
|
||||||
|
session.query(Job)
|
||||||
|
.filter(
|
||||||
|
Job.status.in_([JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED]),
|
||||||
|
Job.completed_at < cutoff_date
|
||||||
|
)
|
||||||
|
.delete()
|
||||||
|
)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
if deleted > 0:
|
||||||
|
logger.info(f"Cleaned up {deleted} old jobs (older than {days} days)")
|
||||||
|
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
def _find_existing_job(
|
||||||
|
self,
|
||||||
|
session: Session,
|
||||||
|
file_path: str,
|
||||||
|
target_lang: Optional[str]
|
||||||
|
) -> Optional[Job]:
|
||||||
|
"""
|
||||||
|
Find existing job for the same file and target language.
|
||||||
|
|
||||||
|
Ignores completed jobs - allows re-transcription.
|
||||||
|
"""
|
||||||
|
query = session.query(Job).filter(
|
||||||
|
Job.file_path == file_path,
|
||||||
|
Job.status.in_([JobStatus.QUEUED, JobStatus.PROCESSING])
|
||||||
|
)
|
||||||
|
|
||||||
|
if target_lang:
|
||||||
|
query = query.filter(Job.target_lang == target_lang)
|
||||||
|
|
||||||
|
return query.first()
|
||||||
|
|
||||||
|
|
||||||
|
# Global queue manager instance
|
||||||
|
queue_manager = QueueManager()
|
||||||
285
backend/core/worker.py
Normal file
285
backend/core/worker.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
"""Individual worker for processing transcription jobs."""
|
||||||
|
import logging
|
||||||
|
import multiprocessing as mp
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from backend.core.database import Database
|
||||||
|
from backend.core.models import Job, JobStatus, JobStage
|
||||||
|
from backend.core.queue_manager import QueueManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerType(str, Enum):
|
||||||
|
"""Worker device type."""
|
||||||
|
CPU = "cpu"
|
||||||
|
GPU = "gpu"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerStatus(str, Enum):
|
||||||
|
"""Worker status states."""
|
||||||
|
IDLE = "idle"
|
||||||
|
BUSY = "busy"
|
||||||
|
STOPPING = "stopping"
|
||||||
|
STOPPED = "stopped"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
class Worker:
|
||||||
|
"""
|
||||||
|
Individual worker process for transcription.
|
||||||
|
|
||||||
|
Each worker runs in its own process and can handle one job at a time.
|
||||||
|
Workers communicate with the main process via multiprocessing primitives.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
worker_id: str,
|
||||||
|
worker_type: WorkerType,
|
||||||
|
device_id: Optional[int] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize worker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worker_id: Unique identifier for this worker
|
||||||
|
worker_type: CPU or GPU
|
||||||
|
device_id: GPU device ID (only for GPU workers)
|
||||||
|
"""
|
||||||
|
self.worker_id = worker_id
|
||||||
|
self.worker_type = worker_type
|
||||||
|
self.device_id = device_id
|
||||||
|
|
||||||
|
# Multiprocessing primitives
|
||||||
|
self.process: Optional[mp.Process] = None
|
||||||
|
self.stop_event = mp.Event()
|
||||||
|
self.status = mp.Value('i', WorkerStatus.IDLE.value) # type: ignore
|
||||||
|
self.current_job_id = mp.Array('c', 36) # type: ignore # UUID string
|
||||||
|
|
||||||
|
# Stats
|
||||||
|
self.jobs_completed = mp.Value('i', 0) # type: ignore
|
||||||
|
self.jobs_failed = mp.Value('i', 0) # type: ignore
|
||||||
|
self.started_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Start the worker process."""
|
||||||
|
if self.process and self.process.is_alive():
|
||||||
|
logger.warning(f"Worker {self.worker_id} is already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.stop_event.clear()
|
||||||
|
self.process = mp.Process(
|
||||||
|
target=self._worker_loop,
|
||||||
|
name=f"Worker-{self.worker_id}",
|
||||||
|
daemon=True
|
||||||
|
)
|
||||||
|
self.process.start()
|
||||||
|
self.started_at = datetime.utcnow()
|
||||||
|
logger.info(
|
||||||
|
f"Worker {self.worker_id} started (PID: {self.process.pid}, "
|
||||||
|
f"Type: {self.worker_type.value})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def stop(self, timeout: float = 30.0):
|
||||||
|
"""
|
||||||
|
Stop the worker process gracefully.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Maximum time to wait for worker to stop
|
||||||
|
"""
|
||||||
|
if not self.process or not self.process.is_alive():
|
||||||
|
logger.warning(f"Worker {self.worker_id} is not running")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Stopping worker {self.worker_id}...")
|
||||||
|
self.stop_event.set()
|
||||||
|
self.process.join(timeout=timeout)
|
||||||
|
|
||||||
|
if self.process.is_alive():
|
||||||
|
logger.warning(f"Worker {self.worker_id} did not stop gracefully, terminating...")
|
||||||
|
self.process.terminate()
|
||||||
|
self.process.join(timeout=5.0)
|
||||||
|
|
||||||
|
if self.process.is_alive():
|
||||||
|
logger.error(f"Worker {self.worker_id} did not terminate, killing...")
|
||||||
|
self.process.kill()
|
||||||
|
|
||||||
|
logger.info(f"Worker {self.worker_id} stopped")
|
||||||
|
|
||||||
|
def is_alive(self) -> bool:
|
||||||
|
"""Check if worker process is alive."""
|
||||||
|
return self.process is not None and self.process.is_alive()
|
||||||
|
|
||||||
|
def get_status(self) -> dict:
|
||||||
|
"""Get worker status information."""
|
||||||
|
status_value = self.status.value
|
||||||
|
status_enum = WorkerStatus.IDLE
|
||||||
|
for s in WorkerStatus:
|
||||||
|
if s.value == status_value:
|
||||||
|
status_enum = s
|
||||||
|
break
|
||||||
|
|
||||||
|
current_job = self.current_job_id.value.decode('utf-8').strip('\x00')
|
||||||
|
|
||||||
|
return {
|
||||||
|
"worker_id": self.worker_id,
|
||||||
|
"type": self.worker_type.value,
|
||||||
|
"device_id": self.device_id,
|
||||||
|
"status": status_enum.value,
|
||||||
|
"current_job_id": current_job if current_job else None,
|
||||||
|
"jobs_completed": self.jobs_completed.value,
|
||||||
|
"jobs_failed": self.jobs_failed.value,
|
||||||
|
"is_alive": self.is_alive(),
|
||||||
|
"pid": self.process.pid if self.process else None,
|
||||||
|
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _worker_loop(self):
|
||||||
|
"""
|
||||||
|
Main worker loop (runs in separate process).
|
||||||
|
|
||||||
|
This is the entry point for the worker process.
|
||||||
|
"""
|
||||||
|
# Set up logging in the worker process
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format=f'[Worker-{self.worker_id}] %(levelname)s: %(message)s'
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Worker {self.worker_id} loop started")
|
||||||
|
|
||||||
|
# Initialize database and queue manager in worker process
|
||||||
|
# Each process needs its own DB connection
|
||||||
|
try:
|
||||||
|
db = Database(auto_create_tables=False)
|
||||||
|
queue_mgr = QueueManager()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize worker: {e}")
|
||||||
|
self._set_status(WorkerStatus.ERROR)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Main work loop
|
||||||
|
while not self.stop_event.is_set():
|
||||||
|
try:
|
||||||
|
# Try to get next job from queue
|
||||||
|
job = queue_mgr.get_next_job(self.worker_id)
|
||||||
|
|
||||||
|
if job is None:
|
||||||
|
# No jobs available, idle for a bit
|
||||||
|
self._set_status(WorkerStatus.IDLE)
|
||||||
|
time.sleep(2)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Process the job
|
||||||
|
self._set_status(WorkerStatus.BUSY)
|
||||||
|
self._set_current_job(job.id)
|
||||||
|
|
||||||
|
logger.info(f"Processing job {job.id}: {job.file_name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._process_job(job, queue_mgr)
|
||||||
|
self.jobs_completed.value += 1
|
||||||
|
logger.info(f"Job {job.id} completed successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.jobs_failed.value += 1
|
||||||
|
error_msg = f"Job processing failed: {str(e)}\n{traceback.format_exc()}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
|
||||||
|
queue_mgr.mark_job_failed(job.id, error_msg)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self._clear_current_job()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Worker loop error: {e}\n{traceback.format_exc()}")
|
||||||
|
time.sleep(5) # Back off on errors
|
||||||
|
|
||||||
|
self._set_status(WorkerStatus.STOPPED)
|
||||||
|
logger.info(f"Worker {self.worker_id} loop ended")
|
||||||
|
|
||||||
|
def _process_job(self, job: Job, queue_mgr: QueueManager):
|
||||||
|
"""
|
||||||
|
Process a single transcription job.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job: Job to process
|
||||||
|
queue_mgr: Queue manager for updating progress
|
||||||
|
"""
|
||||||
|
# TODO: This will be implemented when we add the transcriber module
|
||||||
|
# For now, simulate work
|
||||||
|
|
||||||
|
# Stage 1: Detect language
|
||||||
|
queue_mgr.update_job_progress(
|
||||||
|
job.id,
|
||||||
|
progress=10.0,
|
||||||
|
stage=JobStage.DETECTING_LANGUAGE,
|
||||||
|
eta_seconds=60
|
||||||
|
)
|
||||||
|
time.sleep(2) # Simulate work
|
||||||
|
|
||||||
|
# Stage 2: Extract audio
|
||||||
|
queue_mgr.update_job_progress(
|
||||||
|
job.id,
|
||||||
|
progress=20.0,
|
||||||
|
stage=JobStage.EXTRACTING_AUDIO,
|
||||||
|
eta_seconds=50
|
||||||
|
)
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# Stage 3: Transcribe
|
||||||
|
queue_mgr.update_job_progress(
|
||||||
|
job.id,
|
||||||
|
progress=30.0,
|
||||||
|
stage=JobStage.TRANSCRIBING,
|
||||||
|
eta_seconds=40
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simulate progressive transcription
|
||||||
|
for i in range(30, 90, 10):
|
||||||
|
time.sleep(1)
|
||||||
|
queue_mgr.update_job_progress(
|
||||||
|
job.id,
|
||||||
|
progress=float(i),
|
||||||
|
stage=JobStage.TRANSCRIBING,
|
||||||
|
eta_seconds=int((100 - i) / 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stage 4: Finalize
|
||||||
|
queue_mgr.update_job_progress(
|
||||||
|
job.id,
|
||||||
|
progress=95.0,
|
||||||
|
stage=JobStage.FINALIZING,
|
||||||
|
eta_seconds=5
|
||||||
|
)
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Mark as completed
|
||||||
|
output_path = job.file_path.replace('.mkv', '.srt')
|
||||||
|
queue_mgr.mark_job_completed(
|
||||||
|
job.id,
|
||||||
|
output_path=output_path,
|
||||||
|
segments_count=100, # Simulated
|
||||||
|
srt_content="Simulated SRT content"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _set_status(self, status: WorkerStatus):
|
||||||
|
"""Set worker status (thread-safe)."""
|
||||||
|
self.status.value = status.value
|
||||||
|
|
||||||
|
def _set_current_job(self, job_id: str):
|
||||||
|
"""Set current job ID (thread-safe)."""
|
||||||
|
job_id_bytes = job_id.encode('utf-8')
|
||||||
|
for i, byte in enumerate(job_id_bytes):
|
||||||
|
if i < len(self.current_job_id):
|
||||||
|
self.current_job_id[i] = byte
|
||||||
|
|
||||||
|
def _clear_current_job(self):
|
||||||
|
"""Clear current job ID (thread-safe)."""
|
||||||
|
for i in range(len(self.current_job_id)):
|
||||||
|
self.current_job_id[i] = b'\x00'
|
||||||
Reference in New Issue
Block a user