demo-8efmrqxy / models.py
Gertie01's picture
Deploy Gradio app with multiple files
2ec7ba4 verified
import spaces
import torch
from diffusers import DiffusionPipeline
from PIL import Image
import os
import numpy as np # Required for some internal diffusers operations / data types
# Model ID for Stable Diffusion XL Base
MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
# Load the pipeline globally
# Use float16 for reduced memory usage and faster inference on GPU
pipe = DiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
variant="fp16", # Explicitly specify the fp16 variant
use_safetensors=True
)
pipe.to("cuda")
# --- ZeroGPU AoT Compilation (MANDATORY for local diffusion models) ---
# This function compiles key components of the diffusion pipeline ahead-of-time (AoT)
# to achieve significant performance improvements (1.3x-1.8x speedup) on Hugging Face Spaces.
# It uses the @spaces.GPU decorator with a long duration to ensure the compilation
# completes during the Space's startup phase.
@spaces.GPU(duration=1500) # Maximum duration allowed for startup tasks
def compile_diffusion_pipeline_components():
print("Starting AoT compilation for Diffusion Pipeline components...")
# Compile text_encoder (CLIPTextModel)
print("Compiling pipe.text_encoder...")
with torch.no_grad():
# Prepare dummy input for text_encoder
text_input_ids = pipe.tokenizer(
"a test prompt",
padding="max_length",
max_length=pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids.to("cuda")
# Capture and compile pipe.text_encoder
with spaces.aoti_capture(pipe.text_encoder) as call:
pipe.text_encoder(text_input_ids)
exported_text_encoder = torch.export.export(
pipe.text_encoder,
args=call.args,
kwargs=call.kwargs,
)
compiled_text_encoder = spaces.aoti_compile(exported_text_encoder)
spaces.aoti_apply(compiled_text_encoder, pipe.text_encoder)
print("pipe.text_encoder compiled and applied.")
# Compile text_encoder_2 (CLIPTextModelWithProjection)
print("Compiling pipe.text_encoder_2...")
with torch.no_grad():
# Prepare dummy input for text_encoder_2
text_input_ids_2 = pipe.tokenizer_2(
"a test prompt",
padding="max_length",
max_length=pipe.tokenizer_2.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids.to("cuda")
# Capture and compile pipe.text_encoder_2
with spaces.aoti_capture(pipe.text_encoder_2) as call:
pipe.text_encoder_2(text_input_ids_2)
exported_text_encoder_2 = torch.export.export(
pipe.text_encoder_2,
args=call.args,
kwargs=call.kwargs,
)
compiled_text_encoder_2 = spaces.aoti_compile(exported_text_encoder_2)
spaces.aoti_apply(compiled_text_encoder_2, pipe.text_encoder_2)
print("pipe.text_encoder_2 compiled and applied.")
# Compile UNet (the most computationally intensive part)
# The `spaces.aoti_capture` needs to trace the UNet's forward pass within a pipeline call.
# We will perform a minimal single-step image-to-image generation to capture the UNet's inputs.
print("Compiling pipe.unet...")
with torch.no_grad():
# Create a tiny dummy image (512x512 is typical minimum for SDXL, will be resized internally)
dummy_input_image = Image.new('RGB', (512, 512), color='white')
dummy_prompt = "a small test image"
# Capture the UNet's forward pass during a pipeline run
# This implicitly provides the complex inputs (latents, timestep, encoder_hidden_states, etc.)
with spaces.aoti_capture(pipe.unet) as call:
_ = pipe(
prompt=dummy_prompt,
image=dummy_input_image,
num_inference_steps=1, # Minimal steps for faster capture
guidance_scale=7.5,
denoising_strength=0.8,
output_type="pil" # Ensure PIL output for compatibility
)
exported_unet = torch.export.export(
pipe.unet,
args=call.args,
kwargs=call.kwargs,
)
compiled_unet = spaces.aoti_compile(exported_unet)
spaces.aoti_apply(compiled_unet, pipe.unet)
print("pipe.unet compiled and applied.")
print("AoT compilation complete.")
# Call the compilation function once during the startup of the Space
compile_diffusion_pipeline_components()
@spaces.GPU(duration=60) # Decorate inference function with ZeroGPU
def remix_image(
image: Image.Image,
prompt: str,
negative_prompt: str,
guidance_scale: float,
denoising_strength: float,
) -> Image.Image:
"""
Remixes an input image based on a text prompt using a diffusion model.
Args:
image (PIL.Image.Image): The input image to remix.
prompt (str): The text prompt guiding the remixing.
negative_prompt (str): The negative prompt to guide generation away from.
guidance_scale (float): Classifier-free guidance scale.
denoising_strength (float): The strength of denoising applied to the image.
Higher values allow more creative freedom (more changes from original).
Lower values keep more of the original image's structure.
Returns:
PIL.Image.Image: The remixed image.
"""
if image.mode != "RGB":
image = image.convert("RGB")
print(f"Generating image with prompt: {prompt}")
print(f"Negative prompt: {negative_prompt}")
print(f"Guidance scale: {guidance_scale}, Denoising strength: {denoising_strength}")
generated_images = pipe(
prompt=prompt,
image=image,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
denoising_strength=denoising_strength,
num_inference_steps=25, # Good balance of quality and speed
output_type="pil"
).images
return generated_images[0] # Return the first generated image