From 476320878fbd6ae3a745c11105ae9a26b0dac33f Mon Sep 17 00:00:00 2001 From: McCloudS Date: Thu, 6 Feb 2025 11:11:55 -0700 Subject: [PATCH] partial reversion and attempt to clear cuda vram --- subgen.py | 103 ++++++++---------------------------------------------- 1 file changed, 14 insertions(+), 89 deletions(-) diff --git a/subgen.py b/subgen.py index d1d976b..b2dc9ec 100644 --- a/subgen.py +++ b/subgen.py @@ -1,4 +1,4 @@ -subgen_version = '2025.02.55' +subgen_version = '2025.02.56' from language_code import LanguageCode from datetime import datetime @@ -682,96 +682,21 @@ def extract_audio_segment_to_memory(input_file, start_time, duration): logging.error(f"Error: {str(e)}") return None -# --- Global Model Variables --- -_model_loading = False -_model_lock = threading.Lock() # Protects access to `model` and `_model_loading` -_unload_timer = None -_unload_timer_lock = threading.Lock() -_unload_delay_seconds = 60 # Adjust as needed - -# --- Model Loading/Unloading Functions --- - -async def _load_model(): - """Asynchronously load the Whisper model.""" - global model, _model_loading - logging.debug("Starting asynchronous model load...") - try: - with _model_lock: - _model_loading = True - - model = load_faster_whisper( - whisper_model, - download_root=model_location, - device=transcribe_device, - cpu_threads=whisper_threads, - num_workers=concurrent_transcriptions, - compute_type=compute_type - ) - logging.debug("Model loaded asynchronously.") - - except Exception as e: - logging.error(f"Error loading model asynchronously: {e}") - # Handle the error (e.g., set model to None, log it, etc.) - with _model_lock: - model = None # Ensure model is None on load failure - finally: - with _model_lock: - _model_loading = False - -def _unload_model_callback(): - """Callback function to unload the model after a delay.""" - with _unload_timer_lock: - delete_model() # Call delete_model inside the lock +def start_model(): + global model + if model is None: + logging.debug("Model was purged, need to re-create") + model = stable_whisper.load_faster_whisper(whisper_model, download_root=model_location, device=transcribe_device, cpu_threads=whisper_threads, num_workers=concurrent_transcriptions, compute_type=compute_type) def delete_model(): - """Unload the Whisper model, clear memory, and cancel any pending timers.""" - global model, _unload_timer - - with _model_lock: - if model is not None: - logging.debug("Clearing model from memory (delayed).") - model = None - # Explicitly release CUDA memory if using CUDA *and* the device is CUDA - try: - if transcribe_device.lower() == 'cuda' and torch.cuda.is_available(): - torch.cuda.empty_cache() - logging.debug("CUDA cache cleared.") - gc.collect() - except Exception as e: - logging.error(f"Error clearing CUDA cache: {e}") - - # Cancel the unload timer - with _unload_timer_lock: - if _unload_timer: - try: # Add a try-except block in case the timer has already been cancelled - _unload_timer.cancel() - except ValueError: - pass # Timer was already cancelled - _unload_timer = None - -def start_model(): - """Start the model loading process (asynchronously) or reset the unload timer.""" - global model, _unload_timer - - with _model_lock: - if model is None and not _model_loading: - logging.debug("Starting model loading") - asyncio.create_task(_load_model()) # Start loading in background - elif _model_loading: - logging.debug("Model is currently loading...") - else: - logging.debug("Model is already loaded.") - - # Reset the timer if the model is loaded or loading - with _unload_timer_lock: - if _unload_timer: - try: - _unload_timer.cancel() - except ValueError: - pass - _unload_timer = threading.Timer(_unload_delay_seconds, _unload_model_callback) - _unload_timer.daemon = True - _unload_timer.start() + global model + if clear_vram_on_complete and task_queue.is_idle(): + logging.debug("Queue idle; clearing model from memory.") + model = None + if transcribe_device.lower() == 'cuda' and torch.cuda.is_available(): + torch.cuda.empty_cache() + logging.debug("CUDA cache cleared.") + gc.collect() def isAudioFileExtension(file_extension): return file_extension.casefold() in \