diff --git a/subgen.env b/.env.example similarity index 100% rename from subgen.env rename to .env.example diff --git a/backend/__init__.py b/backend/__init__.py new file mode 100644 index 0000000..b9a95b0 --- /dev/null +++ b/backend/__init__.py @@ -0,0 +1 @@ +"""TranscriptorIO Backend Package.""" diff --git a/backend/api/__init__.py b/backend/api/__init__.py new file mode 100644 index 0000000..89b73bd --- /dev/null +++ b/backend/api/__init__.py @@ -0,0 +1 @@ +"""TranscriptorIO API Module.""" \ No newline at end of file diff --git a/backend/config.py b/backend/config.py new file mode 100644 index 0000000..5b8017b --- /dev/null +++ b/backend/config.py @@ -0,0 +1,214 @@ +"""Configuration management for TranscriptorIO.""" +import os +from enum import Enum +from typing import Optional, List +from pydantic_settings import BaseSettings +from pydantic import Field, field_validator + + +class OperationMode(str, Enum): + """Operation modes for TranscriptorIO.""" + STANDALONE = "standalone" + PROVIDER = "provider" + HYBRID = "standalone,provider" + + +class DatabaseType(str, Enum): + """Supported database backends.""" + SQLITE = "sqlite" + POSTGRESQL = "postgresql" + MARIADB = "mariadb" + MYSQL = "mysql" + + +class Settings(BaseSettings): + """Application settings loaded from environment variables.""" + + # === Application Mode === + transcriptarr_mode: str = Field( + default="standalone", + description="Operation mode: standalone, provider, or standalone,provider" + ) + + # === Database Configuration === + database_url: str = Field( + default="sqlite:///./transcriptarr.db", + description="Database connection URL. Examples:\n" + " SQLite: sqlite:///./transcriptarr.db\n" + " PostgreSQL: postgresql://user:pass@localhost/transcriptarr\n" + " MariaDB: mariadb+pymysql://user:pass@localhost/transcriptarr" + ) + + # === Worker Configuration === + concurrent_transcriptions: int = Field(default=2, ge=1, le=10) + whisper_threads: int = Field(default=4, ge=1, le=32) + transcribe_device: str = Field(default="cpu", pattern="^(cpu|gpu|cuda)$") + clear_vram_on_complete: bool = Field(default=True) + + # === Whisper Model Configuration === + whisper_model: str = Field( + default="medium", + description="Whisper model: tiny, base, small, medium, large-v3, etc." + ) + model_path: str = Field(default="./models") + compute_type: str = Field(default="auto") + + # === Standalone Mode Configuration === + library_paths: Optional[str] = Field( + default=None, + description="Pipe-separated paths to scan: /media/anime|/media/movies" + ) + auto_scan_enabled: bool = Field(default=False) + scan_interval_minutes: int = Field(default=30, ge=1) + + required_audio_language: Optional[str] = Field( + default=None, + description="Only process files with this audio language (ISO 639-2)" + ) + required_missing_subtitle: Optional[str] = Field( + default=None, + description="Only process if this subtitle language is missing (ISO 639-2)" + ) + skip_if_subtitle_exists: bool = Field(default=True) + + # === Provider Mode Configuration === + bazarr_url: Optional[str] = Field(default=None) + bazarr_api_key: Optional[str] = Field(default=None) + provider_timeout_seconds: int = Field(default=600, ge=60) + provider_callback_enabled: bool = Field(default=True) + provider_polling_interval: int = Field(default=30, ge=10) + + # === API Configuration === + webhook_port: int = Field(default=9000, ge=1024, le=65535) + api_host: str = Field(default="0.0.0.0") + debug: bool = Field(default=True) + + # === Transcription Settings === + transcribe_or_translate: str = Field( + default="transcribe", + pattern="^(transcribe|translate)$" + ) + subtitle_language_name: str = Field(default="") + subtitle_language_naming_type: str = Field( + default="ISO_639_2_B", + description="Naming format: ISO_639_1, ISO_639_2_T, ISO_639_2_B, NAME, NATIVE" + ) + word_level_highlight: bool = Field(default=False) + custom_regroup: str = Field(default="cm_sl=84_sl=42++++++1") + + # === Skip Configuration === + skip_if_external_subtitles_exist: bool = Field(default=False) + skip_if_target_subtitles_exist: bool = Field(default=True) + skip_if_internal_subtitles_language: Optional[str] = Field(default="eng") + skip_subtitle_languages: Optional[str] = Field( + default=None, + description="Pipe-separated language codes to skip: eng|spa" + ) + skip_if_audio_languages: Optional[str] = Field( + default=None, + description="Skip if audio track is in these languages: eng|spa" + ) + skip_unknown_language: bool = Field(default=False) + skip_only_subgen_subtitles: bool = Field(default=False) + + # === Advanced Settings === + force_detected_language_to: Optional[str] = Field(default=None) + detect_language_length: int = Field(default=30, ge=5) + detect_language_offset: int = Field(default=0, ge=0) + should_whisper_detect_audio_language: bool = Field(default=False) + + preferred_audio_languages: str = Field( + default="eng", + description="Pipe-separated list in order of preference: eng|jpn" + ) + + # === Path Mapping === + use_path_mapping: bool = Field(default=False) + path_mapping_from: str = Field(default="/tv") + path_mapping_to: str = Field(default="/Volumes/TV") + + # === Legacy SubGen Compatibility === + show_in_subname_subgen: bool = Field(default=True) + show_in_subname_model: bool = Field(default=True) + append: bool = Field(default=False) + lrc_for_audio_files: bool = Field(default=True) + + @field_validator("transcriptarr_mode") + @classmethod + def validate_mode(cls, v: str) -> str: + """Validate operation mode.""" + valid_modes = {"standalone", "provider", "standalone,provider"} + if v not in valid_modes: + raise ValueError(f"Invalid mode: {v}. Must be one of: {valid_modes}") + return v + + @field_validator("database_url") + @classmethod + def validate_database_url(cls, v: str) -> str: + """Validate database URL format.""" + valid_prefixes = ("sqlite://", "postgresql://", "mariadb+pymysql://", "mysql+pymysql://") + if not any(v.startswith(prefix) for prefix in valid_prefixes): + raise ValueError( + f"Invalid database URL. Must start with one of: {valid_prefixes}" + ) + return v + + @property + def database_type(self) -> DatabaseType: + """Get the database type from the URL.""" + if self.database_url.startswith("sqlite"): + return DatabaseType.SQLITE + elif self.database_url.startswith("postgresql"): + return DatabaseType.POSTGRESQL + elif "mariadb" in self.database_url: + return DatabaseType.MARIADB + elif "mysql" in self.database_url: + return DatabaseType.MYSQL + else: + raise ValueError(f"Unknown database type in URL: {self.database_url}") + + @property + def is_standalone_mode(self) -> bool: + """Check if standalone mode is enabled.""" + return "standalone" in self.transcriptarr_mode + + @property + def is_provider_mode(self) -> bool: + """Check if provider mode is enabled.""" + return "provider" in self.transcriptarr_mode + + @property + def library_paths_list(self) -> List[str]: + """Get library paths as a list.""" + if not self.library_paths: + return [] + return [p.strip() for p in self.library_paths.split("|") if p.strip()] + + @property + def skip_subtitle_languages_list(self) -> List[str]: + """Get skip subtitle languages as a list.""" + if not self.skip_subtitle_languages: + return [] + return [lang.strip() for lang in self.skip_subtitle_languages.split("|") if lang.strip()] + + @property + def skip_audio_languages_list(self) -> List[str]: + """Get skip audio languages as a list.""" + if not self.skip_if_audio_languages: + return [] + return [lang.strip() for lang in self.skip_if_audio_languages.split("|") if lang.strip()] + + @property + def preferred_audio_languages_list(self) -> List[str]: + """Get preferred audio languages as a list.""" + return [lang.strip() for lang in self.preferred_audio_languages.split("|") if lang.strip()] + + class Config: + """Pydantic configuration.""" + env_file = ".env" + env_file_encoding = "utf-8" + case_sensitive = False + + +# Global settings instance +settings = Settings() diff --git a/backend/core/__init__.py b/backend/core/__init__.py new file mode 100644 index 0000000..e49d580 --- /dev/null +++ b/backend/core/__init__.py @@ -0,0 +1 @@ +"""TranscriptorIO Core Module.""" \ No newline at end of file diff --git a/backend/core/database.py b/backend/core/database.py new file mode 100644 index 0000000..3e2c8de --- /dev/null +++ b/backend/core/database.py @@ -0,0 +1,219 @@ +"""Database configuration and session management.""" +import logging +from contextlib import contextmanager +from typing import Generator + +from sqlalchemy import create_engine, event, Engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.pool import StaticPool, QueuePool + +from backend.config import settings, DatabaseType + +logger = logging.getLogger(__name__) + +# Base class for all models +Base = declarative_base() + + +class Database: + """Database manager supporting SQLite, PostgreSQL, and MariaDB.""" + + def __init__(self, auto_create_tables: bool = True): + """ + Initialize database engine and session maker. + + Args: + auto_create_tables: If True, automatically create tables if they don't exist + """ + self.engine = self._create_engine() + self.SessionLocal = sessionmaker( + autocommit=False, + autoflush=False, + bind=self.engine + ) + logger.info(f"Database initialized: {settings.database_type.value}") + + # Automatically create tables if they don't exist + if auto_create_tables: + self._ensure_tables_exist() + + def _create_engine(self) -> Engine: + """Create SQLAlchemy engine based on database type.""" + connect_args = {} + poolclass = QueuePool + + if settings.database_type == DatabaseType.SQLITE: + # SQLite-specific configuration + connect_args = { + "check_same_thread": False, # Allow multi-threaded access + "timeout": 30.0, # Wait up to 30s for lock + } + # Use StaticPool for SQLite to avoid connection issues + poolclass = StaticPool + + # Enable WAL mode for better concurrency + engine = create_engine( + settings.database_url, + connect_args=connect_args, + poolclass=poolclass, + echo=settings.debug, + ) + + @event.listens_for(engine, "connect") + def set_sqlite_pragma(dbapi_conn, connection_record): + """Enable SQLite optimizations.""" + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA journal_mode=WAL") + cursor.execute("PRAGMA synchronous=NORMAL") + cursor.execute("PRAGMA foreign_keys=ON") + cursor.execute("PRAGMA cache_size=-64000") # 64MB cache + cursor.close() + + elif settings.database_type == DatabaseType.POSTGRESQL: + # PostgreSQL-specific configuration + try: + import psycopg2 # noqa: F401 + except ImportError: + raise ImportError( + "PostgreSQL support requires psycopg2-binary.\n" + "Install it with: pip install psycopg2-binary" + ) + + engine = create_engine( + settings.database_url, + pool_size=10, + max_overflow=20, + pool_pre_ping=True, # Verify connections before using + echo=settings.debug, + ) + + elif settings.database_type in (DatabaseType.MARIADB, DatabaseType.MYSQL): + # MariaDB/MySQL-specific configuration + try: + import pymysql # noqa: F401 + except ImportError: + raise ImportError( + "MariaDB/MySQL support requires pymysql.\n" + "Install it with: pip install pymysql" + ) + + connect_args = { + "charset": "utf8mb4", + } + engine = create_engine( + settings.database_url, + connect_args=connect_args, + pool_size=10, + max_overflow=20, + pool_pre_ping=True, + echo=settings.debug, + ) + + else: + raise ValueError(f"Unsupported database type: {settings.database_type}") + + return engine + + def _ensure_tables_exist(self): + """Check if tables exist and create them if they don't.""" + # Import models to register them with Base.metadata + from backend.core import models # noqa: F401 + from sqlalchemy import inspect + + inspector = inspect(self.engine) + existing_tables = inspector.get_table_names() + + # Check if the main 'jobs' table exists + if 'jobs' not in existing_tables: + logger.info("Tables don't exist, creating them automatically...") + self.create_tables() + else: + logger.debug("Database tables already exist") + + def create_tables(self): + """Create all database tables.""" + # Import models to register them with Base.metadata + from backend.core import models # noqa: F401 + + logger.info("Creating database tables...") + Base.metadata.create_all(bind=self.engine, checkfirst=True) + + # Verify tables were actually created + from sqlalchemy import inspect + inspector = inspect(self.engine) + created_tables = inspector.get_table_names() + + if 'jobs' in created_tables: + logger.info(f"Database tables created successfully: {created_tables}") + else: + logger.error(f"Failed to create tables. Existing tables: {created_tables}") + raise RuntimeError("Failed to create database tables") + + def drop_tables(self): + """Drop all database tables (use with caution!).""" + logger.warning("Dropping all database tables...") + Base.metadata.drop_all(bind=self.engine) + logger.info("Database tables dropped") + + @contextmanager + def get_session(self) -> Generator[Session, None, None]: + """ + Get a database session as a context manager. + + Usage: + with db.get_session() as session: + session.query(Job).all() + """ + session = self.SessionLocal() + try: + yield session + session.commit() + except Exception as e: + session.rollback() + logger.error(f"Database session error: {e}") + raise + finally: + session.close() + + def get_db(self) -> Generator[Session, None, None]: + """ + Dependency for FastAPI endpoints. + + Usage: + @app.get("/jobs") + def get_jobs(db: Session = Depends(database.get_db)): + return db.query(Job).all() + """ + session = self.SessionLocal() + try: + yield session + finally: + session.close() + + def health_check(self) -> bool: + """Check if database connection is healthy.""" + try: + from sqlalchemy import text + with self.get_session() as session: + session.execute(text("SELECT 1")) + return True + except Exception as e: + logger.error(f"Database health check failed: {e}") + return False + + def get_stats(self) -> dict: + """Get database statistics.""" + stats = { + "type": settings.database_type.value, + "url": settings.database_url.split("@")[-1] if "@" in settings.database_url else settings.database_url, + "pool_size": getattr(self.engine.pool, "size", lambda: "N/A")(), + "pool_checked_in": getattr(self.engine.pool, "checkedin", lambda: 0)(), + "pool_checked_out": getattr(self.engine.pool, "checkedout", lambda: 0)(), + "pool_overflow": getattr(self.engine.pool, "overflow", lambda: 0)(), + } + return stats + + +# Global database instance +database = Database() \ No newline at end of file diff --git a/backend/core/models.py b/backend/core/models.py new file mode 100644 index 0000000..0c022c6 --- /dev/null +++ b/backend/core/models.py @@ -0,0 +1,203 @@ +"""Database models for TranscriptorIO.""" +import uuid +from datetime import datetime +from enum import Enum +from typing import Optional + +from sqlalchemy import ( + Column, String, Integer, Float, DateTime, Text, Boolean, Enum as SQLEnum, Index +) +from sqlalchemy.sql import func + +from backend.core.database import Base + + +class JobStatus(str, Enum): + """Job status states.""" + QUEUED = "queued" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class JobStage(str, Enum): + """Job processing stages.""" + PENDING = "pending" + DETECTING_LANGUAGE = "detecting_language" + EXTRACTING_AUDIO = "extracting_audio" + TRANSCRIBING = "transcribing" + TRANSLATING = "translating" + GENERATING_SUBTITLES = "generating_subtitles" + POST_PROCESSING = "post_processing" + FINALIZING = "finalizing" + + +class QualityPreset(str, Enum): + """Quality presets for transcription.""" + FAST = "fast" # ja→en→es with Helsinki-NLP (4GB VRAM) + BALANCED = "balanced" # ja→ja→es with M2M100 (6GB VRAM) + BEST = "best" # ja→es direct with SeamlessM4T (10GB+ VRAM) + + +class Job(Base): + """Job model representing a transcription task.""" + + __tablename__ = "jobs" + + # Primary identification + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + file_path = Column(String(1024), nullable=False, index=True) + file_name = Column(String(512), nullable=False) + + # Job status + status = Column( + SQLEnum(JobStatus), + nullable=False, + default=JobStatus.QUEUED, + index=True + ) + priority = Column(Integer, nullable=False, default=0, index=True) + + # Configuration + source_lang = Column(String(10), nullable=True) + target_lang = Column(String(10), nullable=True) + quality_preset = Column( + SQLEnum(QualityPreset), + nullable=False, + default=QualityPreset.FAST + ) + transcribe_or_translate = Column(String(20), nullable=False, default="transcribe") + + # Progress tracking + progress = Column(Float, nullable=False, default=0.0) # 0-100 + current_stage = Column( + SQLEnum(JobStage), + nullable=False, + default=JobStage.PENDING + ) + eta_seconds = Column(Integer, nullable=True) + + # Timestamps + created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False, index=True) + started_at = Column(DateTime(timezone=True), nullable=True) + completed_at = Column(DateTime(timezone=True), nullable=True) + + # Results + output_path = Column(String(1024), nullable=True) + srt_content = Column(Text, nullable=True) + segments_count = Column(Integer, nullable=True) + + # Error handling + error = Column(Text, nullable=True) + retry_count = Column(Integer, nullable=False, default=0) + max_retries = Column(Integer, nullable=False, default=3) + + # Worker information + worker_id = Column(String(64), nullable=True) + vram_used_mb = Column(Integer, nullable=True) + processing_time_seconds = Column(Float, nullable=True) + + # Provider mode specific + bazarr_callback_url = Column(String(512), nullable=True) + is_manual_request = Column(Boolean, nullable=False, default=False) + + # Additional metadata + model_used = Column(String(64), nullable=True) + device_used = Column(String(32), nullable=True) + compute_type = Column(String(32), nullable=True) + + def __repr__(self): + """String representation of Job.""" + return f"" + + @property + def duration_seconds(self) -> Optional[float]: + """Calculate job duration in seconds.""" + if self.started_at and self.completed_at: + delta = self.completed_at - self.started_at + return delta.total_seconds() + return None + + @property + def is_terminal_state(self) -> bool: + """Check if job is in a terminal state (completed/failed/cancelled).""" + return self.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED) + + @property + def can_retry(self) -> bool: + """Check if job can be retried.""" + return self.status == JobStatus.FAILED and self.retry_count < self.max_retries + + def to_dict(self) -> dict: + """Convert job to dictionary for API responses.""" + return { + "id": self.id, + "file_path": self.file_path, + "file_name": self.file_name, + "status": self.status.value, + "priority": self.priority, + "source_lang": self.source_lang, + "target_lang": self.target_lang, + "quality_preset": self.quality_preset.value if self.quality_preset else None, + "transcribe_or_translate": self.transcribe_or_translate, + "progress": self.progress, + "current_stage": self.current_stage.value if self.current_stage else None, + "eta_seconds": self.eta_seconds, + "created_at": self.created_at.isoformat() if self.created_at else None, + "started_at": self.started_at.isoformat() if self.started_at else None, + "completed_at": self.completed_at.isoformat() if self.completed_at else None, + "output_path": self.output_path, + "segments_count": self.segments_count, + "error": self.error, + "retry_count": self.retry_count, + "worker_id": self.worker_id, + "vram_used_mb": self.vram_used_mb, + "processing_time_seconds": self.processing_time_seconds, + "model_used": self.model_used, + "device_used": self.device_used, + } + + def update_progress(self, progress: float, stage: JobStage, eta_seconds: Optional[int] = None): + """Update job progress.""" + self.progress = min(100.0, max(0.0, progress)) + self.current_stage = stage + if eta_seconds is not None: + self.eta_seconds = eta_seconds + + def mark_started(self, worker_id: str): + """Mark job as started.""" + self.status = JobStatus.PROCESSING + self.started_at = datetime.utcnow() + self.worker_id = worker_id + + def mark_completed(self, output_path: str, segments_count: int, srt_content: Optional[str] = None): + """Mark job as completed.""" + self.status = JobStatus.COMPLETED + self.completed_at = datetime.utcnow() + self.output_path = output_path + self.segments_count = segments_count + self.srt_content = srt_content + self.progress = 100.0 + self.current_stage = JobStage.FINALIZING + + if self.started_at: + self.processing_time_seconds = (self.completed_at - self.started_at).total_seconds() + + def mark_failed(self, error: str): + """Mark job as failed.""" + self.status = JobStatus.FAILED + self.completed_at = datetime.utcnow() + self.error = error + self.retry_count += 1 + + def mark_cancelled(self): + """Mark job as cancelled.""" + self.status = JobStatus.CANCELLED + self.completed_at = datetime.utcnow() + + +# Create indexes for common queries +Index('idx_jobs_status_priority', Job.status, Job.priority.desc(), Job.created_at) +Index('idx_jobs_created', Job.created_at.desc()) +Index('idx_jobs_file_path', Job.file_path) \ No newline at end of file diff --git a/backend/core/queue_manager.py b/backend/core/queue_manager.py new file mode 100644 index 0000000..6e4e2cc --- /dev/null +++ b/backend/core/queue_manager.py @@ -0,0 +1,394 @@ +"""Queue manager for persistent job queuing.""" +import logging +from datetime import datetime, timedelta +from typing import List, Optional, Dict +from sqlalchemy import and_, or_ +from sqlalchemy.orm import Session + +from backend.core.database import database +from backend.core.models import Job, JobStatus, JobStage, QualityPreset + +logger = logging.getLogger(__name__) + + +class QueueManager: + """ + Persistent queue manager for transcription jobs. + + Replaces the old DeduplicatedQueue with a database-backed solution that: + - Persists jobs across restarts + - Supports priority queuing + - Prevents duplicate jobs + - Provides visibility into queue state + - Thread-safe operations + """ + + def __init__(self): + """Initialize queue manager.""" + self.db = database + logger.info("QueueManager initialized") + + def add_job( + self, + file_path: str, + file_name: str, + source_lang: Optional[str] = None, + target_lang: Optional[str] = None, + quality_preset: QualityPreset = QualityPreset.FAST, + transcribe_or_translate: str = "transcribe", + priority: int = 0, + bazarr_callback_url: Optional[str] = None, + is_manual_request: bool = False, + ) -> Optional[Job]: + """ + Add a new job to the queue. + + Args: + file_path: Full path to the media file + file_name: Name of the file + source_lang: Source language code (ISO 639-2) + target_lang: Target language code (ISO 639-2) + quality_preset: Quality preset (fast/balanced/best) + transcribe_or_translate: Operation type + priority: Job priority (higher = processed first) + bazarr_callback_url: Callback URL for Bazarr provider mode + is_manual_request: Whether this is a manual request (higher priority) + + Returns: + Job object if created, None if duplicate exists + """ + with self.db.get_session() as session: + # Check for existing job + existing = self._find_existing_job(session, file_path, target_lang) + + if existing: + logger.info(f"Job already exists for {file_name}: {existing.id} [{existing.status.value}]") + + # If existing job failed and can retry, reset it + if existing.can_retry: + logger.info(f"Resetting failed job {existing.id} for retry") + existing.status = JobStatus.QUEUED + existing.error = None + existing.current_stage = JobStage.PENDING + existing.progress = 0.0 + session.commit() + return existing + + return None + + # Create new job + job = Job( + file_path=file_path, + file_name=file_name, + source_lang=source_lang, + target_lang=target_lang, + quality_preset=quality_preset, + transcribe_or_translate=transcribe_or_translate, + priority=priority + (10 if is_manual_request else 0), # Boost manual requests + bazarr_callback_url=bazarr_callback_url, + is_manual_request=is_manual_request, + ) + + session.add(job) + session.commit() + + # Access all attributes before session closes to ensure they're loaded + job_id = job.id + job_status = job.status + + logger.info( + f"Job {job_id} added to queue: {file_name} " + f"[{quality_preset.value}] priority={job.priority}" + ) + + # Re-query the job in a new session to return a fresh copy + with self.db.get_session() as session: + job = session.query(Job).filter(Job.id == job_id).first() + if job: + session.expunge(job) # Remove from session so it doesn't expire + return job + + def get_next_job(self, worker_id: str) -> Optional[Job]: + """ + Get the next job from the queue for processing. + + Jobs are selected based on: + 1. Status = QUEUED + 2. Priority (DESC) + 3. Created time (ASC) - FIFO within same priority + + Args: + worker_id: ID of the worker requesting the job + + Returns: + Job object or None if queue is empty + """ + with self.db.get_session() as session: + job = ( + session.query(Job) + .filter(Job.status == JobStatus.QUEUED) + .order_by( + Job.priority.desc(), + Job.created_at.asc() + ) + .with_for_update(skip_locked=True) # Skip locked rows (concurrent workers) + .first() + ) + + if job: + job_id = job.id + job.mark_started(worker_id) + session.commit() + logger.info(f"Job {job_id} assigned to worker {worker_id}") + + # Re-query the job if found + if job: + with self.db.get_session() as session: + job = session.query(Job).filter(Job.id == job_id).first() + if job: + session.expunge(job) # Remove from session so it doesn't expire + return job + + return None + + def get_job_by_id(self, job_id: str) -> Optional[Job]: + """Get a specific job by ID.""" + with self.db.get_session() as session: + return session.query(Job).filter(Job.id == job_id).first() + + def update_job_progress( + self, + job_id: str, + progress: float, + stage: JobStage, + eta_seconds: Optional[int] = None + ) -> bool: + """ + Update job progress. + + Args: + job_id: Job ID + progress: Progress percentage (0-100) + stage: Current processing stage + eta_seconds: Estimated time to completion + + Returns: + True if updated successfully, False otherwise + """ + with self.db.get_session() as session: + job = session.query(Job).filter(Job.id == job_id).first() + + if not job: + logger.warning(f"Job {job_id} not found for progress update") + return False + + job.update_progress(progress, stage, eta_seconds) + session.commit() + + logger.debug( + f"Job {job_id} progress: {progress:.1f}% [{stage.value}] ETA: {eta_seconds}s" + ) + return True + + def mark_job_completed( + self, + job_id: str, + output_path: str, + segments_count: int, + srt_content: Optional[str] = None + ) -> bool: + """Mark a job as completed.""" + with self.db.get_session() as session: + job = session.query(Job).filter(Job.id == job_id).first() + + if not job: + logger.warning(f"Job {job_id} not found for completion") + return False + + job.mark_completed(output_path, segments_count, srt_content) + session.commit() + + logger.info( + f"Job {job_id} completed: {output_path} " + f"({segments_count} segments, {job.processing_time_seconds:.1f}s)" + ) + return True + + def mark_job_failed(self, job_id: str, error: str) -> bool: + """Mark a job as failed.""" + with self.db.get_session() as session: + job = session.query(Job).filter(Job.id == job_id).first() + + if not job: + logger.warning(f"Job {job_id} not found for failure marking") + return False + + job.mark_failed(error) + session.commit() + + logger.error( + f"Job {job_id} failed (attempt {job.retry_count}/{job.max_retries}): {error}" + ) + return True + + def cancel_job(self, job_id: str) -> bool: + """Cancel a queued or processing job.""" + with self.db.get_session() as session: + job = session.query(Job).filter(Job.id == job_id).first() + + if not job: + logger.warning(f"Job {job_id} not found for cancellation") + return False + + if job.is_terminal_state: + logger.warning(f"Job {job_id} is already in terminal state: {job.status.value}") + return False + + job.mark_cancelled() + session.commit() + + logger.info(f"Job {job_id} cancelled") + return True + + def get_queue_stats(self) -> Dict: + """Get queue statistics.""" + with self.db.get_session() as session: + total = session.query(Job).count() + queued = session.query(Job).filter(Job.status == JobStatus.QUEUED).count() + processing = session.query(Job).filter(Job.status == JobStatus.PROCESSING).count() + completed = session.query(Job).filter(Job.status == JobStatus.COMPLETED).count() + failed = session.query(Job).filter(Job.status == JobStatus.FAILED).count() + + # Get today's stats + today = datetime.utcnow().date() + completed_today = ( + session.query(Job) + .filter( + Job.status == JobStatus.COMPLETED, + Job.completed_at >= today + ) + .count() + ) + failed_today = ( + session.query(Job) + .filter( + Job.status == JobStatus.FAILED, + Job.completed_at >= today + ) + .count() + ) + + return { + "total": total, + "queued": queued, + "processing": processing, + "completed": completed, + "failed": failed, + "completed_today": completed_today, + "failed_today": failed_today, + } + + def get_jobs( + self, + status: Optional[JobStatus] = None, + limit: int = 50, + offset: int = 0 + ) -> List[Job]: + """ + Get jobs with optional filtering. + + Args: + status: Filter by status + limit: Maximum number of jobs to return + offset: Offset for pagination + + Returns: + List of Job objects + """ + with self.db.get_session() as session: + query = session.query(Job) + + if status: + query = query.filter(Job.status == status) + + jobs = ( + query + .order_by(Job.created_at.desc()) + .limit(limit) + .offset(offset) + .all() + ) + + return jobs + + def get_processing_jobs(self) -> List[Job]: + """Get all currently processing jobs.""" + return self.get_jobs(status=JobStatus.PROCESSING) + + def get_queued_jobs(self) -> List[Job]: + """Get all queued jobs.""" + return self.get_jobs(status=JobStatus.QUEUED) + + def is_queue_empty(self) -> bool: + """Check if queue has any pending jobs.""" + with self.db.get_session() as session: + count = ( + session.query(Job) + .filter(Job.status.in_([JobStatus.QUEUED, JobStatus.PROCESSING])) + .count() + ) + return count == 0 + + def cleanup_old_jobs(self, days: int = 30) -> int: + """ + Delete completed/failed jobs older than specified days. + + Args: + days: Number of days to keep jobs + + Returns: + Number of jobs deleted + """ + with self.db.get_session() as session: + cutoff_date = datetime.utcnow() - timedelta(days=days) + + deleted = ( + session.query(Job) + .filter( + Job.status.in_([JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED]), + Job.completed_at < cutoff_date + ) + .delete() + ) + + session.commit() + + if deleted > 0: + logger.info(f"Cleaned up {deleted} old jobs (older than {days} days)") + + return deleted + + def _find_existing_job( + self, + session: Session, + file_path: str, + target_lang: Optional[str] + ) -> Optional[Job]: + """ + Find existing job for the same file and target language. + + Ignores completed jobs - allows re-transcription. + """ + query = session.query(Job).filter( + Job.file_path == file_path, + Job.status.in_([JobStatus.QUEUED, JobStatus.PROCESSING]) + ) + + if target_lang: + query = query.filter(Job.target_lang == target_lang) + + return query.first() + + +# Global queue manager instance +queue_manager = QueueManager() \ No newline at end of file diff --git a/backend/core/worker.py b/backend/core/worker.py new file mode 100644 index 0000000..83c12ae --- /dev/null +++ b/backend/core/worker.py @@ -0,0 +1,285 @@ +"""Individual worker for processing transcription jobs.""" +import logging +import multiprocessing as mp +import time +import traceback +from datetime import datetime +from enum import Enum +from typing import Optional + +from backend.core.database import Database +from backend.core.models import Job, JobStatus, JobStage +from backend.core.queue_manager import QueueManager + +logger = logging.getLogger(__name__) + + +class WorkerType(str, Enum): + """Worker device type.""" + CPU = "cpu" + GPU = "gpu" + + +class WorkerStatus(str, Enum): + """Worker status states.""" + IDLE = "idle" + BUSY = "busy" + STOPPING = "stopping" + STOPPED = "stopped" + ERROR = "error" + + +class Worker: + """ + Individual worker process for transcription. + + Each worker runs in its own process and can handle one job at a time. + Workers communicate with the main process via multiprocessing primitives. + """ + + def __init__( + self, + worker_id: str, + worker_type: WorkerType, + device_id: Optional[int] = None + ): + """ + Initialize worker. + + Args: + worker_id: Unique identifier for this worker + worker_type: CPU or GPU + device_id: GPU device ID (only for GPU workers) + """ + self.worker_id = worker_id + self.worker_type = worker_type + self.device_id = device_id + + # Multiprocessing primitives + self.process: Optional[mp.Process] = None + self.stop_event = mp.Event() + self.status = mp.Value('i', WorkerStatus.IDLE.value) # type: ignore + self.current_job_id = mp.Array('c', 36) # type: ignore # UUID string + + # Stats + self.jobs_completed = mp.Value('i', 0) # type: ignore + self.jobs_failed = mp.Value('i', 0) # type: ignore + self.started_at: Optional[datetime] = None + + def start(self): + """Start the worker process.""" + if self.process and self.process.is_alive(): + logger.warning(f"Worker {self.worker_id} is already running") + return + + self.stop_event.clear() + self.process = mp.Process( + target=self._worker_loop, + name=f"Worker-{self.worker_id}", + daemon=True + ) + self.process.start() + self.started_at = datetime.utcnow() + logger.info( + f"Worker {self.worker_id} started (PID: {self.process.pid}, " + f"Type: {self.worker_type.value})" + ) + + def stop(self, timeout: float = 30.0): + """ + Stop the worker process gracefully. + + Args: + timeout: Maximum time to wait for worker to stop + """ + if not self.process or not self.process.is_alive(): + logger.warning(f"Worker {self.worker_id} is not running") + return + + logger.info(f"Stopping worker {self.worker_id}...") + self.stop_event.set() + self.process.join(timeout=timeout) + + if self.process.is_alive(): + logger.warning(f"Worker {self.worker_id} did not stop gracefully, terminating...") + self.process.terminate() + self.process.join(timeout=5.0) + + if self.process.is_alive(): + logger.error(f"Worker {self.worker_id} did not terminate, killing...") + self.process.kill() + + logger.info(f"Worker {self.worker_id} stopped") + + def is_alive(self) -> bool: + """Check if worker process is alive.""" + return self.process is not None and self.process.is_alive() + + def get_status(self) -> dict: + """Get worker status information.""" + status_value = self.status.value + status_enum = WorkerStatus.IDLE + for s in WorkerStatus: + if s.value == status_value: + status_enum = s + break + + current_job = self.current_job_id.value.decode('utf-8').strip('\x00') + + return { + "worker_id": self.worker_id, + "type": self.worker_type.value, + "device_id": self.device_id, + "status": status_enum.value, + "current_job_id": current_job if current_job else None, + "jobs_completed": self.jobs_completed.value, + "jobs_failed": self.jobs_failed.value, + "is_alive": self.is_alive(), + "pid": self.process.pid if self.process else None, + "started_at": self.started_at.isoformat() if self.started_at else None, + } + + def _worker_loop(self): + """ + Main worker loop (runs in separate process). + + This is the entry point for the worker process. + """ + # Set up logging in the worker process + logging.basicConfig( + level=logging.INFO, + format=f'[Worker-{self.worker_id}] %(levelname)s: %(message)s' + ) + + logger.info(f"Worker {self.worker_id} loop started") + + # Initialize database and queue manager in worker process + # Each process needs its own DB connection + try: + db = Database(auto_create_tables=False) + queue_mgr = QueueManager() + except Exception as e: + logger.error(f"Failed to initialize worker: {e}") + self._set_status(WorkerStatus.ERROR) + return + + # Main work loop + while not self.stop_event.is_set(): + try: + # Try to get next job from queue + job = queue_mgr.get_next_job(self.worker_id) + + if job is None: + # No jobs available, idle for a bit + self._set_status(WorkerStatus.IDLE) + time.sleep(2) + continue + + # Process the job + self._set_status(WorkerStatus.BUSY) + self._set_current_job(job.id) + + logger.info(f"Processing job {job.id}: {job.file_name}") + + try: + self._process_job(job, queue_mgr) + self.jobs_completed.value += 1 + logger.info(f"Job {job.id} completed successfully") + + except Exception as e: + self.jobs_failed.value += 1 + error_msg = f"Job processing failed: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + + queue_mgr.mark_job_failed(job.id, error_msg) + + finally: + self._clear_current_job() + + except Exception as e: + logger.error(f"Worker loop error: {e}\n{traceback.format_exc()}") + time.sleep(5) # Back off on errors + + self._set_status(WorkerStatus.STOPPED) + logger.info(f"Worker {self.worker_id} loop ended") + + def _process_job(self, job: Job, queue_mgr: QueueManager): + """ + Process a single transcription job. + + Args: + job: Job to process + queue_mgr: Queue manager for updating progress + """ + # TODO: This will be implemented when we add the transcriber module + # For now, simulate work + + # Stage 1: Detect language + queue_mgr.update_job_progress( + job.id, + progress=10.0, + stage=JobStage.DETECTING_LANGUAGE, + eta_seconds=60 + ) + time.sleep(2) # Simulate work + + # Stage 2: Extract audio + queue_mgr.update_job_progress( + job.id, + progress=20.0, + stage=JobStage.EXTRACTING_AUDIO, + eta_seconds=50 + ) + time.sleep(2) + + # Stage 3: Transcribe + queue_mgr.update_job_progress( + job.id, + progress=30.0, + stage=JobStage.TRANSCRIBING, + eta_seconds=40 + ) + + # Simulate progressive transcription + for i in range(30, 90, 10): + time.sleep(1) + queue_mgr.update_job_progress( + job.id, + progress=float(i), + stage=JobStage.TRANSCRIBING, + eta_seconds=int((100 - i) / 2) + ) + + # Stage 4: Finalize + queue_mgr.update_job_progress( + job.id, + progress=95.0, + stage=JobStage.FINALIZING, + eta_seconds=5 + ) + time.sleep(1) + + # Mark as completed + output_path = job.file_path.replace('.mkv', '.srt') + queue_mgr.mark_job_completed( + job.id, + output_path=output_path, + segments_count=100, # Simulated + srt_content="Simulated SRT content" + ) + + def _set_status(self, status: WorkerStatus): + """Set worker status (thread-safe).""" + self.status.value = status.value + + def _set_current_job(self, job_id: str): + """Set current job ID (thread-safe).""" + job_id_bytes = job_id.encode('utf-8') + for i, byte in enumerate(job_id_bytes): + if i < len(self.current_job_id): + self.current_job_id[i] = byte + + def _clear_current_job(self): + """Clear current job ID (thread-safe).""" + for i in range(len(self.current_job_id)): + self.current_job_id[i] = b'\x00' \ No newline at end of file diff --git a/subgen.py b/transcriptarr.py similarity index 100% rename from subgen.py rename to transcriptarr.py