Update subgen.py

Added ASR webhook for Bazarr
This commit is contained in:
McCloudS
2023-10-31 01:23:41 -06:00
committed by GitHub
parent 2b4956cfa1
commit 59802f9e01

View File

@@ -8,8 +8,10 @@ import time
import queue import queue
import logging import logging
import gc import gc
import io
from array import array from array import array
from typing import Union, Any from typing import BinaryIO, Union, Any
import random
# List of packages to install # List of packages to install
packages_to_install = [ packages_to_install = [
@@ -20,22 +22,18 @@ packages_to_install = [
'faster-whisper', 'faster-whisper',
'uvicorn', 'uvicorn',
'python-multipart', 'python-multipart',
'whisper',
# Add more packages as needed # Add more packages as needed
] ]
for package in packages_to_install:
print(f"Installing {package}...")
try:
subprocess.run(['pip3', 'install', package], check=True)
print(f"{package} has been successfully installed.")
except subprocess.CalledProcessError as e:
print(f"Failed to install {package}: {e}")
from fastapi import FastAPI, File, UploadFile, Query, Header, Body, Form, Request from fastapi import FastAPI, File, UploadFile, Query, Header, Body, Form, Request
from fastapi.responses import StreamingResponse, RedirectResponse from fastapi.responses import StreamingResponse, RedirectResponse
import numpy as np
import stable_whisper import stable_whisper
import requests import requests
import av import av
import ffmpeg
import whisper
def convert_to_bool(in_bool): def convert_to_bool(in_bool):
if isinstance(in_bool, bool): if isinstance(in_bool, bool):
@@ -169,6 +167,86 @@ def receive_emby_webhook(
return "" return ""
@app.post("/asr")
async def asr(
task: Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]),
language: Union[str, None] = Query(default=None),
initial_prompt: Union[str, None] = Query(default=None), #not used by Bazarr
audio_file: UploadFile = File(...),
encode: bool = Query(default=True, description="Encode audio first through ffmpeg"), #not used by Bazarr/always False
output: Union[str, None] = Query(default="srt", enum=["txt", "vtt", "srt", "tsv", "json"]),
word_timestamps: bool = Query(default=False, description="Word level timestamps") #not used by Bazarr
):
try:
print(f"Transcribing file from Bazarr/ASR webhook")
start_time = time.time()
start_model()
#give the 'process' a random name so mutliple Bazaar transcribes can operate at the same time.
random_name = random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890", k=6)
files_to_transcribe.insert(0, f"Bazarr-detect-langauge-{random_name}")
result = model.transcribe_stable(np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0, task=task, input_sr=16000)
elapsed_time = time.time() - start_time
minutes, seconds = divmod(int(elapsed_time), 60)
print(f"Bazarr transcription is completed, it took {minutes} minutes and {seconds} seconds to complete.")
except Exception as e:
print(f"Error processing or transcribing Bazarr {audio_file.filename}: {e}")
files_to_transcribe.remove(f"Bazarr-detect-langauge-{random_name}")
delete_model()
return StreamingResponse(
iter(result.to_srt_vtt(filepath = None, word_level=word_level_highlight)),
media_type="text/plain",
headers={
'Source': 'Transcribed using stable-ts, faster-whisper from Subgen!',
})
@app.post("/detect-language")
async def detect_language(
audio_file: UploadFile = File(...),
#encode: bool = Query(default=True, description="Encode audio first through ffmpeg") # This is always false from Bazarr
):
start_model()
#give the 'process' a random name so mutliple Bazaar transcribes can operate at the same time.
random_name = random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890", k=6)
files_to_transcribe.insert(0, f"Bazarr-detect-langauge-{random_name}")
detected_lang_code = model.transcribe_stable(whisper.pad_or_trim(np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0), input_sr=16000).language
files_to_transcribe.remove(f"Bazarr-detect-langauge-{random_name}")
delete_model()
return {"detected_language": get_lang_pair(whisper_languages, detected_lang_code), "language_code": detected_lang_code}
def start_model():
global model
if model is None:
logging.debug("Model was purged, need to re-create")
model = stable_whisper.load_faster_whisper(whisper_model, download_root=model_location, device=transcribe_device, cpu_threads=whisper_threads, num_workers=concurrent_transcriptions)
def delete_model():
if len(files_to_transcribe) == 0:
global model
logging.debug("Queue is empty, clearing/releasing VRAM")
del model
gc.collect()
def get_lang_pair(whisper_languages, key):
"""Returns the other side of the pair in the Whisper languages dictionary.
Args:
whisper_languages: A dictionary of Whisper languages.
key: The key to look up in the dictionary.
Returns:
The other side of the pair in the Whisper languages dictionary, or None if the
key is not found in the dictionary.
"""
other_side = whisper_languages.get(key)
if other_side is None:
return key
else:
return whisper_languages[other_side]
def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True) -> None: def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True) -> None:
"""Generates subtitles for a video file. """Generates subtitles for a video file.
@@ -177,7 +255,6 @@ def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True)
transcription_or_translation: The type of transcription or translation to perform. transcription_or_translation: The type of transcription or translation to perform.
front: Whether to add the file to the front of the transcription queue. front: Whether to add the file to the front of the transcription queue.
""" """
global model
try: try:
if not is_video_file(file_path): if not is_video_file(file_path):
@@ -202,9 +279,7 @@ def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True)
print(f"{len(files_to_transcribe)} files in the queue for transcription") print(f"{len(files_to_transcribe)} files in the queue for transcription")
print(f"Transcribing file: {os.path.basename(file_path)}") print(f"Transcribing file: {os.path.basename(file_path)}")
start_time = time.time() start_time = time.time()
if model is None: start_model()
logging.debug("Model was purged, need to re-create")
model = stable_whisper.load_faster_whisper(whisper_model, download_root=model_location, device=transcribe_device, cpu_threads=whisper_threads, num_workers=concurrent_transcriptions)
result = model.transcribe_stable(file_path, task=transcribe_or_translate_str) result = model.transcribe_stable(file_path, task=transcribe_or_translate_str)
result.to_srt_vtt(file_path.rsplit('.', 1)[0] + subextension, word_level=word_level_highlight) result.to_srt_vtt(file_path.rsplit('.', 1)[0] + subextension, word_level=word_level_highlight)
@@ -216,15 +291,9 @@ def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True)
print(f"File {os.path.basename(file_path)} is already in the transcription list. Skipping.") print(f"File {os.path.basename(file_path)} is already in the transcription list. Skipping.")
except Exception as e: except Exception as e:
print(f"Error processing or transcribing {file_path}: {e}") print(f"Error processing or transcribing {video_file_path}: {e}")
finally: finally:
if len(files_to_transcribe) == 0: delete_model()
logging.debug("Queue is empty, clearing/releasing VRAM")
try:
del model
except Exception as e:
None
gc.collect()
def has_subtitle_language(video_file, target_language): def has_subtitle_language(video_file, target_language):
try: try:
@@ -346,6 +415,108 @@ if transcribe_folders:
transcribe_folders = transcribe_folders.split(",") transcribe_folders = transcribe_folders.split(",")
transcribe_existing() transcribe_existing()
whisper_languages = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
}
print("Starting webhook!") print("Starting webhook!")
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn