Update subgen.py
Added some suggestions from JaiZed (https://github.com/JaiZed/subgen) Added a check if existing SDH subtitle is there, added /batch endpoint, more elegant forceLanguage implementation.
This commit is contained in:
156
subgen/subgen.py
156
subgen/subgen.py
@@ -28,18 +28,18 @@ def convert_to_bool(in_bool):
|
|||||||
return value not in ('false', 'off', '0')
|
return value not in ('false', 'off', '0')
|
||||||
|
|
||||||
# Replace your getenv calls with appropriate default values here
|
# Replace your getenv calls with appropriate default values here
|
||||||
plextoken = os.getenv('PLEXTOKEN', "token here")
|
plextoken = os.getenv('PLEXTOKEN', 'token here')
|
||||||
plexserver = os.getenv('PLEXSERVER', "http://192.168.1.111:32400")
|
plexserver = os.getenv('PLEXSERVER', 'http://192.168.1.111:32400')
|
||||||
jellyfintoken = os.getenv('JELLYFINTOKEN', "token here")
|
jellyfintoken = os.getenv('JELLYFINTOKEN', 'token here')
|
||||||
jellyfinserver = os.getenv('JELLYFINSERVER', "http://192.168.1.111:8096")
|
jellyfinserver = os.getenv('JELLYFINSERVER', 'http://192.168.1.111:8096')
|
||||||
whisper_model = os.getenv('WHISPER_MODEL', "medium")
|
whisper_model = os.getenv('WHISPER_MODEL', 'medium')
|
||||||
whisper_threads = int(os.getenv('WHISPER_THREADS', 4))
|
whisper_threads = int(os.getenv('WHISPER_THREADS', 4))
|
||||||
concurrent_transcriptions = int(os.getenv('CONCURRENT_TRANSCRIPTIONS', '2'))
|
concurrent_transcriptions = int(os.getenv('CONCURRENT_TRANSCRIPTIONS', 2))
|
||||||
transcribe_device = os.getenv('TRANSCRIBE_DEVICE', "cpu")
|
transcribe_device = os.getenv('TRANSCRIBE_DEVICE', 'cpu')
|
||||||
procaddedmedia = convert_to_bool(os.getenv('PROCADDEDMEDIA', True))
|
procaddedmedia = convert_to_bool(os.getenv('PROCADDEDMEDIA', True))
|
||||||
procmediaonplay = convert_to_bool(os.getenv('PROCMEDIAONPLAY', True))
|
procmediaonplay = convert_to_bool(os.getenv('PROCMEDIAONPLAY', True))
|
||||||
namesublang = os.getenv('NAMESUBLANG', "aa")
|
namesublang = os.getenv('NAMESUBLANG', 'aa')
|
||||||
skipifinternalsublang = os.getenv('SKIPIFINTERNALSUBLANG', "eng")
|
skipifinternalsublang = os.getenv('SKIPIFINTERNALSUBLANG', 'eng')
|
||||||
webhookport = int(os.getenv('WEBHOOKPORT', 8090))
|
webhookport = int(os.getenv('WEBHOOKPORT', 8090))
|
||||||
word_level_highlight = convert_to_bool(os.getenv('WORD_LEVEL_HIGHLIGHT', False))
|
word_level_highlight = convert_to_bool(os.getenv('WORD_LEVEL_HIGHLIGHT', False))
|
||||||
debug = convert_to_bool(os.getenv('DEBUG', False))
|
debug = convert_to_bool(os.getenv('DEBUG', False))
|
||||||
@@ -49,7 +49,10 @@ path_mapping_to = os.getenv('PATH_MAPPING_TO', '/Volumes/TV')
|
|||||||
model_location = os.getenv('MODEL_PATH', '.')
|
model_location = os.getenv('MODEL_PATH', '.')
|
||||||
transcribe_folders = os.getenv('TRANSCRIBE_FOLDERS', '')
|
transcribe_folders = os.getenv('TRANSCRIBE_FOLDERS', '')
|
||||||
transcribe_or_translate = os.getenv('TRANSCRIBE_OR_TRANSLATE', 'translate')
|
transcribe_or_translate = os.getenv('TRANSCRIBE_OR_TRANSLATE', 'translate')
|
||||||
force_detected_language_to = os.getenv('FORCE_DETECTED_LANGUAGE_TO', '')
|
force_detected_language_to = os.getenv('FORCE_DETECTED_LANGUAGE_TO', None)
|
||||||
|
hf_transformers = os.getenv('HF_TRANSFORMERS', False)
|
||||||
|
hf_batch_size = os.getenv('HF_BATCH_SIZE', 24)
|
||||||
|
clear_vram_on_complete = os.getenv('CLEAR_VRAM_ON_COMPLETE', True)
|
||||||
compute_type = os.getenv('COMPUTE_TYPE', 'auto')
|
compute_type = os.getenv('COMPUTE_TYPE', 'auto')
|
||||||
if transcribe_device == "gpu":
|
if transcribe_device == "gpu":
|
||||||
transcribe_device = "cuda"
|
transcribe_device = "cuda"
|
||||||
@@ -58,8 +61,7 @@ app = FastAPI()
|
|||||||
model = None
|
model = None
|
||||||
files_to_transcribe = []
|
files_to_transcribe = []
|
||||||
subextension = f".subgen.{whisper_model.split('.')[0]}.{namesublang}.srt"
|
subextension = f".subgen.{whisper_model.split('.')[0]}.{namesublang}.srt"
|
||||||
print(f"Transcriptions are limited to running {str(concurrent_transcriptions)} at a time")
|
subextensionSDH = f".subgen.{whisper_model.split('.')[0]}.{namesublang}.sdh.srt"
|
||||||
print(f"Running {str(whisper_threads)} threads per transcription")
|
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
logging.basicConfig(stream=sys.stderr, level=logging.NOTSET)
|
logging.basicConfig(stream=sys.stderr, level=logging.NOTSET)
|
||||||
@@ -72,6 +74,7 @@ else:
|
|||||||
@app.get("/asr")
|
@app.get("/asr")
|
||||||
@app.get("/emby")
|
@app.get("/emby")
|
||||||
@app.get("/detect-language")
|
@app.get("/detect-language")
|
||||||
|
@app.get("/")
|
||||||
def handle_get_request(request: Request):
|
def handle_get_request(request: Request):
|
||||||
return "You accessed this request incorrectly via a GET request. See https://github.com/McCloudS/subgen for proper configuration"
|
return "You accessed this request incorrectly via a GET request. See https://github.com/McCloudS/subgen for proper configuration"
|
||||||
|
|
||||||
@@ -172,6 +175,14 @@ def receive_emby_webhook(
|
|||||||
print("This doesn't appear to be a properly configured Emby webhook, please review the instructions again!")
|
print("This doesn't appear to be a properly configured Emby webhook, please review the instructions again!")
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@app.post("/batch")
|
||||||
|
def batch(
|
||||||
|
directory: Union[str, None] = Query(default=None),
|
||||||
|
forceLanguage: Union[str, None] = Query(default=None)
|
||||||
|
):
|
||||||
|
transcribe_existing(directory, forceLanguage)
|
||||||
|
|
||||||
# idea and some code for asr and detect language from https://github.com/ahmetoner/whisper-asr-webservice
|
# idea and some code for asr and detect language from https://github.com/ahmetoner/whisper-asr-webservice
|
||||||
@app.post("/asr")
|
@app.post("/asr")
|
||||||
def asr(
|
def asr(
|
||||||
@@ -185,55 +196,81 @@ def asr(
|
|||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
print(f"Transcribing file from Bazarr/ASR webhook")
|
print(f"Transcribing file from Bazarr/ASR webhook")
|
||||||
|
result = None
|
||||||
|
#give the 'process' a random name so mutliple Bazaar transcribes can operate at the same time.
|
||||||
|
random_name = random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890", k=6)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
start_model()
|
start_model()
|
||||||
|
|
||||||
#give the 'process' a random name so mutliple Bazaar transcribes can operate at the same time.
|
|
||||||
random_name = random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890", k=6)
|
|
||||||
files_to_transcribe.insert(0, f"Bazarr-detect-langauge-{random_name}")
|
files_to_transcribe.insert(0, f"Bazarr-detect-langauge-{random_name}")
|
||||||
result = model.transcribe_stable(np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0, task=task, input_sr=16000, language=language)
|
if(hf_transformers):
|
||||||
|
result = model.transcribe(np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0, task=task, input_sr=16000, language=language, batch_size=hf_batch_size)
|
||||||
|
else:
|
||||||
|
result = model.transcribe_stable(np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0, task=task, input_sr=16000, language=language)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
minutes, seconds = divmod(int(elapsed_time), 60)
|
minutes, seconds = divmod(int(elapsed_time), 60)
|
||||||
print(f"Bazarr transcription is completed, it took {minutes} minutes and {seconds} seconds to complete.")
|
print(f"Bazarr transcription is completed, it took {minutes} minutes and {seconds} seconds to complete.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing or transcribing Bazarr {audio_file.filename}: {e}")
|
print(f"Error processing or transcribing Bazarr {audio_file.filename}: {e}")
|
||||||
files_to_transcribe.remove(f"Bazarr-detect-langauge-{random_name}")
|
finally:
|
||||||
delete_model()
|
if f"Bazarr-detect-langauge-{random_name}" in files_to_transcribe:
|
||||||
return StreamingResponse(
|
files_to_transcribe.remove(f"Bazarr-detect-langauge-{random_name}")
|
||||||
iter(result.to_srt_vtt(filepath = None, word_level=word_level_highlight)),
|
delete_model()
|
||||||
media_type="text/plain",
|
if result:
|
||||||
headers={
|
return StreamingResponse(
|
||||||
'Source': 'Transcribed using stable-ts, faster-whisper from Subgen!',
|
iter(result.to_srt_vtt(filepath = None, word_level=word_level_highlight)),
|
||||||
})
|
media_type="text/plain",
|
||||||
|
headers={
|
||||||
|
'Source': 'Transcribed using stable-ts from Subgen!',
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
@app.post("/detect-language")
|
@app.post("/detect-language")
|
||||||
def detect_language(
|
def detect_language(
|
||||||
audio_file: UploadFile = File(...),
|
audio_file: UploadFile = File(...),
|
||||||
#encode: bool = Query(default=True, description="Encode audio first through ffmpeg") # This is always false from Bazarr
|
#encode: bool = Query(default=True, description="Encode audio first through ffmpeg") # This is always false from Bazarr
|
||||||
):
|
):
|
||||||
start_model()
|
try:
|
||||||
|
#give the 'process' a random name so mutliple Bazaar transcribes can operate at the same time.
|
||||||
|
result = None
|
||||||
|
random_name = random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890", k=6)
|
||||||
|
start_model()
|
||||||
|
|
||||||
#give the 'process' a random name so mutliple Bazaar transcribes can operate at the same time.
|
files_to_transcribe.insert(0, f"Bazarr-detect-langauge-{random_name}")
|
||||||
random_name = random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890", k=6)
|
detected_lang_code = model.transcribe_stable(whisper.pad_or_trim(np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0), input_sr=16000).language
|
||||||
files_to_transcribe.insert(0, f"Bazarr-detect-langauge-{random_name}")
|
|
||||||
detected_lang_code = model.transcribe_stable(whisper.pad_or_trim(np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0), input_sr=16000).language
|
except Exception as e:
|
||||||
|
print(f"Error processing or transcribing Bazarr {audio_file.filename}: {e}")
|
||||||
files_to_transcribe.remove(f"Bazarr-detect-langauge-{random_name}")
|
|
||||||
delete_model()
|
finally:
|
||||||
return {"detected_language": get_lang_pair(whisper_languages, detected_lang_code), "language_code": detected_lang_code}
|
if f"Bazarr-detect-langauge-{random_name}" in files_to_transcribe:
|
||||||
|
files_to_transcribe.remove(f"Bazarr-detect-langauge-{random_name}")
|
||||||
|
delete_model()
|
||||||
|
|
||||||
|
if result:
|
||||||
|
return {"detected_language": get_lang_pair(whisper_languages, detected_lang_code), "language_code": detected_lang_code}
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
def start_model():
|
def start_model():
|
||||||
global model
|
global model
|
||||||
if model is None:
|
if model is None:
|
||||||
logging.debug("Model was purged, need to re-create")
|
logging.debug("Model was purged, need to re-create")
|
||||||
model = stable_whisper.load_faster_whisper(whisper_model, download_root=model_location, device=transcribe_device, cpu_threads=whisper_threads, num_workers=concurrent_transcriptions, compute_type=compute_type)
|
if(hf_transformers):
|
||||||
|
logging.debug("Use Hugging Face Transformers, whisper_threads, concurrent_transcriptions, and model_location variables are ignored!")
|
||||||
|
model = stable_whisper.load_hf_whisper(whisper_model, device=transcribe_device)
|
||||||
|
else:
|
||||||
|
model = stable_whisper.load_faster_whisper(whisper_model, download_root=model_location, device=transcribe_device, cpu_threads=whisper_threads, num_workers=concurrent_transcriptions, compute_type=compute_type)
|
||||||
|
|
||||||
def delete_model():
|
def delete_model():
|
||||||
if len(files_to_transcribe) == 0:
|
if clear_vram_on_complete:
|
||||||
global model
|
if len(files_to_transcribe) == 0:
|
||||||
logging.debug("Queue is empty, clearing/releasing VRAM")
|
global model
|
||||||
model = None
|
logging.debug("Queue is empty, clearing/releasing VRAM")
|
||||||
gc.collect()
|
model = None
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
def get_lang_pair(whisper_languages, key):
|
def get_lang_pair(whisper_languages, key):
|
||||||
"""Returns the other side of the pair in the Whisper languages dictionary.
|
"""Returns the other side of the pair in the Whisper languages dictionary.
|
||||||
@@ -253,7 +290,7 @@ def get_lang_pair(whisper_languages, key):
|
|||||||
else:
|
else:
|
||||||
return whisper_languages[other_side]
|
return whisper_languages[other_side]
|
||||||
|
|
||||||
def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True) -> None:
|
def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True, forceLanguage=force_detected_language_to) -> None:
|
||||||
"""Generates subtitles for a video file.
|
"""Generates subtitles for a video file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -268,12 +305,16 @@ def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True)
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if file_path not in files_to_transcribe:
|
if file_path not in files_to_transcribe:
|
||||||
|
message = None
|
||||||
if has_subtitle_language(file_path, skipifinternalsublang):
|
if has_subtitle_language(file_path, skipifinternalsublang):
|
||||||
logging.debug(f"{file_path} already has an internal sub we want, skipping generation")
|
message = f"{file_path} already has an internal subtitle we want, skipping generation"
|
||||||
return f"{file_path} already has an internal sub we want, skipping generation"
|
elif os.path.exists(file_path.rsplit('.', 1)[0] + subextension):
|
||||||
elif os.path.exists(get_file_name_without_extension(file_path) + subextension):
|
message = f"{file_path} already has a subtitle created for this, skipping it"
|
||||||
logging.debug(f"{file_path} already has a subgen created for this, skipping it")
|
elif os.path.exists(file_path.rsplit('.', 1)[0] + subextensionSDH):
|
||||||
return f"{file_path} already has a subgen created for this, skipping it"
|
message = f"{file_path} already has a SDH subtitle created for this, skipping it"
|
||||||
|
if message != None:
|
||||||
|
print(message)
|
||||||
|
#return message
|
||||||
|
|
||||||
if front:
|
if front:
|
||||||
files_to_transcribe.insert(0, file_path)
|
files_to_transcribe.insert(0, file_path)
|
||||||
@@ -287,11 +328,10 @@ def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True)
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
start_model()
|
start_model()
|
||||||
|
|
||||||
if(force_detected_language_to):
|
if(hf_transformers):
|
||||||
logging.debug(f"Forcing detected audio language to {force_detected_language_to}")
|
result = model.transcribe(file_path, language=forceLanguage, batch_size=hf_batch_size, task=transcribe_or_translate_str)
|
||||||
result = model.transcribe_stable(file_path, language=force_detected_language_to, task=transcribe_or_translate_str)
|
|
||||||
else:
|
else:
|
||||||
result = model.transcribe_stable(file_path, task=transcribe_or_translate_str)
|
result = model.transcribe_stable(file_path, language=forceLanguage, task=transcribe_or_translate_str)
|
||||||
result.to_srt_vtt(get_file_name_without_extension(file_path) + subextension, word_level=word_level_highlight)
|
result.to_srt_vtt(get_file_name_without_extension(file_path) + subextension, word_level=word_level_highlight)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
minutes, seconds = divmod(int(elapsed_time), 60)
|
minutes, seconds = divmod(int(elapsed_time), 60)
|
||||||
@@ -416,7 +456,7 @@ def refresh_jellyfin_metadata(itemid: str, server_ip: str, jellyfin_token: str)
|
|||||||
users = json.loads(requests.get(f"{server_ip}/Users", headers=headers).content)
|
users = json.loads(requests.get(f"{server_ip}/Users", headers=headers).content)
|
||||||
jellyfin_admin = get_jellyfin_admin(users)
|
jellyfin_admin = get_jellyfin_admin(users)
|
||||||
|
|
||||||
response = requests.get(f"{server_ip}/Users/{jellyfin_admin}/Items/{item_id}/Refresh", headers=headers)
|
response = requests.get(f"{server_ip}/Users/{jellyfin_admin}/Items/{itemid}/Refresh", headers=headers)
|
||||||
|
|
||||||
# Sending the PUT request to refresh metadata
|
# Sending the PUT request to refresh metadata
|
||||||
response = requests.post(url, headers=headers)
|
response = requests.post(url, headers=headers)
|
||||||
@@ -480,8 +520,7 @@ def path_mapping(fullpath):
|
|||||||
return fullpath.replace(path_mapping_from, path_mapping_to)
|
return fullpath.replace(path_mapping_from, path_mapping_to)
|
||||||
return fullpath
|
return fullpath
|
||||||
|
|
||||||
def transcribe_existing():
|
def transcribe_existing(transcribe_folders, forceLanguage=None):
|
||||||
global transcribe_folders
|
|
||||||
transcribe_folders = transcribe_folders.split("|")
|
transcribe_folders = transcribe_folders.split("|")
|
||||||
print("Starting to search folders to see if we need to create subtitles.")
|
print("Starting to search folders to see if we need to create subtitles.")
|
||||||
logging.debug("The folders are:")
|
logging.debug("The folders are:")
|
||||||
@@ -490,12 +529,13 @@ def transcribe_existing():
|
|||||||
for root, dirs, files in os.walk(path):
|
for root, dirs, files in os.walk(path):
|
||||||
for file in files:
|
for file in files:
|
||||||
file_path = os.path.join(root, file)
|
file_path = os.path.join(root, file)
|
||||||
gen_subtitles(path_mapping(file_path), transcribe_or_translate, False)
|
gen_subtitles(path_mapping(file_path), transcribe_or_translate, False, forceLanguage)
|
||||||
|
# if the path specified was actually a single file and not a folder, process it
|
||||||
|
if os.path.isfile(path):
|
||||||
|
if is_video_file(path):
|
||||||
|
gen_subtitles(path_mapping(path), transcribe_or_translate, False, forceLanguage)
|
||||||
|
|
||||||
print("Finished searching and queueing files for transcription")
|
print("Finished searching and queueing files for transcription")
|
||||||
|
|
||||||
if transcribe_folders:
|
|
||||||
transcribe_existing()
|
|
||||||
|
|
||||||
whisper_languages = {
|
whisper_languages = {
|
||||||
"en": "english",
|
"en": "english",
|
||||||
@@ -599,8 +639,12 @@ whisper_languages = {
|
|||||||
"su": "sundanese",
|
"su": "sundanese",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
print("Starting Subgen with listening webhooks!")
|
print("Starting Subgen with listening webhooks!")
|
||||||
|
print(f"Transcriptions are limited to running {str(concurrent_transcriptions)} at a time")
|
||||||
|
print(f"Running {str(whisper_threads)} threads per transcription")
|
||||||
|
print(f"Using {transcribe_device} to encode")
|
||||||
|
if transcribe_folders:
|
||||||
|
transcribe_existing(transcribe_folders)
|
||||||
uvicorn.run("subgen:app", host="0.0.0.0", port=int(webhookport), reload=debug, use_colors=True)
|
uvicorn.run("subgen:app", host="0.0.0.0", port=int(webhookport), reload=debug, use_colors=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user