|
|
import torch
|
|
|
from modules import script_callbacks, shared
|
|
|
|
|
|
last_model = None
|
|
|
|
|
|
def on_model_loaded(checkpoint_info):
|
|
|
"""Offload VRAM when model changes"""
|
|
|
global last_model
|
|
|
current_model = checkpoint_info.filename
|
|
|
if last_model != current_model and last_model is not None:
|
|
|
if hasattr(shared, 'sd_model') and shared.sd_model is not None:
|
|
|
shared.sd_model.to('cpu')
|
|
|
torch.cuda.empty_cache()
|
|
|
print(f"Offloaded model {last_model} from VRAM")
|
|
|
else:
|
|
|
torch.cuda.empty_cache()
|
|
|
print("Cleared VRAM (no model object available)")
|
|
|
last_model = current_model
|
|
|
|
|
|
|
|
|
script_callbacks.on_model_loaded(on_model_loaded) |