diff --git a/subgen.py b/subgen.py index 689c923..d1d976b 100644 --- a/subgen.py +++ b/subgen.py @@ -1,4 +1,4 @@ -subgen_version = '2025.02.54' +subgen_version = '2025.02.55' from language_code import LanguageCode from datetime import datetime @@ -682,18 +682,96 @@ def extract_audio_segment_to_memory(input_file, start_time, duration): logging.error(f"Error: {str(e)}") return None -def start_model(): - global model - if model is None: - logging.debug("Model was purged, need to re-create") - model = stable_whisper.load_faster_whisper(whisper_model, download_root=model_location, device=transcribe_device, cpu_threads=whisper_threads, num_workers=concurrent_transcriptions, compute_type=compute_type) +# --- Global Model Variables --- +_model_loading = False +_model_lock = threading.Lock() # Protects access to `model` and `_model_loading` +_unload_timer = None +_unload_timer_lock = threading.Lock() +_unload_delay_seconds = 60 # Adjust as needed + +# --- Model Loading/Unloading Functions --- + +async def _load_model(): + """Asynchronously load the Whisper model.""" + global model, _model_loading + logging.debug("Starting asynchronous model load...") + try: + with _model_lock: + _model_loading = True + + model = load_faster_whisper( + whisper_model, + download_root=model_location, + device=transcribe_device, + cpu_threads=whisper_threads, + num_workers=concurrent_transcriptions, + compute_type=compute_type + ) + logging.debug("Model loaded asynchronously.") + + except Exception as e: + logging.error(f"Error loading model asynchronously: {e}") + # Handle the error (e.g., set model to None, log it, etc.) + with _model_lock: + model = None # Ensure model is None on load failure + finally: + with _model_lock: + _model_loading = False + +def _unload_model_callback(): + """Callback function to unload the model after a delay.""" + with _unload_timer_lock: + delete_model() # Call delete_model inside the lock def delete_model(): - global model - if clear_vram_on_complete and task_queue.is_idle(): - logging.debug("Queue idle; clearing model from memory.") - model = None - gc.collect() + """Unload the Whisper model, clear memory, and cancel any pending timers.""" + global model, _unload_timer + + with _model_lock: + if model is not None: + logging.debug("Clearing model from memory (delayed).") + model = None + # Explicitly release CUDA memory if using CUDA *and* the device is CUDA + try: + if transcribe_device.lower() == 'cuda' and torch.cuda.is_available(): + torch.cuda.empty_cache() + logging.debug("CUDA cache cleared.") + gc.collect() + except Exception as e: + logging.error(f"Error clearing CUDA cache: {e}") + + # Cancel the unload timer + with _unload_timer_lock: + if _unload_timer: + try: # Add a try-except block in case the timer has already been cancelled + _unload_timer.cancel() + except ValueError: + pass # Timer was already cancelled + _unload_timer = None + +def start_model(): + """Start the model loading process (asynchronously) or reset the unload timer.""" + global model, _unload_timer + + with _model_lock: + if model is None and not _model_loading: + logging.debug("Starting model loading") + asyncio.create_task(_load_model()) # Start loading in background + elif _model_loading: + logging.debug("Model is currently loading...") + else: + logging.debug("Model is already loaded.") + + # Reset the timer if the model is loaded or loading + with _unload_timer_lock: + if _unload_timer: + try: + _unload_timer.cancel() + except ValueError: + pass + _unload_timer = threading.Timer(_unload_delay_seconds, _unload_model_callback) + _unload_timer.daemon = True + _unload_timer.start() def isAudioFileExtension(file_extension): return file_extension.casefold() in \