diff --git a/subgen.py b/subgen.py index ae021fb..38146be 100644 --- a/subgen.py +++ b/subgen.py @@ -24,6 +24,7 @@ import av import ffmpeg import whisper import re +import ast from watchdog.observers.polling import PollingObserver as Observer from watchdog.events import FileSystemEventHandler import faster_whisper @@ -102,6 +103,7 @@ def update_env_variables(): custom_regroup = os.getenv('CUSTOM_REGROUP', 'cm_sl=84_sl=42++++++1') detect_language_length = os.getenv('DETECT_LANGUAGE_LENGTH', 30) skipifexternalsub = convert_to_bool(os.getenv('SKIPIFEXTERNALSUB', False)) + kwargs = ast.literal_eval(os.getenv('SUBGEN_KWARGS', '')) set_env_variables('subgen.env') @@ -444,9 +446,9 @@ async def asr( if model_prompt: custom_prompt = greetings_translations.get(language, '') or custom_model_prompt if custom_regroup: - result = model.transcribe_stable(audio_data, task=task, input_sr=16000, language=language, progress_callback=progress, initial_prompt=custom_prompt, regroup=custom_regroup) + result = model.transcribe_stable(audio_data, task=task, input_sr=16000, language=language, progress_callback=progress, initial_prompt=custom_prompt, regroup=custom_regroup, **kwargs) else: - result = model.transcribe_stable(audio_data, task=task, input_sr=16000, language=language, progress_callback=progress, initial_prompt=custom_prompt) + result = model.transcribe_stable(audio_data, task=task, input_sr=16000, language=language, progress_callback=progress, initial_prompt=custom_prompt, **kwargs) appendLine(result) elapsed_time = time.time() - start_time minutes, seconds = divmod(int(elapsed_time), 60) @@ -485,7 +487,7 @@ async def detect_language( task_queue.put(task_id) audio_data = np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0 - detected_language = model.transcribe_stable(whisper.pad_or_trim(audio_data, int(detect_language_length) * 16000), input_sr=16000).language + detected_language = model.transcribe_stable(whisper.pad_or_trim(audio_data, int(detect_language_length) * 16000), input_sr=16000, **kwargs).language # reverse lookup of language -> code, ex: "english" -> "en", "nynorsk" -> "nn", ... language_code = get_key_by_value(whisper_languages, detected_language) @@ -546,10 +548,10 @@ def gen_subtitles(file_path: str, transcription_type: str, force_language=None) if custom_regroup: result = model.transcribe_stable(file_path, language=force_language, task=transcription_type, progress_callback=progress, initial_prompt=custom_model_prompt, - regroup=custom_regroup) + regroup=custom_regroup, **kwargs) else: result = model.transcribe_stable(file_path, language=force_language, task=transcription_type, - progress_callback=progress, initial_prompt=custom_model_prompt) + progress_callback=progress, initial_prompt=custom_model_prompt, **kwargs) appendLine(result) file_name, file_extension = os.path.splitext(file_path)