- Add WhisperTranscriber wrapper for stable-ts/faster-whisper - Add audio utilities for ffmpeg/ffprobe operations - Add translator for two-stage translation workflow - Support CPU/GPU with graceful degradation
409 lines
13 KiB
Python
409 lines
13 KiB
Python
"""Whisper transcription wrapper for worker processes."""
|
|
import logging
|
|
import os
|
|
import gc
|
|
import ctypes
|
|
import ctypes.util
|
|
from typing import Optional, Callable
|
|
from io import BytesIO
|
|
|
|
import numpy as np
|
|
|
|
# Optional imports - graceful degradation if not available
|
|
try:
|
|
import stable_whisper
|
|
import torch
|
|
WHISPER_AVAILABLE = True
|
|
except ImportError:
|
|
stable_whisper = None
|
|
torch = None
|
|
WHISPER_AVAILABLE = False
|
|
logging.warning("stable_whisper or torch not available. Transcription will not work.")
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TranscriptionResult:
|
|
"""Result of a transcription operation."""
|
|
|
|
def __init__(self, result, language: str, segments_count: int):
|
|
"""
|
|
Initialize transcription result.
|
|
|
|
Args:
|
|
result: stable-ts result object
|
|
language: Detected or forced language
|
|
segments_count: Number of subtitle segments
|
|
"""
|
|
self.result = result
|
|
self.language = language
|
|
self.segments_count = segments_count
|
|
|
|
def to_srt(self, output_path: str, word_level: bool = False) -> str:
|
|
"""
|
|
Save result as SRT file.
|
|
|
|
Args:
|
|
output_path: Path to save SRT file
|
|
word_level: Enable word-level timestamps
|
|
|
|
Returns:
|
|
Path to saved file
|
|
"""
|
|
self.result.to_srt_vtt(output_path, word_level=word_level)
|
|
return output_path
|
|
|
|
def get_srt_content(self, word_level: bool = False) -> str:
|
|
"""
|
|
Get SRT content as string.
|
|
|
|
Args:
|
|
word_level: Enable word-level timestamps
|
|
|
|
Returns:
|
|
SRT content
|
|
"""
|
|
return "".join(self.result.to_srt_vtt(filepath=None, word_level=word_level))
|
|
|
|
|
|
class WhisperTranscriber:
|
|
"""
|
|
Whisper transcription engine wrapper.
|
|
|
|
Manages Whisper model loading/unloading and transcription operations.
|
|
Designed to run in worker processes with isolated model instances.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: Optional[str] = None,
|
|
device: Optional[str] = None,
|
|
model_path: Optional[str] = None,
|
|
compute_type: Optional[str] = None,
|
|
threads: Optional[int] = None,
|
|
):
|
|
"""
|
|
Initialize transcriber.
|
|
|
|
Args:
|
|
model_name: Whisper model name (tiny, base, small, medium, large, etc.)
|
|
device: Device to use (cpu, cuda, gpu)
|
|
model_path: Path to store/load models
|
|
compute_type: Compute type (auto, int8, float16, etc.)
|
|
threads: Number of CPU threads
|
|
"""
|
|
# Import settings_service here to avoid circular imports
|
|
from backend.core.settings_service import settings_service
|
|
|
|
# Load from database settings with sensible defaults
|
|
self.model_name = model_name or settings_service.get('whisper_model', 'medium')
|
|
self.device = (device or settings_service.get('transcribe_device', 'cpu')).lower()
|
|
if self.device == "gpu":
|
|
self.device = "cuda"
|
|
self.model_path = model_path or settings_service.get('model_path', './models')
|
|
|
|
# Get compute_type from settings based on device type
|
|
if compute_type:
|
|
requested_compute_type = compute_type
|
|
elif self.device == "cpu":
|
|
requested_compute_type = settings_service.get('cpu_compute_type', 'auto')
|
|
else:
|
|
requested_compute_type = settings_service.get('gpu_compute_type', 'auto')
|
|
|
|
# Auto-detect compatible compute_type based on device
|
|
self.compute_type = self._get_compatible_compute_type(self.device, requested_compute_type)
|
|
|
|
self.threads = threads or int(settings_service.get('whisper_threads', 4))
|
|
|
|
self.model = None
|
|
self.is_loaded = False
|
|
|
|
if self.compute_type != requested_compute_type:
|
|
logger.warning(
|
|
f"Requested compute_type '{requested_compute_type}' is not compatible with device '{self.device}'. "
|
|
f"Using '{self.compute_type}' instead."
|
|
)
|
|
|
|
logger.info(
|
|
f"WhisperTranscriber initialized: model={self.model_name}, "
|
|
f"device={self.device}, compute_type={self.compute_type}"
|
|
)
|
|
|
|
def _get_compatible_compute_type(self, device: str, requested: str) -> str:
|
|
"""
|
|
Get compatible compute type for the device.
|
|
|
|
CPU: Only supports int8 and float32
|
|
GPU: Supports float16, float32, int8, int8_float16
|
|
|
|
Args:
|
|
device: Device type (cpu, cuda)
|
|
requested: Requested compute type
|
|
|
|
Returns:
|
|
Compatible compute type
|
|
"""
|
|
if device == "cpu":
|
|
# CPU only supports int8 and float32
|
|
if requested == "auto":
|
|
return "int8" # int8 is faster on CPU
|
|
elif requested in ("float16", "int8_float16"):
|
|
logger.warning(f"CPU doesn't support {requested}, falling back to int8")
|
|
return "int8"
|
|
elif requested in ("int8", "float32"):
|
|
return requested
|
|
else:
|
|
logger.warning(f"Unknown compute type {requested}, using int8")
|
|
return "int8"
|
|
else:
|
|
# CUDA/GPU supports all types
|
|
if requested == "auto":
|
|
return "float16" # float16 is recommended for GPU
|
|
elif requested in ("float16", "float32", "int8", "int8_float16"):
|
|
return requested
|
|
else:
|
|
logger.warning(f"Unknown compute type {requested}, using float16")
|
|
return "float16"
|
|
def load_model(self):
|
|
"""Load Whisper model into memory."""
|
|
if not WHISPER_AVAILABLE:
|
|
raise RuntimeError(
|
|
"Whisper is not available. Install with: pip install stable-ts faster-whisper"
|
|
)
|
|
|
|
if self.is_loaded and self.model is not None:
|
|
logger.debug("Model already loaded")
|
|
return
|
|
|
|
try:
|
|
logger.info(f"Loading Whisper model: {self.model_name}")
|
|
self.model = stable_whisper.load_faster_whisper(
|
|
self.model_name,
|
|
download_root=self.model_path,
|
|
device=self.device,
|
|
cpu_threads=self.threads,
|
|
num_workers=1, # Each worker has own model
|
|
compute_type=self.compute_type if self.device == "gpu" or self.device == "cuda" else "float32",
|
|
)
|
|
self.is_loaded = True
|
|
logger.info(f"Model {self.model_name} loaded successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model {self.model_name}: {e}")
|
|
raise
|
|
|
|
def unload_model(self):
|
|
"""Unload model from memory and clear cache."""
|
|
if not self.is_loaded or self.model is None:
|
|
logger.debug("Model not loaded, nothing to unload")
|
|
return
|
|
|
|
try:
|
|
logger.info("Unloading Whisper model")
|
|
|
|
# Unload the model
|
|
if hasattr(self.model, "model") and hasattr(self.model.model, "unload_model"):
|
|
self.model.model.unload_model()
|
|
|
|
del self.model
|
|
self.model = None
|
|
self.is_loaded = False
|
|
|
|
# Clear CUDA cache if using GPU
|
|
if self.device == "cuda" and torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
logger.debug("CUDA cache cleared")
|
|
|
|
# Garbage collection
|
|
if os.name != "nt": # Don't run on Windows
|
|
gc.collect()
|
|
try:
|
|
ctypes.CDLL(ctypes.util.find_library("c")).malloc_trim(0)
|
|
except Exception:
|
|
pass
|
|
|
|
logger.info("Model unloaded successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error unloading model: {e}")
|
|
|
|
def transcribe_file(
|
|
self,
|
|
file_path: str,
|
|
language: Optional[str] = None,
|
|
task: str = "transcribe",
|
|
progress_callback: Optional[Callable] = None,
|
|
) -> TranscriptionResult:
|
|
"""
|
|
Transcribe a media file.
|
|
|
|
Args:
|
|
file_path: Path to media file
|
|
language: Language code (ISO 639-1) or None for auto-detect
|
|
task: 'transcribe' or 'translate'
|
|
progress_callback: Optional callback for progress updates
|
|
|
|
Returns:
|
|
TranscriptionResult object
|
|
|
|
Raises:
|
|
Exception: If transcription fails
|
|
"""
|
|
# Ensure model is loaded
|
|
if not self.is_loaded:
|
|
self.load_model()
|
|
|
|
try:
|
|
logger.info(f"Transcribing file: {file_path} (language={language}, task={task})")
|
|
|
|
# Prepare transcription arguments
|
|
args = {}
|
|
if progress_callback:
|
|
args["progress_callback"] = progress_callback
|
|
|
|
# Add custom regroup if configured
|
|
from backend.core.settings_service import settings_service
|
|
custom_regroup = settings_service.get('custom_regroup', 'cm_sl=84_sl=42++++++1')
|
|
if custom_regroup:
|
|
args["regroup"] = custom_regroup
|
|
|
|
# Perform transcription
|
|
result = self.model.transcribe(
|
|
file_path,
|
|
language=language,
|
|
task=task,
|
|
**args,
|
|
)
|
|
|
|
segments_count = len(result.segments) if hasattr(result, "segments") else 0
|
|
detected_language = result.language if hasattr(result, "language") else language or "unknown"
|
|
|
|
logger.info(
|
|
f"Transcription completed: {segments_count} segments, "
|
|
f"language={detected_language}"
|
|
)
|
|
|
|
return TranscriptionResult(
|
|
result=result,
|
|
language=detected_language,
|
|
segments_count=segments_count,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Transcription failed for {file_path}: {e}")
|
|
raise
|
|
|
|
def transcribe_audio_data(
|
|
self,
|
|
audio_data: bytes,
|
|
language: Optional[str] = None,
|
|
task: str = "transcribe",
|
|
sample_rate: int = 16000,
|
|
progress_callback: Optional[Callable] = None,
|
|
) -> TranscriptionResult:
|
|
"""
|
|
Transcribe raw audio data (for Bazarr provider mode).
|
|
|
|
Args:
|
|
audio_data: Raw audio bytes
|
|
language: Language code or None
|
|
task: 'transcribe' or 'translate'
|
|
sample_rate: Audio sample rate
|
|
progress_callback: Optional progress callback
|
|
|
|
Returns:
|
|
TranscriptionResult object
|
|
"""
|
|
if not self.is_loaded:
|
|
self.load_model()
|
|
|
|
try:
|
|
logger.info(f"Transcribing audio data (size={len(audio_data)} bytes)")
|
|
|
|
args = {
|
|
"audio": audio_data,
|
|
"input_sr": sample_rate,
|
|
}
|
|
|
|
if progress_callback:
|
|
args["progress_callback"] = progress_callback
|
|
|
|
from backend.core.settings_service import settings_service
|
|
custom_regroup = settings_service.get('custom_regroup', 'cm_sl=84_sl=42++++++1')
|
|
if custom_regroup:
|
|
args["regroup"] = custom_regroup
|
|
|
|
result = self.model.transcribe(task=task, language=language, **args)
|
|
|
|
segments_count = len(result.segments) if hasattr(result, "segments") else 0
|
|
detected_language = result.language if hasattr(result, "language") else language or "unknown"
|
|
|
|
logger.info(f"Audio transcription completed: {segments_count} segments")
|
|
|
|
return TranscriptionResult(
|
|
result=result,
|
|
language=detected_language,
|
|
segments_count=segments_count,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Audio transcription failed: {e}")
|
|
raise
|
|
|
|
def detect_language(
|
|
self,
|
|
file_path: str,
|
|
offset: int = 0,
|
|
length: int = 30,
|
|
) -> str:
|
|
"""
|
|
Detect language of a media file.
|
|
|
|
Args:
|
|
file_path: Path to media file
|
|
offset: Start offset in seconds
|
|
length: Duration to analyze in seconds
|
|
|
|
Returns:
|
|
Language code (ISO 639-1)
|
|
"""
|
|
if not self.is_loaded:
|
|
self.load_model()
|
|
|
|
try:
|
|
logger.info(f"Detecting language for: {file_path} (offset={offset}s, length={length}s)")
|
|
|
|
# Extract audio segment for analysis
|
|
from backend.transcription.audio_utils import extract_audio_segment
|
|
|
|
audio_segment = extract_audio_segment(file_path, offset, length)
|
|
|
|
result = self.model.transcribe(audio_segment.read())
|
|
detected_language = result.language if hasattr(result, "language") else "unknown"
|
|
|
|
logger.info(f"Detected language: {detected_language}")
|
|
return detected_language
|
|
|
|
except Exception as e:
|
|
logger.error(f"Language detection failed for {file_path}: {e}")
|
|
return "unknown"
|
|
|
|
def __enter__(self):
|
|
"""Context manager entry."""
|
|
self.load_model()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
"""Context manager exit."""
|
|
from backend.core.settings_service import settings_service
|
|
if settings_service.get('clear_vram_on_complete', True) in (True, 'true', 'True', '1', 1):
|
|
self.unload_model()
|
|
|
|
def __del__(self):
|
|
"""Destructor - ensure model is unloaded."""
|
|
try:
|
|
if self.is_loaded:
|
|
self.unload_model()
|
|
except Exception:
|
|
pass
|