feat(transcription): add Whisper transcriber and audio utilities
- Add WhisperTranscriber wrapper for stable-ts/faster-whisper - Add audio utilities for ffmpeg/ffprobe operations - Add translator for two-stage translation workflow - Support CPU/GPU with graceful degradation
This commit is contained in:
5
backend/transcription/__init__.py
Normal file
5
backend/transcription/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Whisper transcription module."""
|
||||
from backend.transcription.transcriber import WhisperTranscriber
|
||||
from backend.transcription.translator import SRTTranslator, translate_srt_file
|
||||
|
||||
__all__ = ["WhisperTranscriber", "SRTTranslator", "translate_srt_file"]
|
||||
354
backend/transcription/audio_utils.py
Normal file
354
backend/transcription/audio_utils.py
Normal file
@@ -0,0 +1,354 @@
|
||||
"""Audio processing utilities extracted from transcriptarr.py."""
|
||||
import logging
|
||||
import os
|
||||
from io import BytesIO
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
import ffmpeg
|
||||
|
||||
# Optional import - graceful degradation if not available
|
||||
try:
|
||||
import av
|
||||
AV_AVAILABLE = True
|
||||
except ImportError:
|
||||
av = None
|
||||
AV_AVAILABLE = False
|
||||
logging.warning("av (PyAV) not available. Some audio features may not work.")
|
||||
|
||||
from backend.core.language_code import LanguageCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_audio_segment(
|
||||
input_file: str,
|
||||
start_time: int,
|
||||
duration: int,
|
||||
) -> BytesIO:
|
||||
"""
|
||||
Extract a segment of audio from a file to memory.
|
||||
|
||||
Args:
|
||||
input_file: Path to input media file
|
||||
start_time: Start time in seconds
|
||||
duration: Duration in seconds
|
||||
|
||||
Returns:
|
||||
BytesIO object containing audio segment
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Extracting audio: {input_file}, start={start_time}s, duration={duration}s")
|
||||
|
||||
out, _ = (
|
||||
ffmpeg.input(input_file, ss=start_time, t=duration)
|
||||
.output("pipe:1", format="wav", acodec="pcm_s16le", ar=16000)
|
||||
.run(capture_stdout=True, capture_stderr=True)
|
||||
)
|
||||
|
||||
if not out:
|
||||
raise ValueError("FFmpeg output is empty")
|
||||
|
||||
return BytesIO(out)
|
||||
|
||||
except ffmpeg.Error as e:
|
||||
logger.error(f"FFmpeg error: {e.stderr.decode()}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting audio: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def get_audio_tracks(video_file: str) -> List[Dict]:
|
||||
"""
|
||||
Get information about audio tracks in a media file.
|
||||
|
||||
Args:
|
||||
video_file: Path to media file
|
||||
|
||||
Returns:
|
||||
List of dicts with audio track information
|
||||
"""
|
||||
try:
|
||||
probe = ffmpeg.probe(video_file, select_streams="a")
|
||||
audio_streams = probe.get("streams", [])
|
||||
|
||||
audio_tracks = []
|
||||
for stream in audio_streams:
|
||||
# Get all possible language tags - check multiple locations
|
||||
tags = stream.get("tags", {})
|
||||
|
||||
# Try different common tag names (MKV uses different conventions)
|
||||
lang_tag = (
|
||||
tags.get("language") or # Standard location
|
||||
tags.get("LANGUAGE") or # Uppercase variant
|
||||
tags.get("lang") or # Short form
|
||||
stream.get("language") or # Sometimes at stream level
|
||||
"und" # Default: undefined
|
||||
)
|
||||
|
||||
# Log ALL tags for debugging
|
||||
logger.debug(
|
||||
f"Audio track {stream.get('index')}: "
|
||||
f"codec={stream.get('codec_name')}, "
|
||||
f"lang_tag='{lang_tag}', "
|
||||
f"all_tags={tags}"
|
||||
)
|
||||
|
||||
language = LanguageCode.from_iso_639_2(lang_tag)
|
||||
|
||||
# Log when language is undefined
|
||||
if lang_tag == "und" or language is None:
|
||||
logger.warning(
|
||||
f"Audio track {stream.get('index')} in {video_file}: "
|
||||
f"Language undefined (tag='{lang_tag}'). "
|
||||
f"Available tags: {list(tags.keys())}"
|
||||
)
|
||||
|
||||
audio_track = {
|
||||
"index": int(stream.get("index", 0)),
|
||||
"codec": stream.get("codec_name", "unknown"),
|
||||
"channels": int(stream.get("channels", 0)),
|
||||
"language": language,
|
||||
"title": tags.get("title", ""),
|
||||
"default": stream.get("disposition", {}).get("default", 0) == 1,
|
||||
"forced": stream.get("disposition", {}).get("forced", 0) == 1,
|
||||
"original": stream.get("disposition", {}).get("original", 0) == 1,
|
||||
"commentary": "commentary" in tags.get("title", "").lower(),
|
||||
}
|
||||
audio_tracks.append(audio_track)
|
||||
|
||||
return audio_tracks
|
||||
|
||||
except ffmpeg.Error as e:
|
||||
logger.error(f"FFmpeg error: {e.stderr}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading audio tracks: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def extract_audio_track_to_memory(
|
||||
input_video_path: str, track_index: int
|
||||
) -> Optional[BytesIO]:
|
||||
"""
|
||||
Extract a specific audio track to memory.
|
||||
|
||||
Args:
|
||||
input_video_path: Path to video file
|
||||
track_index: Audio track index
|
||||
|
||||
Returns:
|
||||
BytesIO with audio data or None
|
||||
"""
|
||||
if track_index is None:
|
||||
logger.warning(f"Skipping audio track extraction for {input_video_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
out, _ = (
|
||||
ffmpeg.input(input_video_path)
|
||||
.output(
|
||||
"pipe:",
|
||||
map=f"0:{track_index}",
|
||||
format="wav",
|
||||
ac=1,
|
||||
ar=16000,
|
||||
loglevel="quiet",
|
||||
)
|
||||
.run(capture_stdout=True, capture_stderr=True)
|
||||
)
|
||||
return BytesIO(out)
|
||||
|
||||
except ffmpeg.Error as e:
|
||||
logger.error(f"FFmpeg error extracting track: {e.stderr.decode()}")
|
||||
return None
|
||||
|
||||
|
||||
def get_audio_languages(video_path: str) -> List[LanguageCode]:
|
||||
"""
|
||||
Extract language codes from audio streams.
|
||||
|
||||
Args:
|
||||
video_path: Path to video file
|
||||
|
||||
Returns:
|
||||
List of LanguageCode objects
|
||||
"""
|
||||
audio_tracks = get_audio_tracks(video_path)
|
||||
return [track["language"] for track in audio_tracks]
|
||||
|
||||
|
||||
def get_subtitle_languages(video_path: str) -> List[LanguageCode]:
|
||||
"""
|
||||
Extract language codes from subtitle streams.
|
||||
|
||||
Args:
|
||||
video_path: Path to video file
|
||||
|
||||
Returns:
|
||||
List of LanguageCode objects
|
||||
"""
|
||||
languages = []
|
||||
|
||||
try:
|
||||
with av.open(video_path) as container:
|
||||
for stream in container.streams.subtitles:
|
||||
lang_code = stream.metadata.get("language")
|
||||
if lang_code:
|
||||
languages.append(LanguageCode.from_iso_639_2(lang_code))
|
||||
else:
|
||||
languages.append(LanguageCode.NONE)
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading subtitle languages: {e}")
|
||||
|
||||
return languages
|
||||
|
||||
|
||||
def has_audio(file_path: str) -> bool:
|
||||
"""
|
||||
Check if a file has valid audio streams.
|
||||
|
||||
Args:
|
||||
file_path: Path to media file
|
||||
|
||||
Returns:
|
||||
True if file has audio, False otherwise
|
||||
"""
|
||||
if not AV_AVAILABLE or av is None:
|
||||
logger.warning(f"av (PyAV) not available, cannot check audio for {file_path}")
|
||||
# Assume file has audio if we can't check
|
||||
return True
|
||||
|
||||
try:
|
||||
if not os.path.isfile(file_path):
|
||||
return False
|
||||
|
||||
with av.open(file_path) as container:
|
||||
for stream in container.streams:
|
||||
if stream.type == "audio":
|
||||
if stream.codec_context and stream.codec_context.name != "none":
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
# Catch all exceptions since av.FFmpegError might not exist if av is None
|
||||
logger.debug(f"Error checking audio in {file_path}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def has_subtitle_language_in_file(
|
||||
video_file: str, target_language: LanguageCode
|
||||
) -> bool:
|
||||
"""
|
||||
Check if video has embedded subtitles in target language.
|
||||
|
||||
Args:
|
||||
video_file: Path to video file
|
||||
target_language: Language to check for
|
||||
|
||||
Returns:
|
||||
True if subtitles exist in target language
|
||||
"""
|
||||
if not AV_AVAILABLE or av is None:
|
||||
logger.warning(f"av (PyAV) not available, cannot check subtitles for {video_file}")
|
||||
return False
|
||||
|
||||
try:
|
||||
with av.open(video_file) as container:
|
||||
subtitle_streams = [
|
||||
stream
|
||||
for stream in container.streams
|
||||
if stream.type == "subtitle" and "language" in stream.metadata
|
||||
]
|
||||
|
||||
for stream in subtitle_streams:
|
||||
stream_language = LanguageCode.from_string(
|
||||
stream.metadata.get("language", "").lower()
|
||||
)
|
||||
if stream_language == target_language:
|
||||
logger.debug(f"Found subtitles in '{target_language}' in video")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking subtitles: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def has_subtitle_of_language_in_folder(
|
||||
video_file: str, target_language: LanguageCode
|
||||
) -> bool:
|
||||
"""
|
||||
Check if external subtitle file exists for video.
|
||||
|
||||
Args:
|
||||
video_file: Path to video file
|
||||
target_language: Language to check for
|
||||
|
||||
Returns:
|
||||
True if external subtitle exists
|
||||
"""
|
||||
subtitle_extensions = {".srt", ".vtt", ".sub", ".ass", ".ssa"}
|
||||
|
||||
video_folder = os.path.dirname(video_file)
|
||||
video_name = os.path.splitext(os.path.basename(video_file))[0]
|
||||
|
||||
try:
|
||||
for file_name in os.listdir(video_folder):
|
||||
if not any(file_name.endswith(ext) for ext in subtitle_extensions):
|
||||
continue
|
||||
|
||||
subtitle_name, _ = os.path.splitext(file_name)
|
||||
|
||||
if not subtitle_name.startswith(video_name):
|
||||
continue
|
||||
|
||||
# Extract language from filename
|
||||
parts = subtitle_name[len(video_name) :].lstrip(".").split(".")
|
||||
|
||||
for part in parts:
|
||||
if LanguageCode.from_string(part) == target_language:
|
||||
logger.debug(f"Found external subtitle: {file_name}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking external subtitles: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def handle_multiple_audio_tracks(
|
||||
file_path: str, language: Optional[LanguageCode] = None
|
||||
) -> Optional[BytesIO]:
|
||||
"""
|
||||
Handle files with multiple audio tracks.
|
||||
|
||||
Args:
|
||||
file_path: Path to media file
|
||||
language: Preferred language
|
||||
|
||||
Returns:
|
||||
BytesIO with extracted audio or None
|
||||
"""
|
||||
audio_tracks = get_audio_tracks(file_path)
|
||||
|
||||
if len(audio_tracks) <= 1:
|
||||
return None
|
||||
|
||||
logger.debug(f"Handling {len(audio_tracks)} audio tracks")
|
||||
|
||||
# Find track by language
|
||||
audio_track = None
|
||||
if language:
|
||||
for track in audio_tracks:
|
||||
if track["language"] == language:
|
||||
audio_track = track
|
||||
break
|
||||
|
||||
# Fallback to first track
|
||||
if not audio_track:
|
||||
audio_track = audio_tracks[0]
|
||||
|
||||
return extract_audio_track_to_memory(file_path, audio_track["index"])
|
||||
408
backend/transcription/transcriber.py
Normal file
408
backend/transcription/transcriber.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""Whisper transcription wrapper for worker processes."""
|
||||
import logging
|
||||
import os
|
||||
import gc
|
||||
import ctypes
|
||||
import ctypes.util
|
||||
from typing import Optional, Callable
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Optional imports - graceful degradation if not available
|
||||
try:
|
||||
import stable_whisper
|
||||
import torch
|
||||
WHISPER_AVAILABLE = True
|
||||
except ImportError:
|
||||
stable_whisper = None
|
||||
torch = None
|
||||
WHISPER_AVAILABLE = False
|
||||
logging.warning("stable_whisper or torch not available. Transcription will not work.")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranscriptionResult:
|
||||
"""Result of a transcription operation."""
|
||||
|
||||
def __init__(self, result, language: str, segments_count: int):
|
||||
"""
|
||||
Initialize transcription result.
|
||||
|
||||
Args:
|
||||
result: stable-ts result object
|
||||
language: Detected or forced language
|
||||
segments_count: Number of subtitle segments
|
||||
"""
|
||||
self.result = result
|
||||
self.language = language
|
||||
self.segments_count = segments_count
|
||||
|
||||
def to_srt(self, output_path: str, word_level: bool = False) -> str:
|
||||
"""
|
||||
Save result as SRT file.
|
||||
|
||||
Args:
|
||||
output_path: Path to save SRT file
|
||||
word_level: Enable word-level timestamps
|
||||
|
||||
Returns:
|
||||
Path to saved file
|
||||
"""
|
||||
self.result.to_srt_vtt(output_path, word_level=word_level)
|
||||
return output_path
|
||||
|
||||
def get_srt_content(self, word_level: bool = False) -> str:
|
||||
"""
|
||||
Get SRT content as string.
|
||||
|
||||
Args:
|
||||
word_level: Enable word-level timestamps
|
||||
|
||||
Returns:
|
||||
SRT content
|
||||
"""
|
||||
return "".join(self.result.to_srt_vtt(filepath=None, word_level=word_level))
|
||||
|
||||
|
||||
class WhisperTranscriber:
|
||||
"""
|
||||
Whisper transcription engine wrapper.
|
||||
|
||||
Manages Whisper model loading/unloading and transcription operations.
|
||||
Designed to run in worker processes with isolated model instances.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
model_path: Optional[str] = None,
|
||||
compute_type: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Initialize transcriber.
|
||||
|
||||
Args:
|
||||
model_name: Whisper model name (tiny, base, small, medium, large, etc.)
|
||||
device: Device to use (cpu, cuda, gpu)
|
||||
model_path: Path to store/load models
|
||||
compute_type: Compute type (auto, int8, float16, etc.)
|
||||
threads: Number of CPU threads
|
||||
"""
|
||||
# Import settings_service here to avoid circular imports
|
||||
from backend.core.settings_service import settings_service
|
||||
|
||||
# Load from database settings with sensible defaults
|
||||
self.model_name = model_name or settings_service.get('whisper_model', 'medium')
|
||||
self.device = (device or settings_service.get('transcribe_device', 'cpu')).lower()
|
||||
if self.device == "gpu":
|
||||
self.device = "cuda"
|
||||
self.model_path = model_path or settings_service.get('model_path', './models')
|
||||
|
||||
# Get compute_type from settings based on device type
|
||||
if compute_type:
|
||||
requested_compute_type = compute_type
|
||||
elif self.device == "cpu":
|
||||
requested_compute_type = settings_service.get('cpu_compute_type', 'auto')
|
||||
else:
|
||||
requested_compute_type = settings_service.get('gpu_compute_type', 'auto')
|
||||
|
||||
# Auto-detect compatible compute_type based on device
|
||||
self.compute_type = self._get_compatible_compute_type(self.device, requested_compute_type)
|
||||
|
||||
self.threads = threads or int(settings_service.get('whisper_threads', 4))
|
||||
|
||||
self.model = None
|
||||
self.is_loaded = False
|
||||
|
||||
if self.compute_type != requested_compute_type:
|
||||
logger.warning(
|
||||
f"Requested compute_type '{requested_compute_type}' is not compatible with device '{self.device}'. "
|
||||
f"Using '{self.compute_type}' instead."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"WhisperTranscriber initialized: model={self.model_name}, "
|
||||
f"device={self.device}, compute_type={self.compute_type}"
|
||||
)
|
||||
|
||||
def _get_compatible_compute_type(self, device: str, requested: str) -> str:
|
||||
"""
|
||||
Get compatible compute type for the device.
|
||||
|
||||
CPU: Only supports int8 and float32
|
||||
GPU: Supports float16, float32, int8, int8_float16
|
||||
|
||||
Args:
|
||||
device: Device type (cpu, cuda)
|
||||
requested: Requested compute type
|
||||
|
||||
Returns:
|
||||
Compatible compute type
|
||||
"""
|
||||
if device == "cpu":
|
||||
# CPU only supports int8 and float32
|
||||
if requested == "auto":
|
||||
return "int8" # int8 is faster on CPU
|
||||
elif requested in ("float16", "int8_float16"):
|
||||
logger.warning(f"CPU doesn't support {requested}, falling back to int8")
|
||||
return "int8"
|
||||
elif requested in ("int8", "float32"):
|
||||
return requested
|
||||
else:
|
||||
logger.warning(f"Unknown compute type {requested}, using int8")
|
||||
return "int8"
|
||||
else:
|
||||
# CUDA/GPU supports all types
|
||||
if requested == "auto":
|
||||
return "float16" # float16 is recommended for GPU
|
||||
elif requested in ("float16", "float32", "int8", "int8_float16"):
|
||||
return requested
|
||||
else:
|
||||
logger.warning(f"Unknown compute type {requested}, using float16")
|
||||
return "float16"
|
||||
def load_model(self):
|
||||
"""Load Whisper model into memory."""
|
||||
if not WHISPER_AVAILABLE:
|
||||
raise RuntimeError(
|
||||
"Whisper is not available. Install with: pip install stable-ts faster-whisper"
|
||||
)
|
||||
|
||||
if self.is_loaded and self.model is not None:
|
||||
logger.debug("Model already loaded")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"Loading Whisper model: {self.model_name}")
|
||||
self.model = stable_whisper.load_faster_whisper(
|
||||
self.model_name,
|
||||
download_root=self.model_path,
|
||||
device=self.device,
|
||||
cpu_threads=self.threads,
|
||||
num_workers=1, # Each worker has own model
|
||||
compute_type=self.compute_type if self.device == "gpu" or self.device == "cuda" else "float32",
|
||||
)
|
||||
self.is_loaded = True
|
||||
logger.info(f"Model {self.model_name} loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model {self.model_name}: {e}")
|
||||
raise
|
||||
|
||||
def unload_model(self):
|
||||
"""Unload model from memory and clear cache."""
|
||||
if not self.is_loaded or self.model is None:
|
||||
logger.debug("Model not loaded, nothing to unload")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("Unloading Whisper model")
|
||||
|
||||
# Unload the model
|
||||
if hasattr(self.model, "model") and hasattr(self.model.model, "unload_model"):
|
||||
self.model.model.unload_model()
|
||||
|
||||
del self.model
|
||||
self.model = None
|
||||
self.is_loaded = False
|
||||
|
||||
# Clear CUDA cache if using GPU
|
||||
if self.device == "cuda" and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
logger.debug("CUDA cache cleared")
|
||||
|
||||
# Garbage collection
|
||||
if os.name != "nt": # Don't run on Windows
|
||||
gc.collect()
|
||||
try:
|
||||
ctypes.CDLL(ctypes.util.find_library("c")).malloc_trim(0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("Model unloaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error unloading model: {e}")
|
||||
|
||||
def transcribe_file(
|
||||
self,
|
||||
file_path: str,
|
||||
language: Optional[str] = None,
|
||||
task: str = "transcribe",
|
||||
progress_callback: Optional[Callable] = None,
|
||||
) -> TranscriptionResult:
|
||||
"""
|
||||
Transcribe a media file.
|
||||
|
||||
Args:
|
||||
file_path: Path to media file
|
||||
language: Language code (ISO 639-1) or None for auto-detect
|
||||
task: 'transcribe' or 'translate'
|
||||
progress_callback: Optional callback for progress updates
|
||||
|
||||
Returns:
|
||||
TranscriptionResult object
|
||||
|
||||
Raises:
|
||||
Exception: If transcription fails
|
||||
"""
|
||||
# Ensure model is loaded
|
||||
if not self.is_loaded:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
logger.info(f"Transcribing file: {file_path} (language={language}, task={task})")
|
||||
|
||||
# Prepare transcription arguments
|
||||
args = {}
|
||||
if progress_callback:
|
||||
args["progress_callback"] = progress_callback
|
||||
|
||||
# Add custom regroup if configured
|
||||
from backend.core.settings_service import settings_service
|
||||
custom_regroup = settings_service.get('custom_regroup', 'cm_sl=84_sl=42++++++1')
|
||||
if custom_regroup:
|
||||
args["regroup"] = custom_regroup
|
||||
|
||||
# Perform transcription
|
||||
result = self.model.transcribe(
|
||||
file_path,
|
||||
language=language,
|
||||
task=task,
|
||||
**args,
|
||||
)
|
||||
|
||||
segments_count = len(result.segments) if hasattr(result, "segments") else 0
|
||||
detected_language = result.language if hasattr(result, "language") else language or "unknown"
|
||||
|
||||
logger.info(
|
||||
f"Transcription completed: {segments_count} segments, "
|
||||
f"language={detected_language}"
|
||||
)
|
||||
|
||||
return TranscriptionResult(
|
||||
result=result,
|
||||
language=detected_language,
|
||||
segments_count=segments_count,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription failed for {file_path}: {e}")
|
||||
raise
|
||||
|
||||
def transcribe_audio_data(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
language: Optional[str] = None,
|
||||
task: str = "transcribe",
|
||||
sample_rate: int = 16000,
|
||||
progress_callback: Optional[Callable] = None,
|
||||
) -> TranscriptionResult:
|
||||
"""
|
||||
Transcribe raw audio data (for Bazarr provider mode).
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
language: Language code or None
|
||||
task: 'transcribe' or 'translate'
|
||||
sample_rate: Audio sample rate
|
||||
progress_callback: Optional progress callback
|
||||
|
||||
Returns:
|
||||
TranscriptionResult object
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
logger.info(f"Transcribing audio data (size={len(audio_data)} bytes)")
|
||||
|
||||
args = {
|
||||
"audio": audio_data,
|
||||
"input_sr": sample_rate,
|
||||
}
|
||||
|
||||
if progress_callback:
|
||||
args["progress_callback"] = progress_callback
|
||||
|
||||
from backend.core.settings_service import settings_service
|
||||
custom_regroup = settings_service.get('custom_regroup', 'cm_sl=84_sl=42++++++1')
|
||||
if custom_regroup:
|
||||
args["regroup"] = custom_regroup
|
||||
|
||||
result = self.model.transcribe(task=task, language=language, **args)
|
||||
|
||||
segments_count = len(result.segments) if hasattr(result, "segments") else 0
|
||||
detected_language = result.language if hasattr(result, "language") else language or "unknown"
|
||||
|
||||
logger.info(f"Audio transcription completed: {segments_count} segments")
|
||||
|
||||
return TranscriptionResult(
|
||||
result=result,
|
||||
language=detected_language,
|
||||
segments_count=segments_count,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Audio transcription failed: {e}")
|
||||
raise
|
||||
|
||||
def detect_language(
|
||||
self,
|
||||
file_path: str,
|
||||
offset: int = 0,
|
||||
length: int = 30,
|
||||
) -> str:
|
||||
"""
|
||||
Detect language of a media file.
|
||||
|
||||
Args:
|
||||
file_path: Path to media file
|
||||
offset: Start offset in seconds
|
||||
length: Duration to analyze in seconds
|
||||
|
||||
Returns:
|
||||
Language code (ISO 639-1)
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
logger.info(f"Detecting language for: {file_path} (offset={offset}s, length={length}s)")
|
||||
|
||||
# Extract audio segment for analysis
|
||||
from backend.transcription.audio_utils import extract_audio_segment
|
||||
|
||||
audio_segment = extract_audio_segment(file_path, offset, length)
|
||||
|
||||
result = self.model.transcribe(audio_segment.read())
|
||||
detected_language = result.language if hasattr(result, "language") else "unknown"
|
||||
|
||||
logger.info(f"Detected language: {detected_language}")
|
||||
return detected_language
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Language detection failed for {file_path}: {e}")
|
||||
return "unknown"
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
self.load_model()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
from backend.core.settings_service import settings_service
|
||||
if settings_service.get('clear_vram_on_complete', True) in (True, 'true', 'True', '1', 1):
|
||||
self.unload_model()
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor - ensure model is unloaded."""
|
||||
try:
|
||||
if self.is_loaded:
|
||||
self.unload_model()
|
||||
except Exception:
|
||||
pass
|
||||
198
backend/transcription/translator.py
Normal file
198
backend/transcription/translator.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""SRT translation service using Google Translate or DeepL."""
|
||||
import logging
|
||||
from typing import Optional
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Check for translation library availability
|
||||
try:
|
||||
from deep_translator import GoogleTranslator
|
||||
TRANSLATOR_AVAILABLE = True
|
||||
except ImportError:
|
||||
GoogleTranslator = None
|
||||
TRANSLATOR_AVAILABLE = False
|
||||
|
||||
|
||||
class SRTTranslator:
|
||||
"""
|
||||
Translate SRT subtitle files from English to target language.
|
||||
|
||||
Uses deep-translator library with Google Translate as backend.
|
||||
Falls back gracefully if library not installed.
|
||||
"""
|
||||
|
||||
def __init__(self, target_language: str):
|
||||
"""
|
||||
Initialize translator.
|
||||
|
||||
Args:
|
||||
target_language: ISO 639-1 code (e.g., 'es', 'fr', 'ja')
|
||||
"""
|
||||
if not TRANSLATOR_AVAILABLE:
|
||||
raise RuntimeError(
|
||||
"Translation library not available. Install with: pip install deep-translator"
|
||||
)
|
||||
|
||||
# Google Translate accepts ISO 639-1 codes directly
|
||||
self.target_language = target_language
|
||||
logger.info(f"Initializing translator for language: {target_language}")
|
||||
|
||||
self.translator = None
|
||||
|
||||
def _get_translator(self):
|
||||
"""Lazy load translator."""
|
||||
if self.translator is None:
|
||||
self.translator = GoogleTranslator(source='en', target=self.target_language)
|
||||
return self.translator
|
||||
|
||||
def translate_srt_content(self, srt_content: str) -> str:
|
||||
"""
|
||||
Translate SRT content from English to target language.
|
||||
|
||||
Args:
|
||||
srt_content: SRT formatted string in English
|
||||
|
||||
Returns:
|
||||
SRT formatted string in target language
|
||||
|
||||
Raises:
|
||||
Exception: If translation fails
|
||||
"""
|
||||
if not srt_content or not srt_content.strip():
|
||||
logger.warning("Empty SRT content, nothing to translate")
|
||||
return srt_content
|
||||
|
||||
try:
|
||||
logger.info(f"Translating SRT content to {self.target_language}")
|
||||
|
||||
# Parse SRT into blocks
|
||||
blocks = self._parse_srt(srt_content)
|
||||
|
||||
if not blocks:
|
||||
logger.warning("No subtitle blocks found in SRT")
|
||||
return srt_content
|
||||
|
||||
# Translate each text block
|
||||
translator = self._get_translator()
|
||||
translated_blocks = []
|
||||
|
||||
for block in blocks:
|
||||
try:
|
||||
# Only translate the text, keep index and timestamps
|
||||
translated_text = translator.translate(block['text'])
|
||||
translated_blocks.append({
|
||||
'index': block['index'],
|
||||
'timestamp': block['timestamp'],
|
||||
'text': translated_text
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to translate block {block['index']}: {e}")
|
||||
# Keep original text on error
|
||||
translated_blocks.append(block)
|
||||
|
||||
# Reconstruct SRT
|
||||
result = self._reconstruct_srt(translated_blocks)
|
||||
|
||||
logger.info(f"Successfully translated {len(translated_blocks)} subtitle blocks")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Translation failed: {e}")
|
||||
raise
|
||||
|
||||
def _parse_srt(self, srt_content: str) -> list:
|
||||
"""
|
||||
Parse SRT content into structured blocks.
|
||||
|
||||
Args:
|
||||
srt_content: Raw SRT string
|
||||
|
||||
Returns:
|
||||
List of dicts with 'index', 'timestamp', 'text'
|
||||
"""
|
||||
blocks = []
|
||||
|
||||
# Split by double newline (subtitle blocks separator)
|
||||
raw_blocks = re.split(r'\n\s*\n', srt_content.strip())
|
||||
|
||||
for raw_block in raw_blocks:
|
||||
lines = raw_block.strip().split('\n')
|
||||
|
||||
if len(lines) < 3:
|
||||
continue # Invalid block
|
||||
|
||||
try:
|
||||
index = lines[0].strip()
|
||||
timestamp = lines[1].strip()
|
||||
text = '\n'.join(lines[2:]) # Join remaining lines as text
|
||||
|
||||
blocks.append({
|
||||
'index': index,
|
||||
'timestamp': timestamp,
|
||||
'text': text
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse SRT block: {e}")
|
||||
continue
|
||||
|
||||
return blocks
|
||||
|
||||
def _reconstruct_srt(self, blocks: list) -> str:
|
||||
"""
|
||||
Reconstruct SRT content from structured blocks.
|
||||
|
||||
Args:
|
||||
blocks: List of dicts with 'index', 'timestamp', 'text'
|
||||
|
||||
Returns:
|
||||
SRT formatted string
|
||||
"""
|
||||
srt_lines = []
|
||||
|
||||
for block in blocks:
|
||||
srt_lines.append(block['index'])
|
||||
srt_lines.append(block['timestamp'])
|
||||
srt_lines.append(block['text'])
|
||||
srt_lines.append('') # Empty line separator
|
||||
|
||||
return '\n'.join(srt_lines)
|
||||
|
||||
|
||||
def translate_srt_file(
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
target_language: str
|
||||
) -> bool:
|
||||
"""
|
||||
Translate an SRT file from English to target language.
|
||||
|
||||
Args:
|
||||
input_path: Path to input SRT file (English)
|
||||
output_path: Path to output SRT file (target language)
|
||||
target_language: ISO 639-1 code
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Read input SRT
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
srt_content = f.read()
|
||||
|
||||
# Translate
|
||||
translator = SRTTranslator(target_language=target_language)
|
||||
translated_content = translator.translate_srt_content(srt_content)
|
||||
|
||||
# Write output SRT
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(translated_content)
|
||||
|
||||
logger.info(f"Translated SRT saved to {output_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to translate SRT file: {e}")
|
||||
return False
|
||||
Reference in New Issue
Block a user