Update subgen.py

This commit is contained in:
McCloudS
2024-12-04 11:10:34 -07:00
committed by GitHub
parent 4f890efaec
commit 5b14057ddd

View File

@@ -12,7 +12,7 @@ import queue
import logging import logging
import gc import gc
import random import random
from typing import Union, Any from typing import Union, Any, Optional
from fastapi import FastAPI, File, UploadFile, Query, Header, Body, Form, Request from fastapi import FastAPI, File, UploadFile, Query, Header, Body, Form, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import numpy as np import numpy as np
@@ -26,9 +26,9 @@ import ast
from watchdog.observers.polling import PollingObserver as Observer from watchdog.observers.polling import PollingObserver as Observer
from watchdog.events import FileSystemEventHandler from watchdog.events import FileSystemEventHandler
import faster_whisper import faster_whisper
from io import BytesIO
import io import io
def get_key_by_value(d, value): def get_key_by_value(d, value):
reverse_dict = {v: k for k, v in d.items()} reverse_dict = {v: k for k, v in d.items()}
return reverse_dict.get(value) return reverse_dict.get(value)
@@ -65,7 +65,7 @@ reload_script_on_change = convert_to_bool(os.getenv('RELOAD_SCRIPT_ON_CHANGE', F
lrc_for_audio_files = convert_to_bool(os.getenv('LRC_FOR_AUDIO_FILES', True)) lrc_for_audio_files = convert_to_bool(os.getenv('LRC_FOR_AUDIO_FILES', True))
custom_regroup = os.getenv('CUSTOM_REGROUP', 'cm_sl=84_sl=42++++++1') custom_regroup = os.getenv('CUSTOM_REGROUP', 'cm_sl=84_sl=42++++++1')
detect_language_length = int(os.getenv('DETECT_LANGUAGE_LENGTH', 30)) detect_language_length = int(os.getenv('DETECT_LANGUAGE_LENGTH', 30))
detect_language_start_offset = int(os.getenv('DETECT_LANGUAGE_START_OFFSET', 0)) detect_language_offset = int(os.getenv('DETECT_LANGUAGE_START_OFFSET', 90))
skipifexternalsub = convert_to_bool(os.getenv('SKIPIFEXTERNALSUB', False)) skipifexternalsub = convert_to_bool(os.getenv('SKIPIFEXTERNALSUB', False))
skip_if_to_transcribe_sub_already_exist = convert_to_bool(os.getenv('SKIP_IF_TO_TRANSCRIBE_SUB_ALREADY_EXIST', True)) skip_if_to_transcribe_sub_already_exist = convert_to_bool(os.getenv('SKIP_IF_TO_TRANSCRIBE_SUB_ALREADY_EXIST', True))
skipifinternalsublang = LanguageCode.from_string(os.getenv('SKIPIFINTERNALSUBLANG', '')) skipifinternalsublang = LanguageCode.from_string(os.getenv('SKIPIFINTERNALSUBLANG', ''))
@@ -430,7 +430,7 @@ async def detect_language(
audio_file: UploadFile = File(...), audio_file: UploadFile = File(...),
encode: bool = Query(default=True, description="Encode audio first through ffmpeg"), # This is always false from Bazarr encode: bool = Query(default=True, description="Encode audio first through ffmpeg"), # This is always false from Bazarr
detect_lang_length: int = Query(default=detect_language_length, description="Detect language on X seconds of the file"), detect_lang_length: int = Query(default=detect_language_length, description="Detect language on X seconds of the file"),
detect_lang_offset: int = Query(default=detect_language_start_offset, description="Start Detect language X seconds into the file") detect_lang_offset: int = Query(default=detect_language_offset, description="Start Detect language X seconds into the file")
): ):
if force_detected_language_to: if force_detected_language_to:
@@ -453,10 +453,10 @@ async def detect_language(
logging.info(f"Detecting language on the first {detect_lang_length} seconds of the audio.") logging.info(f"Detecting language on the first {detect_lang_length} seconds of the audio.")
detect_language_length = detect_lang_length detect_language_length = detect_lang_length
if detect_lang_offset != detect_language_start_offset: if detect_lang_offset != detect_language_offset:
logging.info(f"Offsetting language detection by {detect_language_start_offset} seconds.") logging.info(f"Offsetting language detection by {detect_language_offset} seconds.")
detect_language_offset_length = detect_lang_offset detect_language_offset = detect_lang_offset
audio_file = extract_audio_segment_to_memory(audio_file, detect_language_start_offset, detect_language_length) #audio_file = extract_audio_segment_to_memory(audio_file, detect_language_offset, detect_language_length)
try: try:
start_model() start_model()
random_name = ''.join(random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890", k=6)) random_name = ''.join(random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890", k=6))
@@ -470,9 +470,11 @@ async def detect_language(
args['progress_callback'] = progress args['progress_callback'] = progress
if encode: if encode:
args['audio'] = whisper.pad_or_trim(audio_file.file.read() , sample_rate * int(detect_language_length)) args['audio'] = extract_audio_segment_to_memory(audio_file, detect_language_offset, detect_language_length).read()
args['input_sr'] = 16000
else: else:
args['audio'] = whisper.pad_or_trim(np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0, args['input_sr'] * int(detect_language_length)) #args['audio'] = whisper.pad_or_trim(np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0, args['input_sr'] * int(detect_language_length))
args['audio'] = extract_audio_segment_to_memory(audio_file, detect_language_offset, detect_language_length).read()
args['input_sr'] = 16000 args['input_sr'] = 16000
args.update(kwargs) args.update(kwargs)
@@ -502,7 +504,7 @@ def detect_language_task(path):
try: try:
start_model() start_model()
audio_segment = extract_audio_segment_to_memory(path, detect_language_start_offset, int(detect_language_length)).read() audio_segment = extract_audio_segment_to_memory(path, detect_language_offset, int(detect_language_length)).read()
detected_language = LanguageCode.from_name(model.transcribe_stable(audio_segment).language) detected_language = LanguageCode.from_name(model.transcribe_stable(audio_segment).language)
@@ -529,24 +531,44 @@ def extract_audio_segment_to_memory(input_file, start_time, duration):
""" """
Extract a segment of audio from input_file, starting at start_time for duration seconds. Extract a segment of audio from input_file, starting at start_time for duration seconds.
:param input_file: Path to the input audio file :param input_file: UploadFile object or path to the input audio file
:param start_time: Start time in seconds (e.g., 60 for 1 minute) :param start_time: Start time in seconds (e.g., 60 for 1 minute)
:param duration: Duration in seconds (e.g., 30 for 30 seconds) :param duration: Duration in seconds (e.g., 30 for 30 seconds)
:return: BytesIO object containing the audio segment :return: BytesIO object containing the audio segment
""" """
try: try:
if hasattr(input_file, 'file') and hasattr(input_file.file, 'read'): # Handling UploadFile
input_file.file.seek(0) # Ensure the file pointer is at the beginning
input_stream = 'pipe:0'
input_kwargs = {'input': input_file.file.read()}
elif isinstance(input_file, str): # Handling local file path
input_stream = input_file
input_kwargs = {}
else:
raise ValueError("Invalid input: input_file must be a file path or an UploadFile object.")
logging.info(f"Extracting audio from: {input_stream}, start_time: {start_time}, duration: {duration}")
# Run FFmpeg to extract the desired segment # Run FFmpeg to extract the desired segment
out, _ = ( out, _ = (
ffmpeg ffmpeg
.input(input_file, ss=start_time, t=duration) # Start time and duration .input(input_stream, ss=start_time, t=duration) # Set start time and duration
.output('pipe:1', format='wav', acodec='pcm_s16le', ar=16000) # Output to pipe as WAV .output('pipe:1', format='wav', acodec='pcm_s16le', ar=16000) # Output to pipe as WAV
.run(capture_stdout=True, capture_stderr=True) .run(capture_stdout=True, capture_stderr=True, **input_kwargs)
) )
return io.BytesIO(out) # Convert output to BytesIO for in-memory processing
except ffmpeg.Error as e:
print("Error occurred:", e.stderr.decode())
return None
# Check if the output is empty or null
if not out:
raise ValueError("FFmpeg output is empty, possibly due to invalid input.")
return io.BytesIO(out) # Convert output to BytesIO for in-memory processing
except ffmpeg.Error as e:
logging.error(f"FFmpeg error: {e.stderr.decode()}")
return None
except Exception as e:
logging.error(f"Error: {str(e)}")
return None
def start_model(): def start_model():
global model global model
@@ -665,7 +687,7 @@ def name_subtitle(file_path: str, language: LanguageCode) -> str:
""" """
return f"{os.path.splitext(file_path)[0]}.subgen.{whisper_model.split('.')[0]}.{define_subtitle_language_naming(language, subtitle_language_naming_type)}.srt" return f"{os.path.splitext(file_path)[0]}.subgen.{whisper_model.split('.')[0]}.{define_subtitle_language_naming(language, subtitle_language_naming_type)}.srt"
def handle_multiple_audio_tracks(file_path: str, language: LanguageCode | None = None) -> io.BytesIO | None: def handle_multiple_audio_tracks(file_path: str, language: LanguageCode | None = None) -> BytesIO | None:
""" """
Handles the possibility of a media file having multiple audio tracks. Handles the possibility of a media file having multiple audio tracks.
@@ -699,7 +721,7 @@ def handle_multiple_audio_tracks(file_path: str, language: LanguageCode | None =
return None return None
return audio_bytes return audio_bytes
def extract_audio_track_to_memory(input_video_path, track_index) -> io.BytesIO | None: def extract_audio_track_to_memory(input_video_path, track_index) -> BytesIO | None:
""" """
Extract a specific audio track from a video file to memory using FFmpeg. Extract a specific audio track from a video file to memory using FFmpeg.
@@ -729,7 +751,7 @@ def extract_audio_track_to_memory(input_video_path, track_index) -> io.BytesIO |
.run(capture_stdout=True, capture_stderr=True) # Capture output in memory .run(capture_stdout=True, capture_stderr=True) # Capture output in memory
) )
# Return the audio data as a BytesIO object # Return the audio data as a BytesIO object
return io.BytesIO(out) return BytesIO(out)
except ffmpeg.Error as e: except ffmpeg.Error as e:
print("An error occurred:", e.stderr.decode()) print("An error occurred:", e.stderr.decode())