Update subgen.py

Added some suggestions from JaiZed (https://github.com/JaiZed/subgen)
Added a check if existing SDH subtitle is there, added /batch endpoint, more elegant forceLanguage implementation.
This commit is contained in:
McCloudS
2024-02-10 13:10:21 -07:00
committed by GitHub
parent e2dd80e3dd
commit f7b1bef2cc

View File

@@ -28,18 +28,18 @@ def convert_to_bool(in_bool):
return value not in ('false', 'off', '0')
# Replace your getenv calls with appropriate default values here
plextoken = os.getenv('PLEXTOKEN', "token here")
plexserver = os.getenv('PLEXSERVER', "http://192.168.1.111:32400")
jellyfintoken = os.getenv('JELLYFINTOKEN', "token here")
jellyfinserver = os.getenv('JELLYFINSERVER', "http://192.168.1.111:8096")
whisper_model = os.getenv('WHISPER_MODEL', "medium")
plextoken = os.getenv('PLEXTOKEN', 'token here')
plexserver = os.getenv('PLEXSERVER', 'http://192.168.1.111:32400')
jellyfintoken = os.getenv('JELLYFINTOKEN', 'token here')
jellyfinserver = os.getenv('JELLYFINSERVER', 'http://192.168.1.111:8096')
whisper_model = os.getenv('WHISPER_MODEL', 'medium')
whisper_threads = int(os.getenv('WHISPER_THREADS', 4))
concurrent_transcriptions = int(os.getenv('CONCURRENT_TRANSCRIPTIONS', '2'))
transcribe_device = os.getenv('TRANSCRIBE_DEVICE', "cpu")
concurrent_transcriptions = int(os.getenv('CONCURRENT_TRANSCRIPTIONS', 2))
transcribe_device = os.getenv('TRANSCRIBE_DEVICE', 'cpu')
procaddedmedia = convert_to_bool(os.getenv('PROCADDEDMEDIA', True))
procmediaonplay = convert_to_bool(os.getenv('PROCMEDIAONPLAY', True))
namesublang = os.getenv('NAMESUBLANG', "aa")
skipifinternalsublang = os.getenv('SKIPIFINTERNALSUBLANG', "eng")
namesublang = os.getenv('NAMESUBLANG', 'aa')
skipifinternalsublang = os.getenv('SKIPIFINTERNALSUBLANG', 'eng')
webhookport = int(os.getenv('WEBHOOKPORT', 8090))
word_level_highlight = convert_to_bool(os.getenv('WORD_LEVEL_HIGHLIGHT', False))
debug = convert_to_bool(os.getenv('DEBUG', False))
@@ -49,7 +49,10 @@ path_mapping_to = os.getenv('PATH_MAPPING_TO', '/Volumes/TV')
model_location = os.getenv('MODEL_PATH', '.')
transcribe_folders = os.getenv('TRANSCRIBE_FOLDERS', '')
transcribe_or_translate = os.getenv('TRANSCRIBE_OR_TRANSLATE', 'translate')
force_detected_language_to = os.getenv('FORCE_DETECTED_LANGUAGE_TO', '')
force_detected_language_to = os.getenv('FORCE_DETECTED_LANGUAGE_TO', None)
hf_transformers = os.getenv('HF_TRANSFORMERS', False)
hf_batch_size = os.getenv('HF_BATCH_SIZE', 24)
clear_vram_on_complete = os.getenv('CLEAR_VRAM_ON_COMPLETE', True)
compute_type = os.getenv('COMPUTE_TYPE', 'auto')
if transcribe_device == "gpu":
transcribe_device = "cuda"
@@ -58,8 +61,7 @@ app = FastAPI()
model = None
files_to_transcribe = []
subextension = f".subgen.{whisper_model.split('.')[0]}.{namesublang}.srt"
print(f"Transcriptions are limited to running {str(concurrent_transcriptions)} at a time")
print(f"Running {str(whisper_threads)} threads per transcription")
subextensionSDH = f".subgen.{whisper_model.split('.')[0]}.{namesublang}.sdh.srt"
if debug:
logging.basicConfig(stream=sys.stderr, level=logging.NOTSET)
@@ -72,6 +74,7 @@ else:
@app.get("/asr")
@app.get("/emby")
@app.get("/detect-language")
@app.get("/")
def handle_get_request(request: Request):
return "You accessed this request incorrectly via a GET request. See https://github.com/McCloudS/subgen for proper configuration"
@@ -172,6 +175,14 @@ def receive_emby_webhook(
print("This doesn't appear to be a properly configured Emby webhook, please review the instructions again!")
return ""
@app.post("/batch")
def batch(
directory: Union[str, None] = Query(default=None),
forceLanguage: Union[str, None] = Query(default=None)
):
transcribe_existing(directory, forceLanguage)
# idea and some code for asr and detect language from https://github.com/ahmetoner/whisper-asr-webservice
@app.post("/asr")
def asr(
@@ -185,55 +196,81 @@ def asr(
):
try:
print(f"Transcribing file from Bazarr/ASR webhook")
result = None
#give the 'process' a random name so mutliple Bazaar transcribes can operate at the same time.
random_name = random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890", k=6)
start_time = time.time()
start_model()
#give the 'process' a random name so mutliple Bazaar transcribes can operate at the same time.
random_name = random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890", k=6)
files_to_transcribe.insert(0, f"Bazarr-detect-langauge-{random_name}")
result = model.transcribe_stable(np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0, task=task, input_sr=16000, language=language)
if(hf_transformers):
result = model.transcribe(np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0, task=task, input_sr=16000, language=language, batch_size=hf_batch_size)
else:
result = model.transcribe_stable(np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0, task=task, input_sr=16000, language=language)
elapsed_time = time.time() - start_time
minutes, seconds = divmod(int(elapsed_time), 60)
print(f"Bazarr transcription is completed, it took {minutes} minutes and {seconds} seconds to complete.")
except Exception as e:
print(f"Error processing or transcribing Bazarr {audio_file.filename}: {e}")
files_to_transcribe.remove(f"Bazarr-detect-langauge-{random_name}")
delete_model()
return StreamingResponse(
iter(result.to_srt_vtt(filepath = None, word_level=word_level_highlight)),
media_type="text/plain",
headers={
'Source': 'Transcribed using stable-ts, faster-whisper from Subgen!',
})
finally:
if f"Bazarr-detect-langauge-{random_name}" in files_to_transcribe:
files_to_transcribe.remove(f"Bazarr-detect-langauge-{random_name}")
delete_model()
if result:
return StreamingResponse(
iter(result.to_srt_vtt(filepath = None, word_level=word_level_highlight)),
media_type="text/plain",
headers={
'Source': 'Transcribed using stable-ts from Subgen!',
})
else:
return
@app.post("/detect-language")
def detect_language(
audio_file: UploadFile = File(...),
#encode: bool = Query(default=True, description="Encode audio first through ffmpeg") # This is always false from Bazarr
):
start_model()
try:
#give the 'process' a random name so mutliple Bazaar transcribes can operate at the same time.
result = None
random_name = random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890", k=6)
start_model()
#give the 'process' a random name so mutliple Bazaar transcribes can operate at the same time.
random_name = random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890", k=6)
files_to_transcribe.insert(0, f"Bazarr-detect-langauge-{random_name}")
detected_lang_code = model.transcribe_stable(whisper.pad_or_trim(np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0), input_sr=16000).language
files_to_transcribe.remove(f"Bazarr-detect-langauge-{random_name}")
delete_model()
return {"detected_language": get_lang_pair(whisper_languages, detected_lang_code), "language_code": detected_lang_code}
files_to_transcribe.insert(0, f"Bazarr-detect-langauge-{random_name}")
detected_lang_code = model.transcribe_stable(whisper.pad_or_trim(np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0), input_sr=16000).language
except Exception as e:
print(f"Error processing or transcribing Bazarr {audio_file.filename}: {e}")
finally:
if f"Bazarr-detect-langauge-{random_name}" in files_to_transcribe:
files_to_transcribe.remove(f"Bazarr-detect-langauge-{random_name}")
delete_model()
if result:
return {"detected_language": get_lang_pair(whisper_languages, detected_lang_code), "language_code": detected_lang_code}
else:
return
def start_model():
global model
if model is None:
logging.debug("Model was purged, need to re-create")
model = stable_whisper.load_faster_whisper(whisper_model, download_root=model_location, device=transcribe_device, cpu_threads=whisper_threads, num_workers=concurrent_transcriptions, compute_type=compute_type)
if(hf_transformers):
logging.debug("Use Hugging Face Transformers, whisper_threads, concurrent_transcriptions, and model_location variables are ignored!")
model = stable_whisper.load_hf_whisper(whisper_model, device=transcribe_device)
else:
model = stable_whisper.load_faster_whisper(whisper_model, download_root=model_location, device=transcribe_device, cpu_threads=whisper_threads, num_workers=concurrent_transcriptions, compute_type=compute_type)
def delete_model():
if len(files_to_transcribe) == 0:
global model
logging.debug("Queue is empty, clearing/releasing VRAM")
model = None
gc.collect()
if clear_vram_on_complete:
if len(files_to_transcribe) == 0:
global model
logging.debug("Queue is empty, clearing/releasing VRAM")
model = None
gc.collect()
def get_lang_pair(whisper_languages, key):
"""Returns the other side of the pair in the Whisper languages dictionary.
@@ -253,7 +290,7 @@ def get_lang_pair(whisper_languages, key):
else:
return whisper_languages[other_side]
def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True) -> None:
def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True, forceLanguage=force_detected_language_to) -> None:
"""Generates subtitles for a video file.
Args:
@@ -268,12 +305,16 @@ def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True)
return None
if file_path not in files_to_transcribe:
message = None
if has_subtitle_language(file_path, skipifinternalsublang):
logging.debug(f"{file_path} already has an internal sub we want, skipping generation")
return f"{file_path} already has an internal sub we want, skipping generation"
elif os.path.exists(get_file_name_without_extension(file_path) + subextension):
logging.debug(f"{file_path} already has a subgen created for this, skipping it")
return f"{file_path} already has a subgen created for this, skipping it"
message = f"{file_path} already has an internal subtitle we want, skipping generation"
elif os.path.exists(file_path.rsplit('.', 1)[0] + subextension):
message = f"{file_path} already has a subtitle created for this, skipping it"
elif os.path.exists(file_path.rsplit('.', 1)[0] + subextensionSDH):
message = f"{file_path} already has a SDH subtitle created for this, skipping it"
if message != None:
print(message)
#return message
if front:
files_to_transcribe.insert(0, file_path)
@@ -287,11 +328,10 @@ def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True)
start_time = time.time()
start_model()
if(force_detected_language_to):
logging.debug(f"Forcing detected audio language to {force_detected_language_to}")
result = model.transcribe_stable(file_path, language=force_detected_language_to, task=transcribe_or_translate_str)
if(hf_transformers):
result = model.transcribe(file_path, language=forceLanguage, batch_size=hf_batch_size, task=transcribe_or_translate_str)
else:
result = model.transcribe_stable(file_path, task=transcribe_or_translate_str)
result = model.transcribe_stable(file_path, language=forceLanguage, task=transcribe_or_translate_str)
result.to_srt_vtt(get_file_name_without_extension(file_path) + subextension, word_level=word_level_highlight)
elapsed_time = time.time() - start_time
minutes, seconds = divmod(int(elapsed_time), 60)
@@ -416,7 +456,7 @@ def refresh_jellyfin_metadata(itemid: str, server_ip: str, jellyfin_token: str)
users = json.loads(requests.get(f"{server_ip}/Users", headers=headers).content)
jellyfin_admin = get_jellyfin_admin(users)
response = requests.get(f"{server_ip}/Users/{jellyfin_admin}/Items/{item_id}/Refresh", headers=headers)
response = requests.get(f"{server_ip}/Users/{jellyfin_admin}/Items/{itemid}/Refresh", headers=headers)
# Sending the PUT request to refresh metadata
response = requests.post(url, headers=headers)
@@ -480,8 +520,7 @@ def path_mapping(fullpath):
return fullpath.replace(path_mapping_from, path_mapping_to)
return fullpath
def transcribe_existing():
global transcribe_folders
def transcribe_existing(transcribe_folders, forceLanguage=None):
transcribe_folders = transcribe_folders.split("|")
print("Starting to search folders to see if we need to create subtitles.")
logging.debug("The folders are:")
@@ -490,12 +529,13 @@ def transcribe_existing():
for root, dirs, files in os.walk(path):
for file in files:
file_path = os.path.join(root, file)
gen_subtitles(path_mapping(file_path), transcribe_or_translate, False)
gen_subtitles(path_mapping(file_path), transcribe_or_translate, False, forceLanguage)
# if the path specified was actually a single file and not a folder, process it
if os.path.isfile(path):
if is_video_file(path):
gen_subtitles(path_mapping(path), transcribe_or_translate, False, forceLanguage)
print("Finished searching and queueing files for transcription")
if transcribe_folders:
transcribe_existing()
whisper_languages = {
"en": "english",
@@ -599,8 +639,12 @@ whisper_languages = {
"su": "sundanese",
}
if __name__ == "__main__":
import uvicorn
print("Starting Subgen with listening webhooks!")
print(f"Transcriptions are limited to running {str(concurrent_transcriptions)} at a time")
print(f"Running {str(whisper_threads)} threads per transcription")
print(f"Using {transcribe_device} to encode")
if transcribe_folders:
transcribe_existing(transcribe_folders)
uvicorn.run("subgen:app", host="0.0.0.0", port=int(webhookport), reload=debug, use_colors=True)