REPA-E / SD VAE on Apple Silicon (Mac)
Run REPA-E and SD-family VAEs on a Mac through PyTorch + MPS first.
That gives you the Apple GPU path with the fewest moving parts.
This repo uses diffusers AutoencoderKL checkpoints directly:
stabilityai/sd-vae-ft-mseREPA-E/e2e-sdvae-hfREPA-E/e2e-flux-vae
REPA-E’s Hugging Face releases are already diffusers-compatible, so you do not need a custom wrapper just to encode/decode images.
REPA-E’s own repo and model cards are the reference points for these weights and their intended use.
References
- REPA-E main repo:
End2End-Diffusion/REPA-E - REPA-E T2I training/integration repo:
End2End-Diffusion/fuse-dit - HF model cards:
REPA-E/e2e-sdvae-hfREPA-E/e2e-flux-vae
What works on Mac
PyTorch’s mps device is the supported Apple GPU backend.
On macOS 12.3+ with an MPS-capable Apple device, moving tensors and models to "mps" runs them on the GPU. :contentReference[oaicite:0]{index=0}
For REPA-E specifically, the current Hugging Face releases are loadable through diffusers.AutoencoderKL.from_pretrained(...), and the model cards list diffusers>=0.33.0 and torch>=2.3.1 as the required baseline. :contentReference[oaicite:1]{index=1}
Install
1) Create an env
python3 -m venv .venv
source .venv/bin/activate
python -m pip install --upgrade pip
2) Install PyTorch
Use the current PyTorch install instructions for macOS from the official site. Then install the VAE stack:
pip install "torch>=2.3.1" torchvision
pip install "diffusers>=0.33.0" transformers accelerate safetensors pillow requests huggingface_hub
Inference: encode + decode an image
Save this as infer_vae.py.
from pathlib import Path
import argparse
import numpy as np
from PIL import Image
import torch
from diffusers import AutoencoderKL
def pick_device() -> str:
if torch.backends.mps.is_available():
return "mps"
if torch.cuda.is_available():
return "cuda"
return "cpu"
def load_image(image_path: str, size: int | None = None) -> torch.Tensor:
image = Image.open(image_path).convert("RGB")
if size is not None:
image = image.resize((size, size), Image.LANCZOS)
arr = np.array(image).astype(np.float32)
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
tensor = tensor / 127.5 - 1.0
return tensor
def save_image(tensor: torch.Tensor, out_path: str) -> None:
x = tensor.detach().float().cpu().clamp(-1, 1)
x = ((x + 1.0) * 127.5).round().to(torch.uint8)
x = x.squeeze(0).permute(1, 2, 0).numpy()
Image.fromarray(x).save(out_path)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
default="REPA-E/e2e-sdvae-hf",
help="VAE model id. Example: stabilityai/sd-vae-ft-mse, REPA-E/e2e-sdvae-hf, REPA-E/e2e-flux-vae",
)
parser.add_argument("--image", type=str, required=True, help="Input image path")
parser.add_argument("--outdir", type=str, default="outputs", help="Output folder")
parser.add_argument("--resize", type=int, default=None, help="Optional square resize")
parser.add_argument("--sample-latents", action="store_true", help="Use latent_dist.sample() instead of mode()")
args = parser.parse_args()
device = pick_device()
dtype = torch.float16 if device in {"mps", "cuda"} else torch.float32
outdir = Path(args.outdir)
outdir.mkdir(parents=True, exist_ok=True)
image = load_image(args.image, args.resize).to(device=device, dtype=dtype)
vae = AutoencoderKL.from_pretrained(args.model).to(device=device, dtype=dtype).eval()
with torch.inference_mode():
enc = vae.encode(image)
latents = enc.latent_dist.sample() if args.sample_latents else enc.latent_dist.mode()
recon = vae.decode(latents).sample
latents_cpu = latents.detach().float().cpu()
torch.save(latents_cpu, outdir / "latents.pt")
save_image(recon, outdir / "reconstruction.png")
print(f"device={device}")
print(f"input_shape={tuple(image.shape)}")
print(f"latents_shape={tuple(latents.shape)}")
print(f"saved: {outdir / 'latents.pt'}")
print(f"saved: {outdir / 'reconstruction.png'}")
if __name__ == "__main__":
main()
Run it:
python infer_vae.py \
--model REPA-E/e2e-sdvae-hf \
--image assets/example.png \
--outdir outputs_sdvae \
--resize 512
- Downloads last month
- 14