feat: add centralized configuration system with Pydantic
- Add backend/config.py with Pydantic settings validation - Support for standalone, provider, and hybrid operation modes - Multi-database backend configuration (SQLite/PostgreSQL/MariaDB) - Environment variable validation with helpful error messages - Worker and Whisper model configuration
This commit is contained in:
1
backend/__init__.py
Normal file
1
backend/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""TranscriptorIO Backend Package."""
|
||||
1
backend/api/__init__.py
Normal file
1
backend/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""TranscriptorIO API Module."""
|
||||
214
backend/config.py
Normal file
214
backend/config.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Configuration management for TranscriptorIO."""
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Optional, List
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
|
||||
class OperationMode(str, Enum):
|
||||
"""Operation modes for TranscriptorIO."""
|
||||
STANDALONE = "standalone"
|
||||
PROVIDER = "provider"
|
||||
HYBRID = "standalone,provider"
|
||||
|
||||
|
||||
class DatabaseType(str, Enum):
|
||||
"""Supported database backends."""
|
||||
SQLITE = "sqlite"
|
||||
POSTGRESQL = "postgresql"
|
||||
MARIADB = "mariadb"
|
||||
MYSQL = "mysql"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
# === Application Mode ===
|
||||
transcriptarr_mode: str = Field(
|
||||
default="standalone",
|
||||
description="Operation mode: standalone, provider, or standalone,provider"
|
||||
)
|
||||
|
||||
# === Database Configuration ===
|
||||
database_url: str = Field(
|
||||
default="sqlite:///./transcriptarr.db",
|
||||
description="Database connection URL. Examples:\n"
|
||||
" SQLite: sqlite:///./transcriptarr.db\n"
|
||||
" PostgreSQL: postgresql://user:pass@localhost/transcriptarr\n"
|
||||
" MariaDB: mariadb+pymysql://user:pass@localhost/transcriptarr"
|
||||
)
|
||||
|
||||
# === Worker Configuration ===
|
||||
concurrent_transcriptions: int = Field(default=2, ge=1, le=10)
|
||||
whisper_threads: int = Field(default=4, ge=1, le=32)
|
||||
transcribe_device: str = Field(default="cpu", pattern="^(cpu|gpu|cuda)$")
|
||||
clear_vram_on_complete: bool = Field(default=True)
|
||||
|
||||
# === Whisper Model Configuration ===
|
||||
whisper_model: str = Field(
|
||||
default="medium",
|
||||
description="Whisper model: tiny, base, small, medium, large-v3, etc."
|
||||
)
|
||||
model_path: str = Field(default="./models")
|
||||
compute_type: str = Field(default="auto")
|
||||
|
||||
# === Standalone Mode Configuration ===
|
||||
library_paths: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Pipe-separated paths to scan: /media/anime|/media/movies"
|
||||
)
|
||||
auto_scan_enabled: bool = Field(default=False)
|
||||
scan_interval_minutes: int = Field(default=30, ge=1)
|
||||
|
||||
required_audio_language: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Only process files with this audio language (ISO 639-2)"
|
||||
)
|
||||
required_missing_subtitle: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Only process if this subtitle language is missing (ISO 639-2)"
|
||||
)
|
||||
skip_if_subtitle_exists: bool = Field(default=True)
|
||||
|
||||
# === Provider Mode Configuration ===
|
||||
bazarr_url: Optional[str] = Field(default=None)
|
||||
bazarr_api_key: Optional[str] = Field(default=None)
|
||||
provider_timeout_seconds: int = Field(default=600, ge=60)
|
||||
provider_callback_enabled: bool = Field(default=True)
|
||||
provider_polling_interval: int = Field(default=30, ge=10)
|
||||
|
||||
# === API Configuration ===
|
||||
webhook_port: int = Field(default=9000, ge=1024, le=65535)
|
||||
api_host: str = Field(default="0.0.0.0")
|
||||
debug: bool = Field(default=True)
|
||||
|
||||
# === Transcription Settings ===
|
||||
transcribe_or_translate: str = Field(
|
||||
default="transcribe",
|
||||
pattern="^(transcribe|translate)$"
|
||||
)
|
||||
subtitle_language_name: str = Field(default="")
|
||||
subtitle_language_naming_type: str = Field(
|
||||
default="ISO_639_2_B",
|
||||
description="Naming format: ISO_639_1, ISO_639_2_T, ISO_639_2_B, NAME, NATIVE"
|
||||
)
|
||||
word_level_highlight: bool = Field(default=False)
|
||||
custom_regroup: str = Field(default="cm_sl=84_sl=42++++++1")
|
||||
|
||||
# === Skip Configuration ===
|
||||
skip_if_external_subtitles_exist: bool = Field(default=False)
|
||||
skip_if_target_subtitles_exist: bool = Field(default=True)
|
||||
skip_if_internal_subtitles_language: Optional[str] = Field(default="eng")
|
||||
skip_subtitle_languages: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Pipe-separated language codes to skip: eng|spa"
|
||||
)
|
||||
skip_if_audio_languages: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Skip if audio track is in these languages: eng|spa"
|
||||
)
|
||||
skip_unknown_language: bool = Field(default=False)
|
||||
skip_only_subgen_subtitles: bool = Field(default=False)
|
||||
|
||||
# === Advanced Settings ===
|
||||
force_detected_language_to: Optional[str] = Field(default=None)
|
||||
detect_language_length: int = Field(default=30, ge=5)
|
||||
detect_language_offset: int = Field(default=0, ge=0)
|
||||
should_whisper_detect_audio_language: bool = Field(default=False)
|
||||
|
||||
preferred_audio_languages: str = Field(
|
||||
default="eng",
|
||||
description="Pipe-separated list in order of preference: eng|jpn"
|
||||
)
|
||||
|
||||
# === Path Mapping ===
|
||||
use_path_mapping: bool = Field(default=False)
|
||||
path_mapping_from: str = Field(default="/tv")
|
||||
path_mapping_to: str = Field(default="/Volumes/TV")
|
||||
|
||||
# === Legacy SubGen Compatibility ===
|
||||
show_in_subname_subgen: bool = Field(default=True)
|
||||
show_in_subname_model: bool = Field(default=True)
|
||||
append: bool = Field(default=False)
|
||||
lrc_for_audio_files: bool = Field(default=True)
|
||||
|
||||
@field_validator("transcriptarr_mode")
|
||||
@classmethod
|
||||
def validate_mode(cls, v: str) -> str:
|
||||
"""Validate operation mode."""
|
||||
valid_modes = {"standalone", "provider", "standalone,provider"}
|
||||
if v not in valid_modes:
|
||||
raise ValueError(f"Invalid mode: {v}. Must be one of: {valid_modes}")
|
||||
return v
|
||||
|
||||
@field_validator("database_url")
|
||||
@classmethod
|
||||
def validate_database_url(cls, v: str) -> str:
|
||||
"""Validate database URL format."""
|
||||
valid_prefixes = ("sqlite://", "postgresql://", "mariadb+pymysql://", "mysql+pymysql://")
|
||||
if not any(v.startswith(prefix) for prefix in valid_prefixes):
|
||||
raise ValueError(
|
||||
f"Invalid database URL. Must start with one of: {valid_prefixes}"
|
||||
)
|
||||
return v
|
||||
|
||||
@property
|
||||
def database_type(self) -> DatabaseType:
|
||||
"""Get the database type from the URL."""
|
||||
if self.database_url.startswith("sqlite"):
|
||||
return DatabaseType.SQLITE
|
||||
elif self.database_url.startswith("postgresql"):
|
||||
return DatabaseType.POSTGRESQL
|
||||
elif "mariadb" in self.database_url:
|
||||
return DatabaseType.MARIADB
|
||||
elif "mysql" in self.database_url:
|
||||
return DatabaseType.MYSQL
|
||||
else:
|
||||
raise ValueError(f"Unknown database type in URL: {self.database_url}")
|
||||
|
||||
@property
|
||||
def is_standalone_mode(self) -> bool:
|
||||
"""Check if standalone mode is enabled."""
|
||||
return "standalone" in self.transcriptarr_mode
|
||||
|
||||
@property
|
||||
def is_provider_mode(self) -> bool:
|
||||
"""Check if provider mode is enabled."""
|
||||
return "provider" in self.transcriptarr_mode
|
||||
|
||||
@property
|
||||
def library_paths_list(self) -> List[str]:
|
||||
"""Get library paths as a list."""
|
||||
if not self.library_paths:
|
||||
return []
|
||||
return [p.strip() for p in self.library_paths.split("|") if p.strip()]
|
||||
|
||||
@property
|
||||
def skip_subtitle_languages_list(self) -> List[str]:
|
||||
"""Get skip subtitle languages as a list."""
|
||||
if not self.skip_subtitle_languages:
|
||||
return []
|
||||
return [lang.strip() for lang in self.skip_subtitle_languages.split("|") if lang.strip()]
|
||||
|
||||
@property
|
||||
def skip_audio_languages_list(self) -> List[str]:
|
||||
"""Get skip audio languages as a list."""
|
||||
if not self.skip_if_audio_languages:
|
||||
return []
|
||||
return [lang.strip() for lang in self.skip_if_audio_languages.split("|") if lang.strip()]
|
||||
|
||||
@property
|
||||
def preferred_audio_languages_list(self) -> List[str]:
|
||||
"""Get preferred audio languages as a list."""
|
||||
return [lang.strip() for lang in self.preferred_audio_languages.split("|") if lang.strip()]
|
||||
|
||||
class Config:
|
||||
"""Pydantic configuration."""
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
case_sensitive = False
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings()
|
||||
1
backend/core/__init__.py
Normal file
1
backend/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""TranscriptorIO Core Module."""
|
||||
219
backend/core/database.py
Normal file
219
backend/core/database.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Database configuration and session management."""
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
from sqlalchemy import create_engine, event, Engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.pool import StaticPool, QueuePool
|
||||
|
||||
from backend.config import settings, DatabaseType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Base class for all models
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class Database:
|
||||
"""Database manager supporting SQLite, PostgreSQL, and MariaDB."""
|
||||
|
||||
def __init__(self, auto_create_tables: bool = True):
|
||||
"""
|
||||
Initialize database engine and session maker.
|
||||
|
||||
Args:
|
||||
auto_create_tables: If True, automatically create tables if they don't exist
|
||||
"""
|
||||
self.engine = self._create_engine()
|
||||
self.SessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=self.engine
|
||||
)
|
||||
logger.info(f"Database initialized: {settings.database_type.value}")
|
||||
|
||||
# Automatically create tables if they don't exist
|
||||
if auto_create_tables:
|
||||
self._ensure_tables_exist()
|
||||
|
||||
def _create_engine(self) -> Engine:
|
||||
"""Create SQLAlchemy engine based on database type."""
|
||||
connect_args = {}
|
||||
poolclass = QueuePool
|
||||
|
||||
if settings.database_type == DatabaseType.SQLITE:
|
||||
# SQLite-specific configuration
|
||||
connect_args = {
|
||||
"check_same_thread": False, # Allow multi-threaded access
|
||||
"timeout": 30.0, # Wait up to 30s for lock
|
||||
}
|
||||
# Use StaticPool for SQLite to avoid connection issues
|
||||
poolclass = StaticPool
|
||||
|
||||
# Enable WAL mode for better concurrency
|
||||
engine = create_engine(
|
||||
settings.database_url,
|
||||
connect_args=connect_args,
|
||||
poolclass=poolclass,
|
||||
echo=settings.debug,
|
||||
)
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_conn, connection_record):
|
||||
"""Enable SQLite optimizations."""
|
||||
cursor = dbapi_conn.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.execute("PRAGMA cache_size=-64000") # 64MB cache
|
||||
cursor.close()
|
||||
|
||||
elif settings.database_type == DatabaseType.POSTGRESQL:
|
||||
# PostgreSQL-specific configuration
|
||||
try:
|
||||
import psycopg2 # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"PostgreSQL support requires psycopg2-binary.\n"
|
||||
"Install it with: pip install psycopg2-binary"
|
||||
)
|
||||
|
||||
engine = create_engine(
|
||||
settings.database_url,
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
pool_pre_ping=True, # Verify connections before using
|
||||
echo=settings.debug,
|
||||
)
|
||||
|
||||
elif settings.database_type in (DatabaseType.MARIADB, DatabaseType.MYSQL):
|
||||
# MariaDB/MySQL-specific configuration
|
||||
try:
|
||||
import pymysql # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"MariaDB/MySQL support requires pymysql.\n"
|
||||
"Install it with: pip install pymysql"
|
||||
)
|
||||
|
||||
connect_args = {
|
||||
"charset": "utf8mb4",
|
||||
}
|
||||
engine = create_engine(
|
||||
settings.database_url,
|
||||
connect_args=connect_args,
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
pool_pre_ping=True,
|
||||
echo=settings.debug,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {settings.database_type}")
|
||||
|
||||
return engine
|
||||
|
||||
def _ensure_tables_exist(self):
|
||||
"""Check if tables exist and create them if they don't."""
|
||||
# Import models to register them with Base.metadata
|
||||
from backend.core import models # noqa: F401
|
||||
from sqlalchemy import inspect
|
||||
|
||||
inspector = inspect(self.engine)
|
||||
existing_tables = inspector.get_table_names()
|
||||
|
||||
# Check if the main 'jobs' table exists
|
||||
if 'jobs' not in existing_tables:
|
||||
logger.info("Tables don't exist, creating them automatically...")
|
||||
self.create_tables()
|
||||
else:
|
||||
logger.debug("Database tables already exist")
|
||||
|
||||
def create_tables(self):
|
||||
"""Create all database tables."""
|
||||
# Import models to register them with Base.metadata
|
||||
from backend.core import models # noqa: F401
|
||||
|
||||
logger.info("Creating database tables...")
|
||||
Base.metadata.create_all(bind=self.engine, checkfirst=True)
|
||||
|
||||
# Verify tables were actually created
|
||||
from sqlalchemy import inspect
|
||||
inspector = inspect(self.engine)
|
||||
created_tables = inspector.get_table_names()
|
||||
|
||||
if 'jobs' in created_tables:
|
||||
logger.info(f"Database tables created successfully: {created_tables}")
|
||||
else:
|
||||
logger.error(f"Failed to create tables. Existing tables: {created_tables}")
|
||||
raise RuntimeError("Failed to create database tables")
|
||||
|
||||
def drop_tables(self):
|
||||
"""Drop all database tables (use with caution!)."""
|
||||
logger.warning("Dropping all database tables...")
|
||||
Base.metadata.drop_all(bind=self.engine)
|
||||
logger.info("Database tables dropped")
|
||||
|
||||
@contextmanager
|
||||
def get_session(self) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Get a database session as a context manager.
|
||||
|
||||
Usage:
|
||||
with db.get_session() as session:
|
||||
session.query(Job).all()
|
||||
"""
|
||||
session = self.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"Database session error: {e}")
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def get_db(self) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Dependency for FastAPI endpoints.
|
||||
|
||||
Usage:
|
||||
@app.get("/jobs")
|
||||
def get_jobs(db: Session = Depends(database.get_db)):
|
||||
return db.query(Job).all()
|
||||
"""
|
||||
session = self.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def health_check(self) -> bool:
|
||||
"""Check if database connection is healthy."""
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
with self.get_session() as session:
|
||||
session.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
return False
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get database statistics."""
|
||||
stats = {
|
||||
"type": settings.database_type.value,
|
||||
"url": settings.database_url.split("@")[-1] if "@" in settings.database_url else settings.database_url,
|
||||
"pool_size": getattr(self.engine.pool, "size", lambda: "N/A")(),
|
||||
"pool_checked_in": getattr(self.engine.pool, "checkedin", lambda: 0)(),
|
||||
"pool_checked_out": getattr(self.engine.pool, "checkedout", lambda: 0)(),
|
||||
"pool_overflow": getattr(self.engine.pool, "overflow", lambda: 0)(),
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
# Global database instance
|
||||
database = Database()
|
||||
203
backend/core/models.py
Normal file
203
backend/core/models.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""Database models for TranscriptorIO."""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
Column, String, Integer, Float, DateTime, Text, Boolean, Enum as SQLEnum, Index
|
||||
)
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from backend.core.database import Base
|
||||
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
"""Job status states."""
|
||||
QUEUED = "queued"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class JobStage(str, Enum):
|
||||
"""Job processing stages."""
|
||||
PENDING = "pending"
|
||||
DETECTING_LANGUAGE = "detecting_language"
|
||||
EXTRACTING_AUDIO = "extracting_audio"
|
||||
TRANSCRIBING = "transcribing"
|
||||
TRANSLATING = "translating"
|
||||
GENERATING_SUBTITLES = "generating_subtitles"
|
||||
POST_PROCESSING = "post_processing"
|
||||
FINALIZING = "finalizing"
|
||||
|
||||
|
||||
class QualityPreset(str, Enum):
|
||||
"""Quality presets for transcription."""
|
||||
FAST = "fast" # ja→en→es with Helsinki-NLP (4GB VRAM)
|
||||
BALANCED = "balanced" # ja→ja→es with M2M100 (6GB VRAM)
|
||||
BEST = "best" # ja→es direct with SeamlessM4T (10GB+ VRAM)
|
||||
|
||||
|
||||
class Job(Base):
|
||||
"""Job model representing a transcription task."""
|
||||
|
||||
__tablename__ = "jobs"
|
||||
|
||||
# Primary identification
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
file_path = Column(String(1024), nullable=False, index=True)
|
||||
file_name = Column(String(512), nullable=False)
|
||||
|
||||
# Job status
|
||||
status = Column(
|
||||
SQLEnum(JobStatus),
|
||||
nullable=False,
|
||||
default=JobStatus.QUEUED,
|
||||
index=True
|
||||
)
|
||||
priority = Column(Integer, nullable=False, default=0, index=True)
|
||||
|
||||
# Configuration
|
||||
source_lang = Column(String(10), nullable=True)
|
||||
target_lang = Column(String(10), nullable=True)
|
||||
quality_preset = Column(
|
||||
SQLEnum(QualityPreset),
|
||||
nullable=False,
|
||||
default=QualityPreset.FAST
|
||||
)
|
||||
transcribe_or_translate = Column(String(20), nullable=False, default="transcribe")
|
||||
|
||||
# Progress tracking
|
||||
progress = Column(Float, nullable=False, default=0.0) # 0-100
|
||||
current_stage = Column(
|
||||
SQLEnum(JobStage),
|
||||
nullable=False,
|
||||
default=JobStage.PENDING
|
||||
)
|
||||
eta_seconds = Column(Integer, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False, index=True)
|
||||
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Results
|
||||
output_path = Column(String(1024), nullable=True)
|
||||
srt_content = Column(Text, nullable=True)
|
||||
segments_count = Column(Integer, nullable=True)
|
||||
|
||||
# Error handling
|
||||
error = Column(Text, nullable=True)
|
||||
retry_count = Column(Integer, nullable=False, default=0)
|
||||
max_retries = Column(Integer, nullable=False, default=3)
|
||||
|
||||
# Worker information
|
||||
worker_id = Column(String(64), nullable=True)
|
||||
vram_used_mb = Column(Integer, nullable=True)
|
||||
processing_time_seconds = Column(Float, nullable=True)
|
||||
|
||||
# Provider mode specific
|
||||
bazarr_callback_url = Column(String(512), nullable=True)
|
||||
is_manual_request = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
# Additional metadata
|
||||
model_used = Column(String(64), nullable=True)
|
||||
device_used = Column(String(32), nullable=True)
|
||||
compute_type = Column(String(32), nullable=True)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of Job."""
|
||||
return f"<Job {self.id[:8]}... {self.file_name} [{self.status.value}] {self.progress:.1f}%>"
|
||||
|
||||
@property
|
||||
def duration_seconds(self) -> Optional[float]:
|
||||
"""Calculate job duration in seconds."""
|
||||
if self.started_at and self.completed_at:
|
||||
delta = self.completed_at - self.started_at
|
||||
return delta.total_seconds()
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_terminal_state(self) -> bool:
|
||||
"""Check if job is in a terminal state (completed/failed/cancelled)."""
|
||||
return self.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||
|
||||
@property
|
||||
def can_retry(self) -> bool:
|
||||
"""Check if job can be retried."""
|
||||
return self.status == JobStatus.FAILED and self.retry_count < self.max_retries
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert job to dictionary for API responses."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"file_path": self.file_path,
|
||||
"file_name": self.file_name,
|
||||
"status": self.status.value,
|
||||
"priority": self.priority,
|
||||
"source_lang": self.source_lang,
|
||||
"target_lang": self.target_lang,
|
||||
"quality_preset": self.quality_preset.value if self.quality_preset else None,
|
||||
"transcribe_or_translate": self.transcribe_or_translate,
|
||||
"progress": self.progress,
|
||||
"current_stage": self.current_stage.value if self.current_stage else None,
|
||||
"eta_seconds": self.eta_seconds,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
"output_path": self.output_path,
|
||||
"segments_count": self.segments_count,
|
||||
"error": self.error,
|
||||
"retry_count": self.retry_count,
|
||||
"worker_id": self.worker_id,
|
||||
"vram_used_mb": self.vram_used_mb,
|
||||
"processing_time_seconds": self.processing_time_seconds,
|
||||
"model_used": self.model_used,
|
||||
"device_used": self.device_used,
|
||||
}
|
||||
|
||||
def update_progress(self, progress: float, stage: JobStage, eta_seconds: Optional[int] = None):
|
||||
"""Update job progress."""
|
||||
self.progress = min(100.0, max(0.0, progress))
|
||||
self.current_stage = stage
|
||||
if eta_seconds is not None:
|
||||
self.eta_seconds = eta_seconds
|
||||
|
||||
def mark_started(self, worker_id: str):
|
||||
"""Mark job as started."""
|
||||
self.status = JobStatus.PROCESSING
|
||||
self.started_at = datetime.utcnow()
|
||||
self.worker_id = worker_id
|
||||
|
||||
def mark_completed(self, output_path: str, segments_count: int, srt_content: Optional[str] = None):
|
||||
"""Mark job as completed."""
|
||||
self.status = JobStatus.COMPLETED
|
||||
self.completed_at = datetime.utcnow()
|
||||
self.output_path = output_path
|
||||
self.segments_count = segments_count
|
||||
self.srt_content = srt_content
|
||||
self.progress = 100.0
|
||||
self.current_stage = JobStage.FINALIZING
|
||||
|
||||
if self.started_at:
|
||||
self.processing_time_seconds = (self.completed_at - self.started_at).total_seconds()
|
||||
|
||||
def mark_failed(self, error: str):
|
||||
"""Mark job as failed."""
|
||||
self.status = JobStatus.FAILED
|
||||
self.completed_at = datetime.utcnow()
|
||||
self.error = error
|
||||
self.retry_count += 1
|
||||
|
||||
def mark_cancelled(self):
|
||||
"""Mark job as cancelled."""
|
||||
self.status = JobStatus.CANCELLED
|
||||
self.completed_at = datetime.utcnow()
|
||||
|
||||
|
||||
# Create indexes for common queries
|
||||
Index('idx_jobs_status_priority', Job.status, Job.priority.desc(), Job.created_at)
|
||||
Index('idx_jobs_created', Job.created_at.desc())
|
||||
Index('idx_jobs_file_path', Job.file_path)
|
||||
394
backend/core/queue_manager.py
Normal file
394
backend/core/queue_manager.py
Normal file
@@ -0,0 +1,394 @@
|
||||
"""Queue manager for persistent job queuing."""
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Dict
|
||||
from sqlalchemy import and_, or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.core.database import database
|
||||
from backend.core.models import Job, JobStatus, JobStage, QualityPreset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QueueManager:
|
||||
"""
|
||||
Persistent queue manager for transcription jobs.
|
||||
|
||||
Replaces the old DeduplicatedQueue with a database-backed solution that:
|
||||
- Persists jobs across restarts
|
||||
- Supports priority queuing
|
||||
- Prevents duplicate jobs
|
||||
- Provides visibility into queue state
|
||||
- Thread-safe operations
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize queue manager."""
|
||||
self.db = database
|
||||
logger.info("QueueManager initialized")
|
||||
|
||||
def add_job(
|
||||
self,
|
||||
file_path: str,
|
||||
file_name: str,
|
||||
source_lang: Optional[str] = None,
|
||||
target_lang: Optional[str] = None,
|
||||
quality_preset: QualityPreset = QualityPreset.FAST,
|
||||
transcribe_or_translate: str = "transcribe",
|
||||
priority: int = 0,
|
||||
bazarr_callback_url: Optional[str] = None,
|
||||
is_manual_request: bool = False,
|
||||
) -> Optional[Job]:
|
||||
"""
|
||||
Add a new job to the queue.
|
||||
|
||||
Args:
|
||||
file_path: Full path to the media file
|
||||
file_name: Name of the file
|
||||
source_lang: Source language code (ISO 639-2)
|
||||
target_lang: Target language code (ISO 639-2)
|
||||
quality_preset: Quality preset (fast/balanced/best)
|
||||
transcribe_or_translate: Operation type
|
||||
priority: Job priority (higher = processed first)
|
||||
bazarr_callback_url: Callback URL for Bazarr provider mode
|
||||
is_manual_request: Whether this is a manual request (higher priority)
|
||||
|
||||
Returns:
|
||||
Job object if created, None if duplicate exists
|
||||
"""
|
||||
with self.db.get_session() as session:
|
||||
# Check for existing job
|
||||
existing = self._find_existing_job(session, file_path, target_lang)
|
||||
|
||||
if existing:
|
||||
logger.info(f"Job already exists for {file_name}: {existing.id} [{existing.status.value}]")
|
||||
|
||||
# If existing job failed and can retry, reset it
|
||||
if existing.can_retry:
|
||||
logger.info(f"Resetting failed job {existing.id} for retry")
|
||||
existing.status = JobStatus.QUEUED
|
||||
existing.error = None
|
||||
existing.current_stage = JobStage.PENDING
|
||||
existing.progress = 0.0
|
||||
session.commit()
|
||||
return existing
|
||||
|
||||
return None
|
||||
|
||||
# Create new job
|
||||
job = Job(
|
||||
file_path=file_path,
|
||||
file_name=file_name,
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
quality_preset=quality_preset,
|
||||
transcribe_or_translate=transcribe_or_translate,
|
||||
priority=priority + (10 if is_manual_request else 0), # Boost manual requests
|
||||
bazarr_callback_url=bazarr_callback_url,
|
||||
is_manual_request=is_manual_request,
|
||||
)
|
||||
|
||||
session.add(job)
|
||||
session.commit()
|
||||
|
||||
# Access all attributes before session closes to ensure they're loaded
|
||||
job_id = job.id
|
||||
job_status = job.status
|
||||
|
||||
logger.info(
|
||||
f"Job {job_id} added to queue: {file_name} "
|
||||
f"[{quality_preset.value}] priority={job.priority}"
|
||||
)
|
||||
|
||||
# Re-query the job in a new session to return a fresh copy
|
||||
with self.db.get_session() as session:
|
||||
job = session.query(Job).filter(Job.id == job_id).first()
|
||||
if job:
|
||||
session.expunge(job) # Remove from session so it doesn't expire
|
||||
return job
|
||||
|
||||
def get_next_job(self, worker_id: str) -> Optional[Job]:
|
||||
"""
|
||||
Get the next job from the queue for processing.
|
||||
|
||||
Jobs are selected based on:
|
||||
1. Status = QUEUED
|
||||
2. Priority (DESC)
|
||||
3. Created time (ASC) - FIFO within same priority
|
||||
|
||||
Args:
|
||||
worker_id: ID of the worker requesting the job
|
||||
|
||||
Returns:
|
||||
Job object or None if queue is empty
|
||||
"""
|
||||
with self.db.get_session() as session:
|
||||
job = (
|
||||
session.query(Job)
|
||||
.filter(Job.status == JobStatus.QUEUED)
|
||||
.order_by(
|
||||
Job.priority.desc(),
|
||||
Job.created_at.asc()
|
||||
)
|
||||
.with_for_update(skip_locked=True) # Skip locked rows (concurrent workers)
|
||||
.first()
|
||||
)
|
||||
|
||||
if job:
|
||||
job_id = job.id
|
||||
job.mark_started(worker_id)
|
||||
session.commit()
|
||||
logger.info(f"Job {job_id} assigned to worker {worker_id}")
|
||||
|
||||
# Re-query the job if found
|
||||
if job:
|
||||
with self.db.get_session() as session:
|
||||
job = session.query(Job).filter(Job.id == job_id).first()
|
||||
if job:
|
||||
session.expunge(job) # Remove from session so it doesn't expire
|
||||
return job
|
||||
|
||||
return None
|
||||
|
||||
def get_job_by_id(self, job_id: str) -> Optional[Job]:
|
||||
"""Get a specific job by ID."""
|
||||
with self.db.get_session() as session:
|
||||
return session.query(Job).filter(Job.id == job_id).first()
|
||||
|
||||
def update_job_progress(
|
||||
self,
|
||||
job_id: str,
|
||||
progress: float,
|
||||
stage: JobStage,
|
||||
eta_seconds: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Update job progress.
|
||||
|
||||
Args:
|
||||
job_id: Job ID
|
||||
progress: Progress percentage (0-100)
|
||||
stage: Current processing stage
|
||||
eta_seconds: Estimated time to completion
|
||||
|
||||
Returns:
|
||||
True if updated successfully, False otherwise
|
||||
"""
|
||||
with self.db.get_session() as session:
|
||||
job = session.query(Job).filter(Job.id == job_id).first()
|
||||
|
||||
if not job:
|
||||
logger.warning(f"Job {job_id} not found for progress update")
|
||||
return False
|
||||
|
||||
job.update_progress(progress, stage, eta_seconds)
|
||||
session.commit()
|
||||
|
||||
logger.debug(
|
||||
f"Job {job_id} progress: {progress:.1f}% [{stage.value}] ETA: {eta_seconds}s"
|
||||
)
|
||||
return True
|
||||
|
||||
def mark_job_completed(
|
||||
self,
|
||||
job_id: str,
|
||||
output_path: str,
|
||||
segments_count: int,
|
||||
srt_content: Optional[str] = None
|
||||
) -> bool:
|
||||
"""Mark a job as completed."""
|
||||
with self.db.get_session() as session:
|
||||
job = session.query(Job).filter(Job.id == job_id).first()
|
||||
|
||||
if not job:
|
||||
logger.warning(f"Job {job_id} not found for completion")
|
||||
return False
|
||||
|
||||
job.mark_completed(output_path, segments_count, srt_content)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Job {job_id} completed: {output_path} "
|
||||
f"({segments_count} segments, {job.processing_time_seconds:.1f}s)"
|
||||
)
|
||||
return True
|
||||
|
||||
def mark_job_failed(self, job_id: str, error: str) -> bool:
|
||||
"""Mark a job as failed."""
|
||||
with self.db.get_session() as session:
|
||||
job = session.query(Job).filter(Job.id == job_id).first()
|
||||
|
||||
if not job:
|
||||
logger.warning(f"Job {job_id} not found for failure marking")
|
||||
return False
|
||||
|
||||
job.mark_failed(error)
|
||||
session.commit()
|
||||
|
||||
logger.error(
|
||||
f"Job {job_id} failed (attempt {job.retry_count}/{job.max_retries}): {error}"
|
||||
)
|
||||
return True
|
||||
|
||||
def cancel_job(self, job_id: str) -> bool:
|
||||
"""Cancel a queued or processing job."""
|
||||
with self.db.get_session() as session:
|
||||
job = session.query(Job).filter(Job.id == job_id).first()
|
||||
|
||||
if not job:
|
||||
logger.warning(f"Job {job_id} not found for cancellation")
|
||||
return False
|
||||
|
||||
if job.is_terminal_state:
|
||||
logger.warning(f"Job {job_id} is already in terminal state: {job.status.value}")
|
||||
return False
|
||||
|
||||
job.mark_cancelled()
|
||||
session.commit()
|
||||
|
||||
logger.info(f"Job {job_id} cancelled")
|
||||
return True
|
||||
|
||||
def get_queue_stats(self) -> Dict:
|
||||
"""Get queue statistics."""
|
||||
with self.db.get_session() as session:
|
||||
total = session.query(Job).count()
|
||||
queued = session.query(Job).filter(Job.status == JobStatus.QUEUED).count()
|
||||
processing = session.query(Job).filter(Job.status == JobStatus.PROCESSING).count()
|
||||
completed = session.query(Job).filter(Job.status == JobStatus.COMPLETED).count()
|
||||
failed = session.query(Job).filter(Job.status == JobStatus.FAILED).count()
|
||||
|
||||
# Get today's stats
|
||||
today = datetime.utcnow().date()
|
||||
completed_today = (
|
||||
session.query(Job)
|
||||
.filter(
|
||||
Job.status == JobStatus.COMPLETED,
|
||||
Job.completed_at >= today
|
||||
)
|
||||
.count()
|
||||
)
|
||||
failed_today = (
|
||||
session.query(Job)
|
||||
.filter(
|
||||
Job.status == JobStatus.FAILED,
|
||||
Job.completed_at >= today
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"queued": queued,
|
||||
"processing": processing,
|
||||
"completed": completed,
|
||||
"failed": failed,
|
||||
"completed_today": completed_today,
|
||||
"failed_today": failed_today,
|
||||
}
|
||||
|
||||
def get_jobs(
|
||||
self,
|
||||
status: Optional[JobStatus] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[Job]:
|
||||
"""
|
||||
Get jobs with optional filtering.
|
||||
|
||||
Args:
|
||||
status: Filter by status
|
||||
limit: Maximum number of jobs to return
|
||||
offset: Offset for pagination
|
||||
|
||||
Returns:
|
||||
List of Job objects
|
||||
"""
|
||||
with self.db.get_session() as session:
|
||||
query = session.query(Job)
|
||||
|
||||
if status:
|
||||
query = query.filter(Job.status == status)
|
||||
|
||||
jobs = (
|
||||
query
|
||||
.order_by(Job.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
.all()
|
||||
)
|
||||
|
||||
return jobs
|
||||
|
||||
def get_processing_jobs(self) -> List[Job]:
|
||||
"""Get all currently processing jobs."""
|
||||
return self.get_jobs(status=JobStatus.PROCESSING)
|
||||
|
||||
def get_queued_jobs(self) -> List[Job]:
|
||||
"""Get all queued jobs."""
|
||||
return self.get_jobs(status=JobStatus.QUEUED)
|
||||
|
||||
def is_queue_empty(self) -> bool:
|
||||
"""Check if queue has any pending jobs."""
|
||||
with self.db.get_session() as session:
|
||||
count = (
|
||||
session.query(Job)
|
||||
.filter(Job.status.in_([JobStatus.QUEUED, JobStatus.PROCESSING]))
|
||||
.count()
|
||||
)
|
||||
return count == 0
|
||||
|
||||
def cleanup_old_jobs(self, days: int = 30) -> int:
|
||||
"""
|
||||
Delete completed/failed jobs older than specified days.
|
||||
|
||||
Args:
|
||||
days: Number of days to keep jobs
|
||||
|
||||
Returns:
|
||||
Number of jobs deleted
|
||||
"""
|
||||
with self.db.get_session() as session:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
deleted = (
|
||||
session.query(Job)
|
||||
.filter(
|
||||
Job.status.in_([JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED]),
|
||||
Job.completed_at < cutoff_date
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
||||
if deleted > 0:
|
||||
logger.info(f"Cleaned up {deleted} old jobs (older than {days} days)")
|
||||
|
||||
return deleted
|
||||
|
||||
def _find_existing_job(
|
||||
self,
|
||||
session: Session,
|
||||
file_path: str,
|
||||
target_lang: Optional[str]
|
||||
) -> Optional[Job]:
|
||||
"""
|
||||
Find existing job for the same file and target language.
|
||||
|
||||
Ignores completed jobs - allows re-transcription.
|
||||
"""
|
||||
query = session.query(Job).filter(
|
||||
Job.file_path == file_path,
|
||||
Job.status.in_([JobStatus.QUEUED, JobStatus.PROCESSING])
|
||||
)
|
||||
|
||||
if target_lang:
|
||||
query = query.filter(Job.target_lang == target_lang)
|
||||
|
||||
return query.first()
|
||||
|
||||
|
||||
# Global queue manager instance
|
||||
queue_manager = QueueManager()
|
||||
285
backend/core/worker.py
Normal file
285
backend/core/worker.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""Individual worker for processing transcription jobs."""
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from backend.core.database import Database
|
||||
from backend.core.models import Job, JobStatus, JobStage
|
||||
from backend.core.queue_manager import QueueManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkerType(str, Enum):
|
||||
"""Worker device type."""
|
||||
CPU = "cpu"
|
||||
GPU = "gpu"
|
||||
|
||||
|
||||
class WorkerStatus(str, Enum):
|
||||
"""Worker status states."""
|
||||
IDLE = "idle"
|
||||
BUSY = "busy"
|
||||
STOPPING = "stopping"
|
||||
STOPPED = "stopped"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class Worker:
|
||||
"""
|
||||
Individual worker process for transcription.
|
||||
|
||||
Each worker runs in its own process and can handle one job at a time.
|
||||
Workers communicate with the main process via multiprocessing primitives.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_id: str,
|
||||
worker_type: WorkerType,
|
||||
device_id: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
Initialize worker.
|
||||
|
||||
Args:
|
||||
worker_id: Unique identifier for this worker
|
||||
worker_type: CPU or GPU
|
||||
device_id: GPU device ID (only for GPU workers)
|
||||
"""
|
||||
self.worker_id = worker_id
|
||||
self.worker_type = worker_type
|
||||
self.device_id = device_id
|
||||
|
||||
# Multiprocessing primitives
|
||||
self.process: Optional[mp.Process] = None
|
||||
self.stop_event = mp.Event()
|
||||
self.status = mp.Value('i', WorkerStatus.IDLE.value) # type: ignore
|
||||
self.current_job_id = mp.Array('c', 36) # type: ignore # UUID string
|
||||
|
||||
# Stats
|
||||
self.jobs_completed = mp.Value('i', 0) # type: ignore
|
||||
self.jobs_failed = mp.Value('i', 0) # type: ignore
|
||||
self.started_at: Optional[datetime] = None
|
||||
|
||||
def start(self):
|
||||
"""Start the worker process."""
|
||||
if self.process and self.process.is_alive():
|
||||
logger.warning(f"Worker {self.worker_id} is already running")
|
||||
return
|
||||
|
||||
self.stop_event.clear()
|
||||
self.process = mp.Process(
|
||||
target=self._worker_loop,
|
||||
name=f"Worker-{self.worker_id}",
|
||||
daemon=True
|
||||
)
|
||||
self.process.start()
|
||||
self.started_at = datetime.utcnow()
|
||||
logger.info(
|
||||
f"Worker {self.worker_id} started (PID: {self.process.pid}, "
|
||||
f"Type: {self.worker_type.value})"
|
||||
)
|
||||
|
||||
def stop(self, timeout: float = 30.0):
|
||||
"""
|
||||
Stop the worker process gracefully.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for worker to stop
|
||||
"""
|
||||
if not self.process or not self.process.is_alive():
|
||||
logger.warning(f"Worker {self.worker_id} is not running")
|
||||
return
|
||||
|
||||
logger.info(f"Stopping worker {self.worker_id}...")
|
||||
self.stop_event.set()
|
||||
self.process.join(timeout=timeout)
|
||||
|
||||
if self.process.is_alive():
|
||||
logger.warning(f"Worker {self.worker_id} did not stop gracefully, terminating...")
|
||||
self.process.terminate()
|
||||
self.process.join(timeout=5.0)
|
||||
|
||||
if self.process.is_alive():
|
||||
logger.error(f"Worker {self.worker_id} did not terminate, killing...")
|
||||
self.process.kill()
|
||||
|
||||
logger.info(f"Worker {self.worker_id} stopped")
|
||||
|
||||
def is_alive(self) -> bool:
|
||||
"""Check if worker process is alive."""
|
||||
return self.process is not None and self.process.is_alive()
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Get worker status information."""
|
||||
status_value = self.status.value
|
||||
status_enum = WorkerStatus.IDLE
|
||||
for s in WorkerStatus:
|
||||
if s.value == status_value:
|
||||
status_enum = s
|
||||
break
|
||||
|
||||
current_job = self.current_job_id.value.decode('utf-8').strip('\x00')
|
||||
|
||||
return {
|
||||
"worker_id": self.worker_id,
|
||||
"type": self.worker_type.value,
|
||||
"device_id": self.device_id,
|
||||
"status": status_enum.value,
|
||||
"current_job_id": current_job if current_job else None,
|
||||
"jobs_completed": self.jobs_completed.value,
|
||||
"jobs_failed": self.jobs_failed.value,
|
||||
"is_alive": self.is_alive(),
|
||||
"pid": self.process.pid if self.process else None,
|
||||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||
}
|
||||
|
||||
def _worker_loop(self):
|
||||
"""
|
||||
Main worker loop (runs in separate process).
|
||||
|
||||
This is the entry point for the worker process.
|
||||
"""
|
||||
# Set up logging in the worker process
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format=f'[Worker-{self.worker_id}] %(levelname)s: %(message)s'
|
||||
)
|
||||
|
||||
logger.info(f"Worker {self.worker_id} loop started")
|
||||
|
||||
# Initialize database and queue manager in worker process
|
||||
# Each process needs its own DB connection
|
||||
try:
|
||||
db = Database(auto_create_tables=False)
|
||||
queue_mgr = QueueManager()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize worker: {e}")
|
||||
self._set_status(WorkerStatus.ERROR)
|
||||
return
|
||||
|
||||
# Main work loop
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
# Try to get next job from queue
|
||||
job = queue_mgr.get_next_job(self.worker_id)
|
||||
|
||||
if job is None:
|
||||
# No jobs available, idle for a bit
|
||||
self._set_status(WorkerStatus.IDLE)
|
||||
time.sleep(2)
|
||||
continue
|
||||
|
||||
# Process the job
|
||||
self._set_status(WorkerStatus.BUSY)
|
||||
self._set_current_job(job.id)
|
||||
|
||||
logger.info(f"Processing job {job.id}: {job.file_name}")
|
||||
|
||||
try:
|
||||
self._process_job(job, queue_mgr)
|
||||
self.jobs_completed.value += 1
|
||||
logger.info(f"Job {job.id} completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.jobs_failed.value += 1
|
||||
error_msg = f"Job processing failed: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
|
||||
queue_mgr.mark_job_failed(job.id, error_msg)
|
||||
|
||||
finally:
|
||||
self._clear_current_job()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Worker loop error: {e}\n{traceback.format_exc()}")
|
||||
time.sleep(5) # Back off on errors
|
||||
|
||||
self._set_status(WorkerStatus.STOPPED)
|
||||
logger.info(f"Worker {self.worker_id} loop ended")
|
||||
|
||||
def _process_job(self, job: Job, queue_mgr: QueueManager):
|
||||
"""
|
||||
Process a single transcription job.
|
||||
|
||||
Args:
|
||||
job: Job to process
|
||||
queue_mgr: Queue manager for updating progress
|
||||
"""
|
||||
# TODO: This will be implemented when we add the transcriber module
|
||||
# For now, simulate work
|
||||
|
||||
# Stage 1: Detect language
|
||||
queue_mgr.update_job_progress(
|
||||
job.id,
|
||||
progress=10.0,
|
||||
stage=JobStage.DETECTING_LANGUAGE,
|
||||
eta_seconds=60
|
||||
)
|
||||
time.sleep(2) # Simulate work
|
||||
|
||||
# Stage 2: Extract audio
|
||||
queue_mgr.update_job_progress(
|
||||
job.id,
|
||||
progress=20.0,
|
||||
stage=JobStage.EXTRACTING_AUDIO,
|
||||
eta_seconds=50
|
||||
)
|
||||
time.sleep(2)
|
||||
|
||||
# Stage 3: Transcribe
|
||||
queue_mgr.update_job_progress(
|
||||
job.id,
|
||||
progress=30.0,
|
||||
stage=JobStage.TRANSCRIBING,
|
||||
eta_seconds=40
|
||||
)
|
||||
|
||||
# Simulate progressive transcription
|
||||
for i in range(30, 90, 10):
|
||||
time.sleep(1)
|
||||
queue_mgr.update_job_progress(
|
||||
job.id,
|
||||
progress=float(i),
|
||||
stage=JobStage.TRANSCRIBING,
|
||||
eta_seconds=int((100 - i) / 2)
|
||||
)
|
||||
|
||||
# Stage 4: Finalize
|
||||
queue_mgr.update_job_progress(
|
||||
job.id,
|
||||
progress=95.0,
|
||||
stage=JobStage.FINALIZING,
|
||||
eta_seconds=5
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
# Mark as completed
|
||||
output_path = job.file_path.replace('.mkv', '.srt')
|
||||
queue_mgr.mark_job_completed(
|
||||
job.id,
|
||||
output_path=output_path,
|
||||
segments_count=100, # Simulated
|
||||
srt_content="Simulated SRT content"
|
||||
)
|
||||
|
||||
def _set_status(self, status: WorkerStatus):
|
||||
"""Set worker status (thread-safe)."""
|
||||
self.status.value = status.value
|
||||
|
||||
def _set_current_job(self, job_id: str):
|
||||
"""Set current job ID (thread-safe)."""
|
||||
job_id_bytes = job_id.encode('utf-8')
|
||||
for i, byte in enumerate(job_id_bytes):
|
||||
if i < len(self.current_job_id):
|
||||
self.current_job_id[i] = byte
|
||||
|
||||
def _clear_current_job(self):
|
||||
"""Clear current job ID (thread-safe)."""
|
||||
for i in range(len(self.current_job_id)):
|
||||
self.current_job_id[i] = b'\x00'
|
||||
Reference in New Issue
Block a user