Train image-to-image SDXL UNet Finetuning Practical with code (Local Dataset)

Community Article Published December 5, 2025

Introduction

In this article we explore how to use a local Dataset to train UNet for a text-to-image model. The model is here visit deewaiREALCN There are two important branches:

  • main: It has the original model and can be used in Spaces
  • training: It has some extra files and sample dataset for fine-tuning.

1) Goal and scope

Goal: Adapt an existing SDXL model to your data by training the UNet denoiser while keeping the VAE and text encoders frozen.

Why UNet-only finetuning: The UNet is the module that learns to denoise latents conditioned on text. Training just the UNet is a common, efficient approach to specialize a diffusion model without fully retraining the encoders or VAE. It also reduces compute and risk of destabilizing the full pipeline.

2) Repository and data layout assumptions

The script assumes this structure:

training/data/train/images/

training/data/train/captions.jsonl

training/data/val/images/

training/data/val/captions.jsonl

training/output/ for saved weights and checkpoints

A local SDXL diffusers repo at REPO_ROOT (it loads the model from the local folder via StableDiffusionXLPipeline.from_pretrained(MODEL_ROOT, ...))

Each captions.jsonl line is expected to be JSON like:

{"file_name": "0001.png", "text": "a caption ..."}

3) Step-by-step pipeline

Step A , Parse CLI args and set training plan

You configure:

  • device selection (auto/cuda/cpu)

  • precision (bf16/fp16/fp32)

  • training steps

  • gradient accumulation

  • LR scheduler type (constant/cosine) plus warmup

  • eval cadence and checkpoint steps

Why: This makes experiments repeatable and supports stable training recipes (warmup, clipping, eval, etc).

Step B , Device and precision setup (AMP choices)

  • If CPU: force fp32 and disable autocast.

  • If CUDA + bf16: enable autocast (no GradScaler).

  • If CUDA + fp16: enable autocast and enable GradScaler.

  • If CUDA + fp32: disable autocast.

Why: Mixed precision can reduce memory and increase throughput. For fp16, dynamic loss scaling via GradScaler is standard to avoid underflow and training instability.

Step C , Load SDXL pipeline from local model folder

pipe = StableDiffusionXLPipeline.from_pretrained(MODEL_ROOT, ...)
pipe.to(device)

Why: Diffusers pipelines provide a consistent way to load and save model components. You are using the pipeline primarily as a container for UNet, VAE, tokenizers, and text encoders.

Step D , Freeze VAE and text encoders, train only UNet

pipe.vae.requires_grad_(False)
pipe.text_encoder.requires_grad_(False)
pipe.text_encoder_2.requires_grad_(False)
pipe.unet.train()

Why: Focus learning capacity on denoising behavior for your domain while keeping the latent space and text embedding spaces stable.

Step E , Build dataset and image preprocessing

Your CaptionDataset:

  • loads images and captions

  • resizes to (image_size, image_size)

  • normalizes to roughly [-1, 1] by Normalize([0.5],[0.5])

Why: Latent diffusion pipelines expect normalized image tensors; consistent preprocessing is critical for stable VAE encoding and training.

Step F , Encode text prompts into SDXL conditioning

Function encode_text():

  • tokenizes prompts with both SDXL tokenizers (tokenizer, tokenizer_2)

  • runs both text encoders with output_hidden_states=True

  • uses penultimate hidden states (common SD practice) and concatenates them

  • constructs pooled embeddings via text_projection

Why: SDXL conditioning uses both text encoders, plus pooled text embeddings. This matches SDXL’s expected conditioning format.

Step G , Convert images to latents with the VAE

Function prepare_latents():

  • encodes images with the frozen VAE

  • samples from the latent distribution and scales by vae.config.scaling_factor

Why: Stable Diffusion-style models operate in latent space for efficiency.

Step H , Add noise with a DDPM scheduler (training-time corruption)

You instantiate DDPMScheduler(...) and do:

  • sample random timesteps

  • sample Gaussian noise

  • create noisy_latents = scheduler.add_noise(latents, noise, timesteps)

Why: Diffusion training learns denoising from partially noised samples. DDPM scheduling and add_noise are the standard mechanics for training-time forward diffusion.

Step I , Provide SDXL “added conditions” (time ids)

You compute: add_time_ids = pipe._get_add_time_ids(...) repeated for the batch and pass it as:

added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}

Why: SDXL UNet expects additional conditioning for resolution and related metadata (the “added conditions” pathway).

Step J , UNet forward pass and loss (epsilon prediction)

You run:

  • model_pred = unet(noisy_latents, timesteps, ...)

  • loss: MSE(model_pred, noise) (prediction type epsilon)

Why: With epsilon prediction, the UNet learns to predict the exact noise that was added, and MSE is the standard objective.

Step K , Gradient accumulation, clipping, and optimizer step

  • divide loss by grad_accum

  • backward (scaled when fp16)

every grad_accum micro-steps:

  • unscale (if fp16) then clip gradients

  • compute grad norm for logging

  • optimizer.step(), scheduler.step(), zero_grad()

Why: accumulation simulates a larger batch size when VRAM is limited

  • gradient clipping helps stabilize spikes
  • unscale-before-clip is important for fp16 scaling correctness.

Step L , LR warmup and (optional) cosine decay

make_lr_scheduler() produces a LambdaLR:

  • warmup ramps LR from near 0 to base LR over lr_warmup_steps

  • then constant or cosine schedule

Why: Warmup helps avoid early divergence, cosine decay can improve convergence later.

Step M , Periodic evaluation (validation loss)

Every eval_every optimizer steps:

  • switch UNet to eval

  • compute the same denoising MSE on val batches

  • log eval/loss to W&B

Why: Validation loss helps detect overfitting and choose an early stopping point (for example, when eval loss rises while train loss continues to fall).

Step N , Checkpointing and final saving

  • Mid-train: save at configured optimizer steps (example: 600, 680, 750)

  • End: save UNet to training/output/ in:

  • fp32 safe tensors

  • fp16 variant

Why: Mid checkpoints let you pick the best step retrospectively using eval curves. Diffusers save_pretrained format makes it easy to reload later.

Step N , Checkpointing and final saving

  • Mid-train: save at configured optimizer steps (example: 600, 680, 750)

  • End: save UNet to training/output/ in:

  • fp32 safe tensors

  • fp16 variant

Why: Mid checkpoints let you pick the best step retrospectively using eval curves. Diffusers save_pretrained format makes it easy to reload later.

4) Experiment tracking in W&B (what is logged and why)

You initialize W&B with config (hyperparameters, device, dtype), then log:

  • train/loss, train/lr, train/grad_norm

  • optional GPU memory stats

  • eval/loss

  • final artifacts: UNet weight files as a W&B Artifact

Why: W&B provides run comparison, metric charts, and structured model/version tracking via Artifacts.

5) Outputs you get

Local outputs:

  • training/output/

  • diffusion_pytorch_model.safetensors (fp32)

  • diffusion_pytorch_model.fp16.safetensors (fp16 variant)

  • training/output/checkpoints/checkpoint-/... (if enabled)

Remote outputs (W&B):

  • charts for train and eval metrics

  • a model artifact containing the saved UNet files

Key reference docs

🤗 Diffusers (SDXL pipeline and loading/saving)

🤗 Diffusers (Schedulers and training-time noise)

PyTorch AMP (autocast + GradScaler)

Weights & Biases (logging and Artifacts)


Community

Sign up or log in to comment