Files
Transcriptarr/backend/transcription/transcriber.py
Dasemu cbf5ef9623 feat(transcription): add Whisper transcriber and audio utilities
- Add WhisperTranscriber wrapper for stable-ts/faster-whisper
- Add audio utilities for ffmpeg/ffprobe operations
- Add translator for two-stage translation workflow
- Support CPU/GPU with graceful degradation
2026-01-16 16:55:02 +01:00

409 lines
13 KiB
Python

"""Whisper transcription wrapper for worker processes."""
import logging
import os
import gc
import ctypes
import ctypes.util
from typing import Optional, Callable
from io import BytesIO
import numpy as np
# Optional imports - graceful degradation if not available
try:
import stable_whisper
import torch
WHISPER_AVAILABLE = True
except ImportError:
stable_whisper = None
torch = None
WHISPER_AVAILABLE = False
logging.warning("stable_whisper or torch not available. Transcription will not work.")
logger = logging.getLogger(__name__)
class TranscriptionResult:
"""Result of a transcription operation."""
def __init__(self, result, language: str, segments_count: int):
"""
Initialize transcription result.
Args:
result: stable-ts result object
language: Detected or forced language
segments_count: Number of subtitle segments
"""
self.result = result
self.language = language
self.segments_count = segments_count
def to_srt(self, output_path: str, word_level: bool = False) -> str:
"""
Save result as SRT file.
Args:
output_path: Path to save SRT file
word_level: Enable word-level timestamps
Returns:
Path to saved file
"""
self.result.to_srt_vtt(output_path, word_level=word_level)
return output_path
def get_srt_content(self, word_level: bool = False) -> str:
"""
Get SRT content as string.
Args:
word_level: Enable word-level timestamps
Returns:
SRT content
"""
return "".join(self.result.to_srt_vtt(filepath=None, word_level=word_level))
class WhisperTranscriber:
"""
Whisper transcription engine wrapper.
Manages Whisper model loading/unloading and transcription operations.
Designed to run in worker processes with isolated model instances.
"""
def __init__(
self,
model_name: Optional[str] = None,
device: Optional[str] = None,
model_path: Optional[str] = None,
compute_type: Optional[str] = None,
threads: Optional[int] = None,
):
"""
Initialize transcriber.
Args:
model_name: Whisper model name (tiny, base, small, medium, large, etc.)
device: Device to use (cpu, cuda, gpu)
model_path: Path to store/load models
compute_type: Compute type (auto, int8, float16, etc.)
threads: Number of CPU threads
"""
# Import settings_service here to avoid circular imports
from backend.core.settings_service import settings_service
# Load from database settings with sensible defaults
self.model_name = model_name or settings_service.get('whisper_model', 'medium')
self.device = (device or settings_service.get('transcribe_device', 'cpu')).lower()
if self.device == "gpu":
self.device = "cuda"
self.model_path = model_path or settings_service.get('model_path', './models')
# Get compute_type from settings based on device type
if compute_type:
requested_compute_type = compute_type
elif self.device == "cpu":
requested_compute_type = settings_service.get('cpu_compute_type', 'auto')
else:
requested_compute_type = settings_service.get('gpu_compute_type', 'auto')
# Auto-detect compatible compute_type based on device
self.compute_type = self._get_compatible_compute_type(self.device, requested_compute_type)
self.threads = threads or int(settings_service.get('whisper_threads', 4))
self.model = None
self.is_loaded = False
if self.compute_type != requested_compute_type:
logger.warning(
f"Requested compute_type '{requested_compute_type}' is not compatible with device '{self.device}'. "
f"Using '{self.compute_type}' instead."
)
logger.info(
f"WhisperTranscriber initialized: model={self.model_name}, "
f"device={self.device}, compute_type={self.compute_type}"
)
def _get_compatible_compute_type(self, device: str, requested: str) -> str:
"""
Get compatible compute type for the device.
CPU: Only supports int8 and float32
GPU: Supports float16, float32, int8, int8_float16
Args:
device: Device type (cpu, cuda)
requested: Requested compute type
Returns:
Compatible compute type
"""
if device == "cpu":
# CPU only supports int8 and float32
if requested == "auto":
return "int8" # int8 is faster on CPU
elif requested in ("float16", "int8_float16"):
logger.warning(f"CPU doesn't support {requested}, falling back to int8")
return "int8"
elif requested in ("int8", "float32"):
return requested
else:
logger.warning(f"Unknown compute type {requested}, using int8")
return "int8"
else:
# CUDA/GPU supports all types
if requested == "auto":
return "float16" # float16 is recommended for GPU
elif requested in ("float16", "float32", "int8", "int8_float16"):
return requested
else:
logger.warning(f"Unknown compute type {requested}, using float16")
return "float16"
def load_model(self):
"""Load Whisper model into memory."""
if not WHISPER_AVAILABLE:
raise RuntimeError(
"Whisper is not available. Install with: pip install stable-ts faster-whisper"
)
if self.is_loaded and self.model is not None:
logger.debug("Model already loaded")
return
try:
logger.info(f"Loading Whisper model: {self.model_name}")
self.model = stable_whisper.load_faster_whisper(
self.model_name,
download_root=self.model_path,
device=self.device,
cpu_threads=self.threads,
num_workers=1, # Each worker has own model
compute_type=self.compute_type if self.device == "gpu" or self.device == "cuda" else "float32",
)
self.is_loaded = True
logger.info(f"Model {self.model_name} loaded successfully")
except Exception as e:
logger.error(f"Failed to load model {self.model_name}: {e}")
raise
def unload_model(self):
"""Unload model from memory and clear cache."""
if not self.is_loaded or self.model is None:
logger.debug("Model not loaded, nothing to unload")
return
try:
logger.info("Unloading Whisper model")
# Unload the model
if hasattr(self.model, "model") and hasattr(self.model.model, "unload_model"):
self.model.model.unload_model()
del self.model
self.model = None
self.is_loaded = False
# Clear CUDA cache if using GPU
if self.device == "cuda" and torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("CUDA cache cleared")
# Garbage collection
if os.name != "nt": # Don't run on Windows
gc.collect()
try:
ctypes.CDLL(ctypes.util.find_library("c")).malloc_trim(0)
except Exception:
pass
logger.info("Model unloaded successfully")
except Exception as e:
logger.error(f"Error unloading model: {e}")
def transcribe_file(
self,
file_path: str,
language: Optional[str] = None,
task: str = "transcribe",
progress_callback: Optional[Callable] = None,
) -> TranscriptionResult:
"""
Transcribe a media file.
Args:
file_path: Path to media file
language: Language code (ISO 639-1) or None for auto-detect
task: 'transcribe' or 'translate'
progress_callback: Optional callback for progress updates
Returns:
TranscriptionResult object
Raises:
Exception: If transcription fails
"""
# Ensure model is loaded
if not self.is_loaded:
self.load_model()
try:
logger.info(f"Transcribing file: {file_path} (language={language}, task={task})")
# Prepare transcription arguments
args = {}
if progress_callback:
args["progress_callback"] = progress_callback
# Add custom regroup if configured
from backend.core.settings_service import settings_service
custom_regroup = settings_service.get('custom_regroup', 'cm_sl=84_sl=42++++++1')
if custom_regroup:
args["regroup"] = custom_regroup
# Perform transcription
result = self.model.transcribe(
file_path,
language=language,
task=task,
**args,
)
segments_count = len(result.segments) if hasattr(result, "segments") else 0
detected_language = result.language if hasattr(result, "language") else language or "unknown"
logger.info(
f"Transcription completed: {segments_count} segments, "
f"language={detected_language}"
)
return TranscriptionResult(
result=result,
language=detected_language,
segments_count=segments_count,
)
except Exception as e:
logger.error(f"Transcription failed for {file_path}: {e}")
raise
def transcribe_audio_data(
self,
audio_data: bytes,
language: Optional[str] = None,
task: str = "transcribe",
sample_rate: int = 16000,
progress_callback: Optional[Callable] = None,
) -> TranscriptionResult:
"""
Transcribe raw audio data (for Bazarr provider mode).
Args:
audio_data: Raw audio bytes
language: Language code or None
task: 'transcribe' or 'translate'
sample_rate: Audio sample rate
progress_callback: Optional progress callback
Returns:
TranscriptionResult object
"""
if not self.is_loaded:
self.load_model()
try:
logger.info(f"Transcribing audio data (size={len(audio_data)} bytes)")
args = {
"audio": audio_data,
"input_sr": sample_rate,
}
if progress_callback:
args["progress_callback"] = progress_callback
from backend.core.settings_service import settings_service
custom_regroup = settings_service.get('custom_regroup', 'cm_sl=84_sl=42++++++1')
if custom_regroup:
args["regroup"] = custom_regroup
result = self.model.transcribe(task=task, language=language, **args)
segments_count = len(result.segments) if hasattr(result, "segments") else 0
detected_language = result.language if hasattr(result, "language") else language or "unknown"
logger.info(f"Audio transcription completed: {segments_count} segments")
return TranscriptionResult(
result=result,
language=detected_language,
segments_count=segments_count,
)
except Exception as e:
logger.error(f"Audio transcription failed: {e}")
raise
def detect_language(
self,
file_path: str,
offset: int = 0,
length: int = 30,
) -> str:
"""
Detect language of a media file.
Args:
file_path: Path to media file
offset: Start offset in seconds
length: Duration to analyze in seconds
Returns:
Language code (ISO 639-1)
"""
if not self.is_loaded:
self.load_model()
try:
logger.info(f"Detecting language for: {file_path} (offset={offset}s, length={length}s)")
# Extract audio segment for analysis
from backend.transcription.audio_utils import extract_audio_segment
audio_segment = extract_audio_segment(file_path, offset, length)
result = self.model.transcribe(audio_segment.read())
detected_language = result.language if hasattr(result, "language") else "unknown"
logger.info(f"Detected language: {detected_language}")
return detected_language
except Exception as e:
logger.error(f"Language detection failed for {file_path}: {e}")
return "unknown"
def __enter__(self):
"""Context manager entry."""
self.load_model()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
from backend.core.settings_service import settings_service
if settings_service.get('clear_vram_on_complete', True) in (True, 'true', 'True', '1', 1):
self.unload_model()
def __del__(self):
"""Destructor - ensure model is unloaded."""
try:
if self.is_loaded:
self.unload_model()
except Exception:
pass