Added kwargs so you can change most any options on transcribe from faster-whisper, stable-ts, or whisper

This commit is contained in:
McCloudS
2024-08-03 07:02:32 -06:00
committed by GitHub
parent 13462b3eed
commit 43a33b1e89

View File

@@ -24,6 +24,7 @@ import av
import ffmpeg import ffmpeg
import whisper import whisper
import re import re
import ast
from watchdog.observers.polling import PollingObserver as Observer from watchdog.observers.polling import PollingObserver as Observer
from watchdog.events import FileSystemEventHandler from watchdog.events import FileSystemEventHandler
import faster_whisper import faster_whisper
@@ -102,6 +103,7 @@ def update_env_variables():
custom_regroup = os.getenv('CUSTOM_REGROUP', 'cm_sl=84_sl=42++++++1') custom_regroup = os.getenv('CUSTOM_REGROUP', 'cm_sl=84_sl=42++++++1')
detect_language_length = os.getenv('DETECT_LANGUAGE_LENGTH', 30) detect_language_length = os.getenv('DETECT_LANGUAGE_LENGTH', 30)
skipifexternalsub = convert_to_bool(os.getenv('SKIPIFEXTERNALSUB', False)) skipifexternalsub = convert_to_bool(os.getenv('SKIPIFEXTERNALSUB', False))
kwargs = ast.literal_eval(os.getenv('SUBGEN_KWARGS', ''))
set_env_variables('subgen.env') set_env_variables('subgen.env')
@@ -444,9 +446,9 @@ async def asr(
if model_prompt: if model_prompt:
custom_prompt = greetings_translations.get(language, '') or custom_model_prompt custom_prompt = greetings_translations.get(language, '') or custom_model_prompt
if custom_regroup: if custom_regroup:
result = model.transcribe_stable(audio_data, task=task, input_sr=16000, language=language, progress_callback=progress, initial_prompt=custom_prompt, regroup=custom_regroup) result = model.transcribe_stable(audio_data, task=task, input_sr=16000, language=language, progress_callback=progress, initial_prompt=custom_prompt, regroup=custom_regroup, **kwargs)
else: else:
result = model.transcribe_stable(audio_data, task=task, input_sr=16000, language=language, progress_callback=progress, initial_prompt=custom_prompt) result = model.transcribe_stable(audio_data, task=task, input_sr=16000, language=language, progress_callback=progress, initial_prompt=custom_prompt, **kwargs)
appendLine(result) appendLine(result)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
minutes, seconds = divmod(int(elapsed_time), 60) minutes, seconds = divmod(int(elapsed_time), 60)
@@ -485,7 +487,7 @@ async def detect_language(
task_queue.put(task_id) task_queue.put(task_id)
audio_data = np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0 audio_data = np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0
detected_language = model.transcribe_stable(whisper.pad_or_trim(audio_data, int(detect_language_length) * 16000), input_sr=16000).language detected_language = model.transcribe_stable(whisper.pad_or_trim(audio_data, int(detect_language_length) * 16000), input_sr=16000, **kwargs).language
# reverse lookup of language -> code, ex: "english" -> "en", "nynorsk" -> "nn", ... # reverse lookup of language -> code, ex: "english" -> "en", "nynorsk" -> "nn", ...
language_code = get_key_by_value(whisper_languages, detected_language) language_code = get_key_by_value(whisper_languages, detected_language)
@@ -546,10 +548,10 @@ def gen_subtitles(file_path: str, transcription_type: str, force_language=None)
if custom_regroup: if custom_regroup:
result = model.transcribe_stable(file_path, language=force_language, task=transcription_type, result = model.transcribe_stable(file_path, language=force_language, task=transcription_type,
progress_callback=progress, initial_prompt=custom_model_prompt, progress_callback=progress, initial_prompt=custom_model_prompt,
regroup=custom_regroup) regroup=custom_regroup, **kwargs)
else: else:
result = model.transcribe_stable(file_path, language=force_language, task=transcription_type, result = model.transcribe_stable(file_path, language=force_language, task=transcription_type,
progress_callback=progress, initial_prompt=custom_model_prompt) progress_callback=progress, initial_prompt=custom_model_prompt, **kwargs)
appendLine(result) appendLine(result)
file_name, file_extension = os.path.splitext(file_path) file_name, file_extension = os.path.splitext(file_path)