Unload model after 60 seconds inactivity. Attempt to clear more cuda vram.
This commit is contained in:
100
subgen.py
100
subgen.py
@@ -1,4 +1,4 @@
|
||||
subgen_version = '2025.02.54'
|
||||
subgen_version = '2025.02.55'
|
||||
|
||||
from language_code import LanguageCode
|
||||
from datetime import datetime
|
||||
@@ -682,18 +682,96 @@ def extract_audio_segment_to_memory(input_file, start_time, duration):
|
||||
logging.error(f"Error: {str(e)}")
|
||||
return None
|
||||
|
||||
def start_model():
|
||||
global model
|
||||
if model is None:
|
||||
logging.debug("Model was purged, need to re-create")
|
||||
model = stable_whisper.load_faster_whisper(whisper_model, download_root=model_location, device=transcribe_device, cpu_threads=whisper_threads, num_workers=concurrent_transcriptions, compute_type=compute_type)
|
||||
# --- Global Model Variables ---
|
||||
_model_loading = False
|
||||
_model_lock = threading.Lock() # Protects access to `model` and `_model_loading`
|
||||
_unload_timer = None
|
||||
_unload_timer_lock = threading.Lock()
|
||||
_unload_delay_seconds = 60 # Adjust as needed
|
||||
|
||||
# --- Model Loading/Unloading Functions ---
|
||||
|
||||
async def _load_model():
|
||||
"""Asynchronously load the Whisper model."""
|
||||
global model, _model_loading
|
||||
logging.debug("Starting asynchronous model load...")
|
||||
try:
|
||||
with _model_lock:
|
||||
_model_loading = True
|
||||
|
||||
model = load_faster_whisper(
|
||||
whisper_model,
|
||||
download_root=model_location,
|
||||
device=transcribe_device,
|
||||
cpu_threads=whisper_threads,
|
||||
num_workers=concurrent_transcriptions,
|
||||
compute_type=compute_type
|
||||
)
|
||||
logging.debug("Model loaded asynchronously.")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading model asynchronously: {e}")
|
||||
# Handle the error (e.g., set model to None, log it, etc.)
|
||||
with _model_lock:
|
||||
model = None # Ensure model is None on load failure
|
||||
finally:
|
||||
with _model_lock:
|
||||
_model_loading = False
|
||||
|
||||
def _unload_model_callback():
|
||||
"""Callback function to unload the model after a delay."""
|
||||
with _unload_timer_lock:
|
||||
delete_model() # Call delete_model inside the lock
|
||||
|
||||
def delete_model():
|
||||
global model
|
||||
if clear_vram_on_complete and task_queue.is_idle():
|
||||
logging.debug("Queue idle; clearing model from memory.")
|
||||
model = None
|
||||
gc.collect()
|
||||
"""Unload the Whisper model, clear memory, and cancel any pending timers."""
|
||||
global model, _unload_timer
|
||||
|
||||
with _model_lock:
|
||||
if model is not None:
|
||||
logging.debug("Clearing model from memory (delayed).")
|
||||
model = None
|
||||
# Explicitly release CUDA memory if using CUDA *and* the device is CUDA
|
||||
try:
|
||||
if transcribe_device.lower() == 'cuda' and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
logging.debug("CUDA cache cleared.")
|
||||
gc.collect()
|
||||
except Exception as e:
|
||||
logging.error(f"Error clearing CUDA cache: {e}")
|
||||
|
||||
# Cancel the unload timer
|
||||
with _unload_timer_lock:
|
||||
if _unload_timer:
|
||||
try: # Add a try-except block in case the timer has already been cancelled
|
||||
_unload_timer.cancel()
|
||||
except ValueError:
|
||||
pass # Timer was already cancelled
|
||||
_unload_timer = None
|
||||
|
||||
def start_model():
|
||||
"""Start the model loading process (asynchronously) or reset the unload timer."""
|
||||
global model, _unload_timer
|
||||
|
||||
with _model_lock:
|
||||
if model is None and not _model_loading:
|
||||
logging.debug("Starting model loading")
|
||||
asyncio.create_task(_load_model()) # Start loading in background
|
||||
elif _model_loading:
|
||||
logging.debug("Model is currently loading...")
|
||||
else:
|
||||
logging.debug("Model is already loaded.")
|
||||
|
||||
# Reset the timer if the model is loaded or loading
|
||||
with _unload_timer_lock:
|
||||
if _unload_timer:
|
||||
try:
|
||||
_unload_timer.cancel()
|
||||
except ValueError:
|
||||
pass
|
||||
_unload_timer = threading.Timer(_unload_delay_seconds, _unload_model_callback)
|
||||
_unload_timer.daemon = True
|
||||
_unload_timer.start()
|
||||
|
||||
def isAudioFileExtension(file_extension):
|
||||
return file_extension.casefold() in \
|
||||
|
||||
Reference in New Issue
Block a user