diff --git a/subgen/subgen.py b/subgen/subgen.py index fb4fbfd..39f60da 100644 --- a/subgen/subgen.py +++ b/subgen/subgen.py @@ -8,8 +8,10 @@ import time import queue import logging import gc +import io from array import array -from typing import Union, Any +from typing import BinaryIO, Union, Any +import random # List of packages to install packages_to_install = [ @@ -20,22 +22,18 @@ packages_to_install = [ 'faster-whisper', 'uvicorn', 'python-multipart', + 'whisper', # Add more packages as needed ] -for package in packages_to_install: - print(f"Installing {package}...") - try: - subprocess.run(['pip3', 'install', package], check=True) - print(f"{package} has been successfully installed.") - except subprocess.CalledProcessError as e: - print(f"Failed to install {package}: {e}") - from fastapi import FastAPI, File, UploadFile, Query, Header, Body, Form, Request -from fastapi.responses import StreamingResponse, RedirectResponse +from fastapi.responses import StreamingResponse, RedirectResponse +import numpy as np import stable_whisper import requests import av +import ffmpeg +import whisper def convert_to_bool(in_bool): if isinstance(in_bool, bool): @@ -169,6 +167,86 @@ def receive_emby_webhook( return "" +@app.post("/asr") +async def asr( + task: Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]), + language: Union[str, None] = Query(default=None), + initial_prompt: Union[str, None] = Query(default=None), #not used by Bazarr + audio_file: UploadFile = File(...), + encode: bool = Query(default=True, description="Encode audio first through ffmpeg"), #not used by Bazarr/always False + output: Union[str, None] = Query(default="srt", enum=["txt", "vtt", "srt", "tsv", "json"]), + word_timestamps: bool = Query(default=False, description="Word level timestamps") #not used by Bazarr +): + try: + print(f"Transcribing file from Bazarr/ASR webhook") + start_time = time.time() + start_model() + + #give the 'process' a random name so mutliple Bazaar transcribes can operate at the same time. + random_name = random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890", k=6) + files_to_transcribe.insert(0, f"Bazarr-detect-langauge-{random_name}") + result = model.transcribe_stable(np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0, task=task, input_sr=16000) + elapsed_time = time.time() - start_time + minutes, seconds = divmod(int(elapsed_time), 60) + print(f"Bazarr transcription is completed, it took {minutes} minutes and {seconds} seconds to complete.") + except Exception as e: + print(f"Error processing or transcribing Bazarr {audio_file.filename}: {e}") + files_to_transcribe.remove(f"Bazarr-detect-langauge-{random_name}") + delete_model() + return StreamingResponse( + iter(result.to_srt_vtt(filepath = None, word_level=word_level_highlight)), + media_type="text/plain", + headers={ + 'Source': 'Transcribed using stable-ts, faster-whisper from Subgen!', + }) + +@app.post("/detect-language") +async def detect_language( + audio_file: UploadFile = File(...), + #encode: bool = Query(default=True, description="Encode audio first through ffmpeg") # This is always false from Bazarr +): + start_model() + + #give the 'process' a random name so mutliple Bazaar transcribes can operate at the same time. + random_name = random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890", k=6) + files_to_transcribe.insert(0, f"Bazarr-detect-langauge-{random_name}") + detected_lang_code = model.transcribe_stable(whisper.pad_or_trim(np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0), input_sr=16000).language + + files_to_transcribe.remove(f"Bazarr-detect-langauge-{random_name}") + delete_model() + return {"detected_language": get_lang_pair(whisper_languages, detected_lang_code), "language_code": detected_lang_code} + +def start_model(): + global model + if model is None: + logging.debug("Model was purged, need to re-create") + model = stable_whisper.load_faster_whisper(whisper_model, download_root=model_location, device=transcribe_device, cpu_threads=whisper_threads, num_workers=concurrent_transcriptions) + +def delete_model(): + if len(files_to_transcribe) == 0: + global model + logging.debug("Queue is empty, clearing/releasing VRAM") + del model + gc.collect() + +def get_lang_pair(whisper_languages, key): + """Returns the other side of the pair in the Whisper languages dictionary. + + Args: + whisper_languages: A dictionary of Whisper languages. + key: The key to look up in the dictionary. + + Returns: + The other side of the pair in the Whisper languages dictionary, or None if the + key is not found in the dictionary. + """ + + other_side = whisper_languages.get(key) + if other_side is None: + return key + else: + return whisper_languages[other_side] + def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True) -> None: """Generates subtitles for a video file. @@ -177,7 +255,6 @@ def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True) transcription_or_translation: The type of transcription or translation to perform. front: Whether to add the file to the front of the transcription queue. """ - global model try: if not is_video_file(file_path): @@ -202,9 +279,7 @@ def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True) print(f"{len(files_to_transcribe)} files in the queue for transcription") print(f"Transcribing file: {os.path.basename(file_path)}") start_time = time.time() - if model is None: - logging.debug("Model was purged, need to re-create") - model = stable_whisper.load_faster_whisper(whisper_model, download_root=model_location, device=transcribe_device, cpu_threads=whisper_threads, num_workers=concurrent_transcriptions) + start_model() result = model.transcribe_stable(file_path, task=transcribe_or_translate_str) result.to_srt_vtt(file_path.rsplit('.', 1)[0] + subextension, word_level=word_level_highlight) @@ -216,15 +291,9 @@ def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True) print(f"File {os.path.basename(file_path)} is already in the transcription list. Skipping.") except Exception as e: - print(f"Error processing or transcribing {file_path}: {e}") + print(f"Error processing or transcribing {video_file_path}: {e}") finally: - if len(files_to_transcribe) == 0: - logging.debug("Queue is empty, clearing/releasing VRAM") - try: - del model - except Exception as e: - None - gc.collect() + delete_model() def has_subtitle_language(video_file, target_language): try: @@ -346,6 +415,108 @@ if transcribe_folders: transcribe_folders = transcribe_folders.split(",") transcribe_existing() +whisper_languages = { + "en": "english", + "zh": "chinese", + "de": "german", + "es": "spanish", + "ru": "russian", + "ko": "korean", + "fr": "french", + "ja": "japanese", + "pt": "portuguese", + "tr": "turkish", + "pl": "polish", + "ca": "catalan", + "nl": "dutch", + "ar": "arabic", + "sv": "swedish", + "it": "italian", + "id": "indonesian", + "hi": "hindi", + "fi": "finnish", + "vi": "vietnamese", + "he": "hebrew", + "uk": "ukrainian", + "el": "greek", + "ms": "malay", + "cs": "czech", + "ro": "romanian", + "da": "danish", + "hu": "hungarian", + "ta": "tamil", + "no": "norwegian", + "th": "thai", + "ur": "urdu", + "hr": "croatian", + "bg": "bulgarian", + "lt": "lithuanian", + "la": "latin", + "mi": "maori", + "ml": "malayalam", + "cy": "welsh", + "sk": "slovak", + "te": "telugu", + "fa": "persian", + "lv": "latvian", + "bn": "bengali", + "sr": "serbian", + "az": "azerbaijani", + "sl": "slovenian", + "kn": "kannada", + "et": "estonian", + "mk": "macedonian", + "br": "breton", + "eu": "basque", + "is": "icelandic", + "hy": "armenian", + "ne": "nepali", + "mn": "mongolian", + "bs": "bosnian", + "kk": "kazakh", + "sq": "albanian", + "sw": "swahili", + "gl": "galician", + "mr": "marathi", + "pa": "punjabi", + "si": "sinhala", + "km": "khmer", + "sn": "shona", + "yo": "yoruba", + "so": "somali", + "af": "afrikaans", + "oc": "occitan", + "ka": "georgian", + "be": "belarusian", + "tg": "tajik", + "sd": "sindhi", + "gu": "gujarati", + "am": "amharic", + "yi": "yiddish", + "lo": "lao", + "uz": "uzbek", + "fo": "faroese", + "ht": "haitian creole", + "ps": "pashto", + "tk": "turkmen", + "nn": "nynorsk", + "mt": "maltese", + "sa": "sanskrit", + "lb": "luxembourgish", + "my": "myanmar", + "bo": "tibetan", + "tl": "tagalog", + "mg": "malagasy", + "as": "assamese", + "tt": "tatar", + "haw": "hawaiian", + "ln": "lingala", + "ha": "hausa", + "ba": "bashkir", + "jw": "javanese", + "su": "sundanese", +} + print("Starting webhook!") if __name__ == "__main__": import uvicorn