From afc78c7e94742e1d660ff87d846d240115ebeab1 Mon Sep 17 00:00:00 2001 From: McCloudS <64094529+McCloudS@users.noreply.github.com> Date: Thu, 21 Mar 2024 11:07:03 -0600 Subject: [PATCH] Update subgen.py --- subgen.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/subgen.py b/subgen.py index 8072aee..1d7cfbd 100644 --- a/subgen.py +++ b/subgen.py @@ -1,4 +1,4 @@ -subgen_version = '2024.3.21.43' +subgen_version = '2024.3.21.44' from datetime import datetime import subprocess @@ -57,6 +57,7 @@ clear_vram_on_complete = convert_to_bool(os.getenv('CLEAR_VRAM_ON_COMPLETE', Tru compute_type = os.getenv('COMPUTE_TYPE', 'auto') append = convert_to_bool(os.getenv('APPEND', False)) reload_script_on_change = convert_to_bool(os.getenv('RELOAD_SCRIPT_ON_CHANGE', False)) +model_prompt = os.getenv('MODEL_PROMPT', 'Hello.') if transcribe_device == "gpu": transcribe_device = "cuda" @@ -332,7 +333,7 @@ def asr( start_model() files_to_transcribe.insert(0, f"Bazarr-asr-{random_name}") audio_data = np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0 - result = model.transcribe_stable(audio_data, task=task, input_sr=16000, language=language, progress_callback=progress) + result = model.transcribe_stable(audio_data, task=task, input_sr=16000, language=language, progress_callback=progress, initial_prompt=model_prompt) appendLine(result) elapsed_time = time.time() - start_time minutes, seconds = divmod(int(elapsed_time), 60) @@ -432,7 +433,7 @@ def gen_subtitles(file_path: str, transcribe_or_translate: str, front=True, forc if force_detected_language_to: forceLanguage = force_detected_language_to logging.info(f"Forcing language to {forceLanguage}") - result = model.transcribe_stable(file_path, language=forceLanguage, task=transcribe_or_translate, progress_callback=progress) + result = model.transcribe_stable(file_path, language=forceLanguage, task=transcribe_or_translate, progress_callback=progress, initial_prompt=model_prompt) appendLine(result) result.to_srt_vtt(get_file_name_without_extension(file_path) + subextension, word_level=word_level_highlight) elapsed_time = time.time() - start_time