File size: 7,157 Bytes
8f28c81 a583976 8f28c81 a583976 8f28c81 a583976 8f28c81 a583976 8f28c81 a583976 8f28c81 a583976 8f28c81 a583976 8f28c81 a583976 8f28c81 a583976 8f28c81 a583976 8f28c81 a583976 8f28c81 a583976 8f28c81 58ef10f 8f28c81 a583976 8f28c81 a583976 8f28c81 a583976 8f28c81 a583976 8f28c81 a583976 8f28c81 a583976 8f28c81 a583976 8f28c81 a583976 8f28c81 58ef10f 8f28c81 58ef10f 8f28c81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
# Copyright Philip Brown, ppbrown@github
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: At this time, the intent is to use the T5 encoder mentioned
# below, with zero changes.
# Therefore, the model deliberately does not store the T5 encoder model bytes,
# (Since they are not unique!)
# but instead takes advantage of huggingface hub cache loading
T5_NAME = "mcmonkey/google_t5-v1_1-xxl_encoderonly"
from diffusers import StableDiffusionXLPipeline, DiffusionPipeline
from transformers import T5Tokenizer, T5EncoderModel
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from typing import Optional
import torch.nn as nn, torch, types
import torch.nn as nn
class LinearWithDtype(nn.Linear):
@property
def dtype(self):
return self.weight.dtype
class StableDiffusionXL_T5Pipeline(StableDiffusionXLPipeline):
_expected_modules = [
"vae", "unet", "scheduler", "tokenizer",
"image_encoder", "feature_extractor",
"t5_encoder", "t5_projection", "t5_pooled_projection",
]
_optional_components = [
"image_encoder", "feature_extractor",
"t5_encoder", "t5_projection", "t5_pooled_projection",
]
def __init__(
self,
vae: AutoencoderKL,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
tokenizer: T5Tokenizer,
t5_encoder=None,
t5_projection=None,
t5_pooled_projection=None,
image_encoder: CLIPVisionModelWithProjection = None,
feature_extractor: CLIPImageProcessor = None,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
):
DiffusionPipeline.__init__(self)
if t5_encoder is None:
self.t5_encoder = T5EncoderModel.from_pretrained(T5_NAME,
torch_dtype=unet.dtype)
else:
self.t5_encoder = t5_encoder
# ----- build T5 4096 => 2048 dim projection -----
if t5_projection is None:
self.t5_projection = LinearWithDtype(4096, 2048) # trainable
else:
self.t5_projection = t5_projection
self.t5_projection.to(dtype=unet.dtype)
# ----- build T5 4096 => 1280 dim projection -----
if t5_pooled_projection is None:
self.t5_pooled_projection = LinearWithDtype(4096, 1280) # trainable
else:
self.t5_pooled_projection = t5_pooled_projection
self.t5_pooled_projection.to(dtype=unet.dtype)
print("dtype of Linear is ",self.t5_projection.dtype)
self.register_modules(
vae=vae,
unet=unet,
scheduler=scheduler,
tokenizer=tokenizer,
t5_encoder=self.t5_encoder,
t5_projection=self.t5_projection,
t5_pooled_projection=self.t5_pooled_projection,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
self.watermark = None
# Parts of original SDXL class complain if these attributes are not
# at least PRESENT
self.text_encoder = self.text_encoder_2 = None
# ------------------------------------------------------------------------
# Encode a text prompt
# Use + 4096 => 2048 projection for standard embeds, but
# 4096 => 1280 for pooled embeds, because that's what the unet requires.
# Returns exactly four tensors in the order SDXL's __call__ expects.
# ------------------------------------------------------------------------
def encode_prompt(
self,
prompt,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: str | None = None,
**_,
):
"""
Returns
-------
prompt_embeds : Tensor [B, T, 2048]
negative_prompt_embeds : Tensor [B, T, 2048] | None
pooled_prompt_embeds : Tensor [B, 1280]
negative_pooled_prompt_embeds: Tensor [B, 1280] | None
where B = batch * num_images_per_prompt
"""
# --- helper to tokenize on the pipeline's device ----------------
def _tok(text: str):
tok_out = self.tokenizer(
text,
return_tensors="pt",
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
).to(self.device)
return tok_out.input_ids, tok_out.attention_mask
# ---------- positive stream -------------------------------------
ids, mask = _tok(prompt)
h_pos = self.t5_encoder(ids, attention_mask=mask).last_hidden_state # [b, T, 4096]
tok_pos = self.t5_projection(h_pos) # [b, T, 2048]
pool_pos = self.t5_pooled_projection(h_pos.mean(dim=1)) # [b, 1280]
# expand for multiple images per prompt
tok_pos = tok_pos.repeat_interleave(num_images_per_prompt, 0)
pool_pos = pool_pos.repeat_interleave(num_images_per_prompt, 0)
# ---------- negative / CFG stream --------------------------------
if do_classifier_free_guidance:
neg_text = "" if negative_prompt is None else negative_prompt
ids_n, mask_n = _tok(neg_text)
h_neg = self.t5_encoder(ids_n, attention_mask=mask_n).last_hidden_state
tok_neg = self.t5_projection(h_neg)
pool_neg = self.t5_pooled_projection(h_neg.mean(dim=1))
tok_neg = tok_neg.repeat_interleave(num_images_per_prompt, 0)
pool_neg = pool_neg.repeat_interleave(num_images_per_prompt, 0)
else:
tok_neg = pool_neg = None
return tok_pos, tok_neg, pool_pos, pool_neg
|