Train image-to-image SDXL UNet Finetuning Practical with code (Local Dataset)
Introduction
In this article we explore how to use a local Dataset to train UNet for a text-to-image model. The model is here visit deewaiREALCN There are two important branches:
- main: It has the original model and can be used in Spaces
- training: It has some extra files and sample dataset for fine-tuning.
1) Goal and scope
Goal: Adapt an existing SDXL model to your data by training the UNet denoiser while keeping the VAE and text encoders frozen.
Why UNet-only finetuning: The UNet is the module that learns to denoise latents conditioned on text. Training just the UNet is a common, efficient approach to specialize a diffusion model without fully retraining the encoders or VAE. It also reduces compute and risk of destabilizing the full pipeline.
2) Repository and data layout assumptions
The script assumes this structure:
training/data/train/images/
training/data/train/captions.jsonl
training/data/val/images/
training/data/val/captions.jsonl
training/output/ for saved weights and checkpoints
A local SDXL diffusers repo at REPO_ROOT (it loads the model from the local folder via StableDiffusionXLPipeline.from_pretrained(MODEL_ROOT, ...))
Each captions.jsonl line is expected to be JSON like:
{"file_name": "0001.png", "text": "a caption ..."}
3) Step-by-step pipeline
Step A , Parse CLI args and set training plan
You configure:
device selection (auto/cuda/cpu)
precision (bf16/fp16/fp32)
training steps
gradient accumulation
LR scheduler type (constant/cosine) plus warmup
eval cadence and checkpoint steps
Why: This makes experiments repeatable and supports stable training recipes (warmup, clipping, eval, etc).
Step B , Device and precision setup (AMP choices)
If CPU: force fp32 and disable autocast.
If CUDA + bf16: enable autocast (no GradScaler).
If CUDA + fp16: enable autocast and enable GradScaler.
If CUDA + fp32: disable autocast.
Why: Mixed precision can reduce memory and increase throughput. For fp16, dynamic loss scaling via GradScaler is standard to avoid underflow and training instability.
Step C , Load SDXL pipeline from local model folder
pipe = StableDiffusionXLPipeline.from_pretrained(MODEL_ROOT, ...)
pipe.to(device)
Why: Diffusers pipelines provide a consistent way to load and save model components. You are using the pipeline primarily as a container for UNet, VAE, tokenizers, and text encoders.
Step D , Freeze VAE and text encoders, train only UNet
pipe.vae.requires_grad_(False)
pipe.text_encoder.requires_grad_(False)
pipe.text_encoder_2.requires_grad_(False)
pipe.unet.train()
Why: Focus learning capacity on denoising behavior for your domain while keeping the latent space and text embedding spaces stable.
Step E , Build dataset and image preprocessing
Your CaptionDataset:
loads images and captions
resizes to (image_size, image_size)
normalizes to roughly [-1, 1] by Normalize([0.5],[0.5])
Why: Latent diffusion pipelines expect normalized image tensors; consistent preprocessing is critical for stable VAE encoding and training.
Step F , Encode text prompts into SDXL conditioning
Function encode_text():
tokenizes prompts with both SDXL tokenizers (tokenizer, tokenizer_2)
runs both text encoders with output_hidden_states=True
uses penultimate hidden states (common SD practice) and concatenates them
constructs pooled embeddings via text_projection
Why: SDXL conditioning uses both text encoders, plus pooled text embeddings. This matches SDXL’s expected conditioning format.
Step G , Convert images to latents with the VAE
Function prepare_latents():
encodes images with the frozen VAE
samples from the latent distribution and scales by vae.config.scaling_factor
Why: Stable Diffusion-style models operate in latent space for efficiency.
Step H , Add noise with a DDPM scheduler (training-time corruption)
You instantiate DDPMScheduler(...) and do:
sample random timesteps
sample Gaussian noise
create noisy_latents = scheduler.add_noise(latents, noise, timesteps)
Why: Diffusion training learns denoising from partially noised samples. DDPM scheduling and add_noise are the standard mechanics for training-time forward diffusion.
Step I , Provide SDXL “added conditions” (time ids)
You compute: add_time_ids = pipe._get_add_time_ids(...) repeated for the batch and pass it as:
added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
Why: SDXL UNet expects additional conditioning for resolution and related metadata (the “added conditions” pathway).
Step J , UNet forward pass and loss (epsilon prediction)
You run:
model_pred = unet(noisy_latents, timesteps, ...)
loss: MSE(model_pred, noise) (prediction type epsilon)
Why: With epsilon prediction, the UNet learns to predict the exact noise that was added, and MSE is the standard objective.
Step K , Gradient accumulation, clipping, and optimizer step
divide loss by grad_accum
backward (scaled when fp16)
every grad_accum micro-steps:
unscale (if fp16) then clip gradients
compute grad norm for logging
optimizer.step(), scheduler.step(), zero_grad()
Why: accumulation simulates a larger batch size when VRAM is limited
- gradient clipping helps stabilize spikes
- unscale-before-clip is important for fp16 scaling correctness.
Step L , LR warmup and (optional) cosine decay
make_lr_scheduler() produces a LambdaLR:
warmup ramps LR from near 0 to base LR over lr_warmup_steps
then constant or cosine schedule
Why: Warmup helps avoid early divergence, cosine decay can improve convergence later.
Step M , Periodic evaluation (validation loss)
Every eval_every optimizer steps:
switch UNet to eval
compute the same denoising MSE on val batches
log eval/loss to W&B
Why: Validation loss helps detect overfitting and choose an early stopping point (for example, when eval loss rises while train loss continues to fall).
Step N , Checkpointing and final saving
Mid-train: save at configured optimizer steps (example: 600, 680, 750)
End: save UNet to training/output/ in:
fp32 safe tensors
fp16 variant
Why: Mid checkpoints let you pick the best step retrospectively using eval curves. Diffusers save_pretrained format makes it easy to reload later.
Step N , Checkpointing and final saving
Mid-train: save at configured optimizer steps (example: 600, 680, 750)
End: save UNet to training/output/ in:
fp32 safe tensors
fp16 variant
Why: Mid checkpoints let you pick the best step retrospectively using eval curves. Diffusers save_pretrained format makes it easy to reload later.
4) Experiment tracking in W&B (what is logged and why)
You initialize W&B with config (hyperparameters, device, dtype), then log:
train/loss, train/lr, train/grad_norm
optional GPU memory stats
eval/loss
final artifacts: UNet weight files as a W&B Artifact
Why: W&B provides run comparison, metric charts, and structured model/version tracking via Artifacts.
5) Outputs you get
Local outputs:
training/output/
diffusion_pytorch_model.safetensors (fp32)
diffusion_pytorch_model.fp16.safetensors (fp16 variant)
training/output/checkpoints/checkpoint-/... (if enabled)
Remote outputs (W&B):
charts for train and eval metrics
a model artifact containing the saved UNet files
Key reference docs
🤗 Diffusers (SDXL pipeline and loading/saving)
- SDXL pipeline API (components, inputs, conditioning):
- How
from_pretrained()andsave_pretrained()work for pipelines: - SDXL usage overview:
🤗 Diffusers (Schedulers and training-time noise)
- DDPMScheduler API (forward diffusion via
add_noise): - Example training walkthrough showing
add_noise: - Diffusion fundamentals (good explainer for why we add noise then predict it):
PyTorch AMP (autocast + GradScaler)
- Mixed precision package documentation:
- Practical AMP examples:
Weights & Biases (logging and Artifacts)
- Artifacts overview:
- Artifact API reference:
- How to construct and log an artifact: