|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
T5_NAME = "mcmonkey/google_t5-v1_1-xxl_encoderonly" |
|
|
|
|
|
|
|
|
from diffusers import StableDiffusionXLPipeline, DiffusionPipeline |
|
|
from transformers import T5Tokenizer, T5EncoderModel |
|
|
from transformers import ( |
|
|
CLIPImageProcessor, |
|
|
CLIPTextModel, |
|
|
CLIPTextModelWithProjection, |
|
|
CLIPTokenizer, |
|
|
CLIPVisionModelWithProjection, |
|
|
) |
|
|
|
|
|
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel |
|
|
from diffusers.schedulers import KarrasDiffusionSchedulers |
|
|
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor |
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
import torch.nn as nn, torch, types |
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
class LinearWithDtype(nn.Linear): |
|
|
@property |
|
|
def dtype(self): |
|
|
return self.weight.dtype |
|
|
|
|
|
|
|
|
class StableDiffusionXL_T5Pipeline(StableDiffusionXLPipeline): |
|
|
_expected_modules = [ |
|
|
"vae", "unet", "scheduler", "tokenizer", |
|
|
"image_encoder", "feature_extractor", |
|
|
"t5_encoder", "t5_projection", "t5_pooled_projection", |
|
|
] |
|
|
|
|
|
_optional_components = [ |
|
|
"image_encoder", "feature_extractor", |
|
|
"t5_encoder", "t5_projection", "t5_pooled_projection", |
|
|
] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vae: AutoencoderKL, |
|
|
unet: UNet2DConditionModel, |
|
|
scheduler: KarrasDiffusionSchedulers, |
|
|
tokenizer: T5Tokenizer, |
|
|
t5_encoder=None, |
|
|
t5_projection=None, |
|
|
t5_pooled_projection=None, |
|
|
image_encoder: CLIPVisionModelWithProjection = None, |
|
|
feature_extractor: CLIPImageProcessor = None, |
|
|
force_zeros_for_empty_prompt: bool = True, |
|
|
add_watermarker: Optional[bool] = None, |
|
|
): |
|
|
DiffusionPipeline.__init__(self) |
|
|
|
|
|
if t5_encoder is None: |
|
|
self.t5_encoder = T5EncoderModel.from_pretrained(T5_NAME, |
|
|
torch_dtype=unet.dtype) |
|
|
else: |
|
|
self.t5_encoder = t5_encoder |
|
|
|
|
|
|
|
|
if t5_projection is None: |
|
|
self.t5_projection = LinearWithDtype(4096, 2048) |
|
|
else: |
|
|
self.t5_projection = t5_projection |
|
|
self.t5_projection.to(dtype=unet.dtype) |
|
|
|
|
|
if t5_pooled_projection is None: |
|
|
self.t5_pooled_projection = LinearWithDtype(4096, 1280) |
|
|
else: |
|
|
self.t5_pooled_projection = t5_pooled_projection |
|
|
self.t5_pooled_projection.to(dtype=unet.dtype) |
|
|
|
|
|
print("dtype of Linear is ",self.t5_projection.dtype) |
|
|
|
|
|
self.register_modules( |
|
|
vae=vae, |
|
|
unet=unet, |
|
|
scheduler=scheduler, |
|
|
tokenizer=tokenizer, |
|
|
t5_encoder=self.t5_encoder, |
|
|
t5_projection=self.t5_projection, |
|
|
t5_pooled_projection=self.t5_pooled_projection, |
|
|
image_encoder=image_encoder, |
|
|
feature_extractor=feature_extractor, |
|
|
) |
|
|
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) |
|
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 |
|
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) |
|
|
|
|
|
self.default_sample_size = ( |
|
|
self.unet.config.sample_size |
|
|
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") |
|
|
else 128 |
|
|
) |
|
|
|
|
|
self.watermark = None |
|
|
|
|
|
|
|
|
|
|
|
self.text_encoder = self.text_encoder_2 = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def encode_prompt( |
|
|
self, |
|
|
prompt, |
|
|
num_images_per_prompt: int = 1, |
|
|
do_classifier_free_guidance: bool = True, |
|
|
negative_prompt: str | None = None, |
|
|
**_, |
|
|
): |
|
|
""" |
|
|
Returns |
|
|
------- |
|
|
prompt_embeds : Tensor [B, T, 2048] |
|
|
negative_prompt_embeds : Tensor [B, T, 2048] | None |
|
|
pooled_prompt_embeds : Tensor [B, 1280] |
|
|
negative_pooled_prompt_embeds: Tensor [B, 1280] | None |
|
|
where B = batch * num_images_per_prompt |
|
|
""" |
|
|
|
|
|
|
|
|
def _tok(text: str): |
|
|
tok_out = self.tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
padding="max_length", |
|
|
max_length=self.tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
).to(self.device) |
|
|
return tok_out.input_ids, tok_out.attention_mask |
|
|
|
|
|
|
|
|
ids, mask = _tok(prompt) |
|
|
h_pos = self.t5_encoder(ids, attention_mask=mask).last_hidden_state |
|
|
tok_pos = self.t5_projection(h_pos) |
|
|
pool_pos = self.t5_pooled_projection(h_pos.mean(dim=1)) |
|
|
|
|
|
|
|
|
tok_pos = tok_pos.repeat_interleave(num_images_per_prompt, 0) |
|
|
pool_pos = pool_pos.repeat_interleave(num_images_per_prompt, 0) |
|
|
|
|
|
|
|
|
if do_classifier_free_guidance: |
|
|
neg_text = "" if negative_prompt is None else negative_prompt |
|
|
ids_n, mask_n = _tok(neg_text) |
|
|
h_neg = self.t5_encoder(ids_n, attention_mask=mask_n).last_hidden_state |
|
|
tok_neg = self.t5_projection(h_neg) |
|
|
pool_neg = self.t5_pooled_projection(h_neg.mean(dim=1)) |
|
|
|
|
|
tok_neg = tok_neg.repeat_interleave(num_images_per_prompt, 0) |
|
|
pool_neg = pool_neg.repeat_interleave(num_images_per_prompt, 0) |
|
|
else: |
|
|
tok_neg = pool_neg = None |
|
|
|
|
|
return tok_pos, tok_neg, pool_pos, pool_neg |
|
|
|