bpiyush commited on
Commit
138a0f7
·
verified ·
1 Parent(s): a0c6980

Upload tarsier/modeling_tarsier.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tarsier/modeling_tarsier.py +779 -0
tarsier/modeling_tarsier.py ADDED
@@ -0,0 +1,779 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # copy and modify from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
16
+ """ PyTorch Llava model."""
17
+ from dataclasses import dataclass
18
+ from typing import List, Optional, Tuple, Union
19
+ import math
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ import torch.nn.functional as F
25
+
26
+ from transformers import PreTrainedModel
27
+ from transformers.activations import ACT2FN
28
+ from transformers.cache_utils import Cache
29
+ from transformers.modeling_outputs import ModelOutput
30
+ from transformers.generation import GenerationMixin
31
+ from transformers.utils import (
32
+ add_start_docstrings,
33
+ add_start_docstrings_to_model_forward,
34
+ logging,
35
+ replace_return_docstrings,
36
+ )
37
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM, CONFIG_MAPPING
38
+ from transformers.configuration_utils import PretrainedConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
44
+ "llava-hf/llava-v1.5-7b": "https://huggingface.co/llava-hf/llava-v1.5-7b/resolve/main/config.json",
45
+ }
46
+
47
+ class LlavaConfig(PretrainedConfig):
48
+ r"""
49
+ This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an
50
+ Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration
51
+ with the defaults will yield a similar configuration to that of the Llava-9B.
52
+
53
+ e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b)
54
+
55
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
56
+ documentation from [`PretrainedConfig`] for more information.
57
+
58
+ Args:
59
+ vision_config (`LlavaVisionConfig`, *optional*):
60
+ Custom vision config or dict
61
+ text_config (`Union[AutoConfig, dict]`, *optional*):
62
+ The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
63
+ ignore_index (`int`, *optional*, defaults to -100):
64
+ The ignore index for the loss function.
65
+ image_token_index (`int`, *optional*, defaults to 32000):
66
+ The image token index to encode the image prompt.
67
+ projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
68
+ The activation function used by the multimodal projector.
69
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
70
+ The feature selection strategy used to select the vision feature from the CLIP backbone.
71
+ vision_feature_layer (`int`, *optional*, defaults to -2):
72
+ The index of the layer to select the vision feature.
73
+ vocab_size (`int`, *optional*, defaults to 32000):
74
+ Vocabulary size of the Llava model. Defines the number of different tokens that can be represented by the
75
+ `inputs_ids` passed when calling [`~LlavaForConditionalGeneration`]
76
+
77
+ Example:
78
+
79
+ ```python
80
+ >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig
81
+
82
+ >>> # Initializing a CLIP-vision config
83
+ >>> vision_config = CLIPVisionConfig()
84
+
85
+ >>> # Initializing a Llama config
86
+ >>> text_config = LlamaConfig()
87
+
88
+ >>> # Initializing a Llava llava-1.5-7b style configuration
89
+ >>> configuration = LlavaConfig(vision_config, text_config)
90
+
91
+ >>> # Initializing a model from the llava-1.5-7b style configuration
92
+ >>> model = LlavaForConditionalGeneration(configuration)
93
+
94
+ >>> # Accessing the model configuration
95
+ >>> configuration = model.config
96
+ ```"""
97
+
98
+ model_type = "llava"
99
+ is_composition = False
100
+
101
+ def __init__(
102
+ self,
103
+ vision_config=None,
104
+ text_config=None,
105
+ ignore_index=-100,
106
+ image_token_index=32000,
107
+ projector_hidden_act="gelu",
108
+ vision_feature_select_strategy="default",
109
+ vision_feature_layer=-2,
110
+ vocab_size=32000,
111
+ image_newline_idx=32002,
112
+ image_new_idx=32003,
113
+ **kwargs,
114
+ ):
115
+ self.ignore_index = ignore_index
116
+ self.image_token_index = image_token_index
117
+ self.projector_hidden_act = projector_hidden_act
118
+ self.vision_feature_select_strategy = vision_feature_select_strategy
119
+ self.vision_feature_layer = vision_feature_layer
120
+ self.vocab_size = vocab_size
121
+ self.image_newline_idx = image_newline_idx
122
+ self.image_new_idx = image_new_idx
123
+
124
+ self.vision_config = vision_config
125
+
126
+ if isinstance(self.vision_config, dict):
127
+ vision_config["model_type"] = (
128
+ vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
129
+ )
130
+ self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
131
+ elif vision_config is None:
132
+ self.vision_config = CONFIG_MAPPING["clip_vision_model"](
133
+ intermediate_size=4096,
134
+ hidden_size=1024,
135
+ patch_size=14,
136
+ image_size=336,
137
+ num_hidden_layers=24,
138
+ num_attention_heads=16,
139
+ vocab_size=32000,
140
+ projection_dim=768,
141
+ )
142
+ self.vocab_size = self.vocab_size
143
+
144
+ self.text_config = text_config
145
+
146
+ if isinstance(self.text_config, dict):
147
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
148
+ self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
149
+ self.vocab_size = self.text_config.vocab_size
150
+ elif text_config is None:
151
+ self.text_config = CONFIG_MAPPING["llama"]()
152
+
153
+ super().__init__(**kwargs)
154
+
155
+
156
+ logger = logging.get_logger(__name__)
157
+
158
+ _CONFIG_FOR_DOC = "LlavaConfig"
159
+
160
+ LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [
161
+ "llava-hf/llava-1.5-7b-hf",
162
+ "llava-hf/llava-1.5-13b-hf",
163
+ "llava-hf/bakLlava-v1-hf",
164
+ # See all Llava models at https://huggingface.co/models?filter=llava
165
+ ]
166
+
167
+
168
+ class Llava3DPositionalEncoding(nn.Module):
169
+ def __init__(self, num_pos, dim) -> None:
170
+ super().__init__()
171
+ dim1, dim2, dim3 = self.split_dim(dim)
172
+ frame_position_encodings = self.create_sinusoidal_positions(num_pos, dim1)
173
+ height_position_encodings = self.create_sinusoidal_positions(num_pos, dim2)
174
+ width_position_encodings = self.create_sinusoidal_positions(num_pos, dim3)
175
+
176
+ self.register_buffer('frame_position_encodings', frame_position_encodings, persistent=False)
177
+ self.register_buffer('height_position_encodings', height_position_encodings, persistent=False)
178
+ self.register_buffer('width_position_encodings', width_position_encodings, persistent=False)
179
+
180
+ def split_dim(self, dim):
181
+ dim1 = dim // 3
182
+ if dim1 % 2 != 0:
183
+ dim1 -= 1
184
+
185
+ dim2 = dim // 3
186
+ if dim2 % 2 != 0:
187
+ dim2 -= 1
188
+
189
+ dim3 = dim - dim1 - dim2
190
+ return dim1, dim2, dim3
191
+
192
+ def create_sinusoidal_positions(self, num_pos: int, dim: int) -> torch.Tensor:
193
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
194
+ sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float()
195
+ return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
196
+
197
+ def forward(self, frame_position_ids, height_position_ids, width_position_ids):
198
+ frame_position_embeds = F.embedding(frame_position_ids, self.frame_position_encodings)
199
+ height_position_embeds = F.embedding(height_position_ids, self.height_position_encodings)
200
+ width_position_embeds = F.embedding(width_position_ids, self.width_position_encodings)
201
+
202
+ return torch.cat([frame_position_embeds, height_position_embeds, width_position_embeds], dim = -1)
203
+
204
+
205
+ @dataclass
206
+ # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava
207
+ class LlavaCausalLMOutputWithPast(ModelOutput):
208
+ """
209
+ Base class for Llava causal language model (or autoregressive) outputs.
210
+
211
+ Args:
212
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
213
+ Language modeling loss (for next-token prediction).
214
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
215
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
216
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
217
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
218
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
219
+
220
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
221
+ `past_key_values` input) to speed up sequential decoding.
222
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
223
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
224
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
225
+
226
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
227
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
228
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
229
+ sequence_length)`.
230
+
231
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
232
+ heads.
233
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
234
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
235
+ sequence_length, hidden_size)`.
236
+
237
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
238
+ """
239
+
240
+ loss: Optional[torch.FloatTensor] = None
241
+ logits: torch.FloatTensor = None
242
+ past_key_values: Optional[List[torch.FloatTensor]] = None
243
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
244
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
245
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
246
+ vision_outputs: Optional[torch.FloatTensor] = None
247
+ llm_attn_mask: Optional[Tuple[torch.FloatTensor]] = None
248
+
249
+
250
+ class LlavaMultiModalProjector(nn.Module):
251
+ def __init__(self, config: LlavaConfig):
252
+ super().__init__()
253
+
254
+ self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
255
+ self.act = ACT2FN[config.projector_hidden_act]
256
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
257
+
258
+ def forward(self, image_features):
259
+ hidden_states = self.linear_1(image_features)
260
+ hidden_states = self.act(hidden_states)
261
+ hidden_states = self.linear_2(hidden_states)
262
+ return hidden_states
263
+
264
+
265
+ TARSIER_START_DOCSTRING = r"""
266
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
267
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
268
+ etc.)
269
+
270
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
271
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
272
+ and behavior.
273
+
274
+ Parameters:
275
+ config ([`LlavaConfig`] or [`LlavaVisionConfig`]):
276
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
277
+ load the weights associated with the model, only the configuration. Check out the
278
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
279
+ """
280
+
281
+
282
+ @add_start_docstrings(
283
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
284
+ TARSIER_START_DOCSTRING,
285
+ )
286
+ class TarsierPreTrainedModel(PreTrainedModel):
287
+ config_class = LlavaConfig
288
+ base_model_prefix = "model"
289
+ supports_gradient_checkpointing = True
290
+ _no_split_modules = ["LlavaVisionAttention"]
291
+ _skip_keys_device_placement = "past_key_values"
292
+ _supports_flash_attn_2 = True
293
+
294
+ def _init_weights(self, module):
295
+ # important: this ported version of Llava isn't meant for training from scratch - only
296
+ # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
297
+ # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
298
+ std = (
299
+ self.config.initializer_range
300
+ if hasattr(self.config, "initializer_range")
301
+ else self.config.text_config.initializer_range
302
+ )
303
+
304
+ if hasattr(module, "class_embedding"):
305
+ module.class_embedding.data.normal_(mean=0.0, std=std)
306
+
307
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
308
+ module.weight.data.normal_(mean=0.0, std=std)
309
+ if module.bias is not None:
310
+ module.bias.data.zero_()
311
+ elif isinstance(module, nn.Embedding):
312
+ module.weight.data.normal_(mean=0.0, std=std)
313
+ if module.padding_idx is not None:
314
+ module.weight.data[module.padding_idx].zero_()
315
+
316
+ @property
317
+ def _supports_sdpa(self):
318
+ """
319
+ Retrieve language_model's attribute to check whether the model supports
320
+ SDPA or not.
321
+ """
322
+ return self.language_model._supports_sdpa
323
+
324
+
325
+ TARSIER_INPUTS_DOCSTRING = r"""
326
+ Args:
327
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
328
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
329
+ it.
330
+
331
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
332
+ [`PreTrainedTokenizer.__call__`] for details.
333
+
334
+ [What are input IDs?](../glossary#input-ids)
335
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
336
+ The tensors corresponding to the input images. Pixel values can be obtained using
337
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
338
+ [`CLIPImageProcessor`] for processing images).
339
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
340
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
341
+
342
+ - 1 for tokens that are **not masked**,
343
+ - 0 for tokens that are **masked**.
344
+
345
+ [What are attention masks?](../glossary#attention-mask)
346
+
347
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
348
+ [`PreTrainedTokenizer.__call__`] for details.
349
+
350
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
351
+ `past_key_values`).
352
+
353
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
354
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
355
+ information on the default strategy.
356
+
357
+ - 1 indicates the head is **not masked**,
358
+ - 0 indicates the head is **masked**.
359
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
360
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
361
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
362
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
363
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
364
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
365
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
366
+
367
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
368
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
369
+
370
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
371
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
372
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
373
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
374
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
375
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
376
+ model's internal embedding lookup matrix.
377
+ use_cache (`bool`, *optional*):
378
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
379
+ `past_key_values`).
380
+ output_attentions (`bool`, *optional*):
381
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
382
+ tensors for more detail.
383
+ output_hidden_states (`bool`, *optional*):
384
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
385
+ more detail.
386
+ return_dict (`bool`, *optional*):
387
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
388
+ """
389
+
390
+
391
+ @add_start_docstrings(
392
+ """The LLAVA model which consists of a vision backbone and a language model.""",
393
+ TARSIER_INPUTS_DOCSTRING,
394
+ )
395
+ class TarsierForConditionalGeneration(TarsierPreTrainedModel, GenerationMixin):
396
+ def __init__(self, config: LlavaConfig):
397
+ super().__init__(config)
398
+ self.vision_tower = AutoModel.from_config(config.vision_config, trust_remote_code=True)
399
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
400
+ self.vocab_size = config.vocab_size
401
+
402
+ use_flash_attn = True
403
+ attn_implementation = "flash_attention_2"
404
+ # If GPU is not compatible, then fall back to sdpa
405
+ from transformers.utils import is_flash_attn_2_available
406
+ if use_flash_attn and not is_flash_attn_2_available():
407
+ use_flash_attn = False
408
+ attn_implementation = "eager"
409
+ print("Flash Attention 2 is not available on this GPU, falling back to Eager.")
410
+
411
+ # Get torch_dtype from config if available, otherwise use bfloat16 as default for Flash Attention
412
+ torch_dtype = getattr(config, 'torch_dtype', None)
413
+ if torch_dtype is None and attn_implementation == "flash_attention_2":
414
+ torch_dtype = torch.bfloat16
415
+
416
+ language_model_kwargs = {"attn_implementation": attn_implementation}
417
+ if torch_dtype is not None:
418
+ language_model_kwargs["torch_dtype"] = torch_dtype
419
+
420
+ self.language_model = AutoModelForCausalLM.from_config(
421
+ config.text_config, **language_model_kwargs
422
+ )
423
+ image_newline_idx = torch.tensor([config.image_newline_idx], dtype=torch.long)
424
+ image_new_idx = torch.tensor([config.image_new_idx], dtype=torch.long)
425
+ self.register_buffer('image_newline_idx', image_newline_idx, persistent=False)
426
+ self.register_buffer('image_new_idx', image_new_idx, persistent=False)
427
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
428
+ self.post_init()
429
+ # Tie weights to avoid warning about untied weights
430
+ self.tie_weights()
431
+
432
+ def get_input_embeddings(self):
433
+ return self.language_model.get_input_embeddings()
434
+
435
+ def set_input_embeddings(self, value):
436
+ self.language_model.set_input_embeddings(value)
437
+
438
+ def get_output_embeddings(self):
439
+ return self.language_model.get_output_embeddings()
440
+
441
+ def set_output_embeddings(self, new_embeddings):
442
+ self.language_model.set_output_embeddings(new_embeddings)
443
+
444
+ def set_decoder(self, decoder):
445
+ self.language_model.set_decoder(decoder)
446
+
447
+ def get_decoder(self):
448
+ return self.language_model.get_decoder()
449
+
450
+ def tie_weights(self):
451
+ return self.language_model.tie_weights()
452
+
453
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
454
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
455
+ # update vocab size
456
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
457
+ self.config.vocab_size = model_embeds.num_embeddings
458
+ self.vocab_size = model_embeds.num_embeddings
459
+ return model_embeds
460
+
461
+ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
462
+ num_images, num_image_patches, embed_dim = image_features.shape
463
+
464
+ batch_size, sequence_length = input_ids.shape
465
+ left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
466
+ # 1. Create a mask to know where special image tokens are
467
+ special_image_token_mask = input_ids == self.config.image_token_index
468
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
469
+ # Compute the maximum embed dimension
470
+ max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
471
+ batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
472
+
473
+ # 2. Compute the positions where text should be written
474
+ # Calculate new positions for text tokens in merged image-text sequence.
475
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
476
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
477
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
478
+ new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
479
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
480
+ if left_padding:
481
+ new_token_positions += nb_image_pad[:, None] # offset for left padding
482
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
483
+
484
+ # 3. Create the full embedding, already padded to the maximum position
485
+ final_embedding = torch.zeros(
486
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
487
+ )
488
+ final_attention_mask = torch.zeros(
489
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
490
+ )
491
+ if labels is not None:
492
+ final_labels = torch.full(
493
+ (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
494
+ )
495
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
496
+ # set the corresponding tensors into their correct target device.
497
+ target_device = inputs_embeds.device
498
+ batch_indices, non_image_indices, text_to_overwrite = (
499
+ batch_indices.to(target_device),
500
+ non_image_indices.to(target_device),
501
+ text_to_overwrite.to(target_device),
502
+ )
503
+ attention_mask = attention_mask.to(target_device)
504
+
505
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
506
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
507
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
508
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
509
+ if labels is not None:
510
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
511
+
512
+ # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
513
+ image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
514
+ image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
515
+
516
+ if image_to_overwrite.sum() != image_features.shape[:-1].numel():
517
+ raise ValueError(
518
+ f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
519
+ f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
520
+ )
521
+
522
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
523
+ final_attention_mask |= image_to_overwrite
524
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
525
+
526
+ if labels is None:
527
+ final_labels = None
528
+
529
+ return final_embedding, final_attention_mask, final_labels, position_ids
530
+
531
+ def add_split_tokens(self, image_features):
532
+ num_images, num_image_patches, embed_dim = image_features.shape
533
+ num_height_patches, num_width_patches = int(math.sqrt(num_image_patches)), int(math.sqrt(num_image_patches))
534
+
535
+ # add image_newline
536
+ image_newline = self.get_input_embeddings()(self.image_newline_idx).squeeze()
537
+ image_features = image_features.view(num_images, num_height_patches, num_width_patches, embed_dim)
538
+ image_features = torch.cat([
539
+ image_features,
540
+ image_newline.expand((num_images, num_height_patches, 1, embed_dim)).to(device=image_features.device)
541
+ ], dim=2)
542
+ num_image_patches += num_height_patches
543
+ image_features = image_features.view(num_images, num_image_patches, embed_dim)
544
+
545
+ # add image_new
546
+ image_new = self.get_input_embeddings()(self.image_new_idx).squeeze()
547
+ image_features = torch.cat([
548
+ image_features,
549
+ image_new.expand((num_images, 1, embed_dim)).to(device=image_features.device)
550
+ ], dim = 1)
551
+
552
+ return image_features
553
+
554
+ @add_start_docstrings_to_model_forward(TARSIER_INPUTS_DOCSTRING)
555
+ @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
556
+ def forward(
557
+ self,
558
+ input_ids: torch.LongTensor = None,
559
+ pixel_values: torch.FloatTensor = None,
560
+ attention_mask: Optional[torch.Tensor] = None,
561
+ position_ids: Optional[torch.LongTensor] = None,
562
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
563
+ inputs_embeds: Optional[torch.FloatTensor] = None,
564
+ vision_feature_layer: Optional[int] = None,
565
+ vision_feature_select_strategy: Optional[str] = None,
566
+ labels: Optional[torch.LongTensor] = None,
567
+ use_cache: Optional[bool] = None,
568
+ output_attentions: Optional[bool] = None,
569
+ output_hidden_states: Optional[bool] = None,
570
+ return_dict: Optional[bool] = None,
571
+ **kwargs,
572
+ ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
573
+ r"""
574
+ Args:
575
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
576
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
577
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
578
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
579
+
580
+ Returns:
581
+
582
+ Example:
583
+
584
+ ```python
585
+ >>> from PIL import Image
586
+ >>> import requests
587
+ >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
588
+
589
+ >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
590
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
591
+
592
+ >>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
593
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
594
+ >>> image = Image.open(requests.get(url, stream=True).raw)
595
+
596
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
597
+
598
+ >>> # Generate
599
+ >>> generate_ids = model.generate(**inputs, max_length=30)
600
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
601
+ "\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
602
+ ```"""
603
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
604
+ output_hidden_states = (
605
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
606
+ )
607
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
608
+ vision_feature_layer = (
609
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
610
+ )
611
+ vision_feature_select_strategy = (
612
+ vision_feature_select_strategy
613
+ if vision_feature_select_strategy is not None
614
+ else self.config.vision_feature_select_strategy
615
+ )
616
+
617
+ image_features = None
618
+ if inputs_embeds is None:
619
+ # 1. Extra the input embeddings
620
+ inputs_embeds = self.get_input_embeddings()(input_ids)
621
+
622
+ # 2. Merge text and images
623
+ if pixel_values is not None and input_ids.shape[1] != 1:
624
+ pixel_values = pixel_values.to(dtype=self.vision_tower.dtype)
625
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
626
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
627
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
628
+
629
+ if vision_feature_select_strategy == "default":
630
+ selected_image_feature = selected_image_feature[:, 1:]
631
+ elif vision_feature_select_strategy == "full":
632
+ selected_image_feature = selected_image_feature
633
+ else:
634
+ raise ValueError(
635
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
636
+ )
637
+
638
+ image_features = self.multi_modal_projector(selected_image_feature)
639
+
640
+ special_image_token_mask = input_ids == self.config.image_token_index
641
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim = -1)
642
+
643
+ image_features = self.add_split_tokens(image_features)
644
+
645
+ if sum(num_special_image_tokens) > 0:
646
+ # print(f'num_special_image_tokens: {num_special_image_tokens}')
647
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
648
+ image_features, inputs_embeds, input_ids, attention_mask, labels
649
+ )
650
+ else:
651
+ inputs_embeds = image_features.sum(dim=(0,1))[None, None, :] * 0. + inputs_embeds
652
+
653
+ if labels is None:
654
+ labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
655
+ else:
656
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
657
+ # generation with cache
658
+ if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
659
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
660
+ # that are set to 0
661
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
662
+
663
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
664
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
665
+
666
+ # Get the target length
667
+ target_seqlen = first_layer_past_key_value.shape[-1] + 1
668
+ extended_attention_mask = torch.ones(
669
+ (attention_mask.shape[0], target_seqlen),
670
+ dtype=attention_mask.dtype,
671
+ device=attention_mask.device,
672
+ )
673
+
674
+ extended_attention_mask[batch_index, non_attended_tokens] = 0
675
+
676
+ valid_indices = torch.ones_like(attention_mask)
677
+ valid_indices[:, 0] = target_seqlen - extended_attention_mask.sum(dim=-1)
678
+ valid_indices = torch.cumsum(valid_indices, dim=-1)
679
+ extended_attention_mask = extended_attention_mask.scatter(1, valid_indices, attention_mask)
680
+ attention_mask = extended_attention_mask
681
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
682
+ outputs = self.language_model(
683
+ attention_mask=attention_mask,
684
+ position_ids=position_ids,
685
+ past_key_values=past_key_values,
686
+ inputs_embeds=inputs_embeds,
687
+ use_cache=use_cache,
688
+ output_attentions=output_attentions,
689
+ output_hidden_states=output_hidden_states,
690
+ # use_rmpad=kwargs.get("use_rmpad", False),
691
+ return_dict=return_dict,
692
+ )
693
+
694
+ logits = outputs[0]
695
+
696
+ loss = None
697
+ if labels is not None:
698
+ # Shift so that tokens < n predict n
699
+ if attention_mask is not None:
700
+ shift_attention_mask = attention_mask[..., 1:]
701
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
702
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
703
+ else:
704
+ shift_logits = logits[..., :-1, :].contiguous()
705
+ shift_labels = labels[..., 1:].contiguous()
706
+ # Flatten the tokens
707
+ loss_fct = nn.CrossEntropyLoss()
708
+ loss = loss_fct(
709
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
710
+ )
711
+
712
+ if not return_dict:
713
+ output = (logits,) + outputs[1:]
714
+ return (loss,) + output if loss is not None else output
715
+
716
+ return LlavaCausalLMOutputWithPast(
717
+ loss=loss,
718
+ logits=logits,
719
+ past_key_values=outputs.past_key_values,
720
+ hidden_states=outputs.hidden_states,
721
+ attentions=outputs.attentions,
722
+ llm_attn_mask=attention_mask
723
+ )
724
+
725
+ def prepare_inputs_for_generation(
726
+ self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
727
+ ):
728
+ if past_key_values is not None:
729
+ if isinstance(past_key_values, Cache):
730
+ cache_length = past_key_values.get_seq_length()
731
+ past_length = past_key_values.seen_tokens
732
+ else:
733
+ cache_length = past_length = past_key_values[0][0].shape[2]
734
+
735
+ # Keep only the unprocessed tokens:
736
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
737
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
738
+ # input)
739
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
740
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
741
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
742
+ # input_ids based on the past_length.
743
+ elif past_length < input_ids.shape[1]:
744
+ input_ids = input_ids[:, past_length:]
745
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
746
+ elif self.config.image_token_index in input_ids:
747
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
748
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
749
+ # older attention values, as their corresponding values are not part of the input.
750
+ if cache_length < past_length and attention_mask is not None:
751
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
752
+
753
+ position_ids = kwargs.get("position_ids", None)
754
+ if attention_mask is not None and position_ids is None:
755
+ # create position_ids on the fly for batch generation
756
+ position_ids = attention_mask.long().cumsum(-1) - 1
757
+ position_ids.masked_fill_(attention_mask == 0, 1)
758
+ if past_key_values:
759
+ position_ids = position_ids[:, -input_ids.shape[1] :]
760
+
761
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
762
+ if inputs_embeds is not None and past_key_values is None:
763
+ model_inputs = {"inputs_embeds": inputs_embeds}
764
+ else:
765
+ model_inputs = {"input_ids": input_ids}
766
+
767
+ model_inputs.update(
768
+ {
769
+ "position_ids": position_ids,
770
+ "past_key_values": past_key_values,
771
+ "use_cache": kwargs.get("use_cache"),
772
+ "attention_mask": attention_mask,
773
+ "pixel_values": pixel_values,
774
+ }
775
+ )
776
+ return model_inputs
777
+
778
+ def _reorder_cache(self, *args, **kwargs):
779
+ return self.language_model._reorder_cache(*args, **kwargs)