Text Generation
Transformers
Safetensors
English
idefics2
code
custom_code
text-generation-inference
Instructions to use HuggingFaceM4/idefics2_raven_finetuned with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use HuggingFaceM4/idefics2_raven_finetuned with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="HuggingFaceM4/idefics2_raven_finetuned", trust_remote_code=True)# Load model directly from transformers import AutoProcessor, AutoModelForCausalLM processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2_raven_finetuned", trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained("HuggingFaceM4/idefics2_raven_finetuned", trust_remote_code=True) - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use HuggingFaceM4/idefics2_raven_finetuned with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "HuggingFaceM4/idefics2_raven_finetuned" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "HuggingFaceM4/idefics2_raven_finetuned", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/HuggingFaceM4/idefics2_raven_finetuned
- SGLang
How to use HuggingFaceM4/idefics2_raven_finetuned with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "HuggingFaceM4/idefics2_raven_finetuned" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "HuggingFaceM4/idefics2_raven_finetuned", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "HuggingFaceM4/idefics2_raven_finetuned" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "HuggingFaceM4/idefics2_raven_finetuned", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use HuggingFaceM4/idefics2_raven_finetuned with Docker Model Runner:
docker model run hf.co/HuggingFaceM4/idefics2_raven_finetuned
| # coding=utf-8 | |
| # Copyright 2024 the HuggingFace Inc. team. All rights reserved. | |
| # | |
| # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX | |
| # and OPT implementations in this library. It has been modified from its | |
| # original forms to accommodate minor architectural differences compared | |
| # to GPT-NeoX and OPT used by the Meta AI team that trained the model. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import repeat | |
| from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask | |
| from transformers.utils import is_flash_attn_2_available | |
| from transformers.utils import logging | |
| from .common import MLP, RMSNorm | |
| if is_flash_attn_2_available(): | |
| from flash_attn import flash_attn_func, flash_attn_varlen_func | |
| from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa | |
| logger = logging.get_logger(__name__) | |
| # Copied from transformers.models.llama.modeling_llama._get_unpad_data | |
| def _get_unpad_data(attention_mask): | |
| seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | |
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | |
| max_seqlen_in_batch = seqlens_in_batch.max().item() | |
| cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) | |
| return ( | |
| indices, | |
| cu_seqlens, | |
| max_seqlen_in_batch, | |
| ) | |
| # Copied from transformers.models.llama.modeling_llama.repeat_kv | |
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
| """ | |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, | |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) | |
| """ | |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape | |
| if n_rep == 1: | |
| return hidden_states | |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) | |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | |
| class PerceiverAttention(nn.Module): | |
| def __init__(self, config) -> None: | |
| """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`""" | |
| super().__init__() | |
| self.config = config | |
| self.hidden_size = config.hidden_size | |
| self.num_heads = config.perceiver_config.resampler_n_heads | |
| self.head_dim = config.perceiver_config.resampler_head_dim | |
| self.num_key_value_heads = config.perceiver_config.num_key_value_heads | |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads | |
| self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver | |
| self.attention_dropout = config.perceiver_config.attention_dropout | |
| if self.qk_layer_norms: | |
| self.q_layer_norm = RMSNorm(self.head_dim) | |
| self.k_layer_norm = RMSNorm(self.head_dim) | |
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) | |
| self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) | |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) | |
| self.is_causal = False | |
| def forward( | |
| self, | |
| latents: torch.Tensor, | |
| context: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
| output_attentions: bool = False, | |
| use_cache: bool = False, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
| """ | |
| Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension! | |
| :param context: Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample. | |
| :param latents: Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to. | |
| :return: Tensor of shape [bsz, n_latents, embed_dim] representing attention over latents w/ cross from context. | |
| """ | |
| bsz, q_len, _ = latents.size() | |
| kv_seq_len = q_len + context.size()[1] | |
| query_states = self.q_proj(latents).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | |
| key_states = ( | |
| self.k_proj(torch.cat([context, latents], dim=-2)) | |
| .view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim) | |
| .transpose(1, 2) | |
| ) | |
| value_states = ( | |
| self.v_proj(torch.cat([context, latents], dim=-2)) | |
| .view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim) | |
| .transpose(1, 2) | |
| ) | |
| kv_seq_len = key_states.shape[-2] | |
| if past_key_value is not None: | |
| kv_seq_len += past_key_value[0].shape[-2] | |
| # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | |
| # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) | |
| if past_key_value is not None: | |
| # reuse k, v, self_attention | |
| key_states = torch.cat([past_key_value[0], key_states], dim=2) | |
| value_states = torch.cat([past_key_value[1], value_states], dim=2) | |
| past_key_value = (key_states, value_states) if use_cache else None | |
| if self.qk_layer_norms: | |
| query_states = self.q_layer_norm(query_states) | |
| key_states = self.k_layer_norm(key_states) | |
| # repeat k/v heads if n_kv_heads < n_heads | |
| key_states = repeat_kv(key_states, self.num_key_value_groups) | |
| value_states = repeat_kv(value_states, self.num_key_value_groups) | |
| attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) | |
| if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): | |
| raise ValueError( | |
| f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" | |
| f" {attn_weights.size()}" | |
| ) | |
| if attention_mask is not None: | |
| if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): | |
| raise ValueError( | |
| f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" | |
| ) | |
| attn_weights = attn_weights + attention_mask | |
| # upcast attention to fp32 | |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) | |
| attn_output = torch.matmul(attn_weights, value_states) | |
| if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): | |
| raise ValueError( | |
| f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" | |
| f" {attn_output.size()}" | |
| ) | |
| attn_output = attn_output.transpose(1, 2).contiguous() | |
| attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) | |
| attn_output = self.o_proj(attn_output) | |
| if not output_attentions: | |
| attn_weights = None | |
| return attn_output, attn_weights, past_key_value | |
| class PerceiverFlashAttention2(PerceiverAttention): | |
| """ | |
| Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays | |
| untouched. The only required change would be on the forward pass where it needs to correctly call the public API of | |
| flash attention and deal with padding tokens in case the input contains any of them. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def forward( | |
| self, | |
| latents: torch.Tensor, | |
| context: torch.Tensor, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
| output_attentions: bool = False, | |
| use_cache: bool = False, | |
| **kwargs, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
| bsz, q_len, _ = latents.size() | |
| kv_seq_len = q_len + context.size()[1] | |
| # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn! | |
| # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents` | |
| query_states = self.q_proj(latents) | |
| key_states = self.k_proj(torch.cat([context, latents], dim=-2)) | |
| value_states = self.v_proj(torch.cat([context, latents], dim=-2)) | |
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | |
| key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
| value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
| kv_seq_len = key_states.shape[-2] | |
| if past_key_value is not None: | |
| kv_seq_len += past_key_value[0].shape[-2] | |
| if past_key_value is not None: | |
| # Activate slicing cache only if the config has a value `sliding_windows` attribute | |
| if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window: | |
| slicing_tokens = kv_seq_len - self.config.sliding_window | |
| past_key = past_key_value[0] | |
| past_value = past_key_value[1] | |
| past_key = past_key[:, :, slicing_tokens:, :].contiguous() | |
| past_value = past_value[:, :, slicing_tokens:, :].contiguous() | |
| if past_key.shape[-2] != self.config.sliding_window - 1: | |
| raise ValueError( | |
| "past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1," | |
| f" head_dim`), got {past_key.shape}" | |
| ) | |
| past_key_value = (past_key, past_value) | |
| if attention_mask is not None: | |
| attention_mask = attention_mask[:, slicing_tokens:] | |
| attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) | |
| key_states = torch.cat([past_key_value[0], key_states], dim=2) | |
| value_states = torch.cat([past_key_value[1], value_states], dim=2) | |
| past_key_value = (key_states, value_states) if use_cache else None | |
| # repeat k/v heads if n_kv_heads < n_heads | |
| key_states = repeat_kv(key_states, self.num_key_value_groups) | |
| value_states = repeat_kv(value_states, self.num_key_value_groups) | |
| dropout_rate = 0.0 if not self.training else self.attention_dropout | |
| # In PEFT, usually we cast the layer norms in float32 for training stability reasons | |
| # therefore the input hidden states gets silently casted in float32. Hence, we need | |
| # cast them back in float16 just to be sure everything works as expected. | |
| input_dtype = query_states.dtype | |
| if input_dtype == torch.float32: | |
| # Handle the case where the model is quantized | |
| if hasattr(self.config, "_pre_quantization_dtype"): | |
| target_dtype = self.config._pre_quantization_dtype | |
| else: | |
| target_dtype = self.q_proj.weight.dtype | |
| logger.warning_once( | |
| "The input hidden states seems to be silently casted in float32, this might be related to the fact" | |
| " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | |
| f" {target_dtype}." | |
| ) | |
| query_states = query_states.to(target_dtype) | |
| key_states = key_states.to(target_dtype) | |
| value_states = value_states.to(target_dtype) | |
| # Reashape to the expected shape for Flash Attention | |
| query_states = query_states.transpose(1, 2) | |
| key_states = key_states.transpose(1, 2) | |
| value_states = value_states.transpose(1, 2) | |
| attn_output = self._flash_attention_forward( | |
| query_states, | |
| key_states, | |
| value_states, | |
| attention_mask, | |
| q_len, | |
| dropout=dropout_rate, | |
| use_sliding_windows=False, | |
| ) | |
| attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() | |
| attn_output = self.o_proj(attn_output) | |
| if not output_attentions: | |
| attn_weights = None | |
| return attn_output, attn_weights, past_key_value | |
| def _flash_attention_forward( | |
| self, | |
| query_states, | |
| key_states, | |
| value_states, | |
| attention_mask, | |
| query_length, | |
| dropout=0.0, | |
| softmax_scale=None, | |
| use_sliding_windows=False, | |
| ): | |
| """ | |
| Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token | |
| first unpad the input, then computes the attention scores and pad the final attention scores. | |
| Args: | |
| query_states (`torch.Tensor`): | |
| Input query states to be passed to Flash Attention API | |
| key_states (`torch.Tensor`): | |
| Input key states to be passed to Flash Attention API | |
| value_states (`torch.Tensor`): | |
| Input value states to be passed to Flash Attention API | |
| attention_mask (`torch.Tensor`): | |
| The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the | |
| position of padding tokens and 1 for the position of non-padding tokens. | |
| dropout (`int`, *optional*): | |
| Attention dropout | |
| softmax_scale (`float`, *optional*): | |
| The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) | |
| use_sliding_windows (`bool`, *optional*): | |
| Whether to activate sliding window attention. | |
| """ | |
| # Contains at least one padding token in the sequence | |
| if attention_mask is not None: | |
| batch_size = query_states.shape[0] | |
| query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( | |
| query_states, key_states, value_states, attention_mask, query_length | |
| ) | |
| cu_seqlens_q, cu_seqlens_k = cu_seq_lens | |
| max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | |
| if not use_sliding_windows: | |
| attn_output_unpad = flash_attn_varlen_func( | |
| query_states, | |
| key_states, | |
| value_states, | |
| cu_seqlens_q=cu_seqlens_q, | |
| cu_seqlens_k=cu_seqlens_k, | |
| max_seqlen_q=max_seqlen_in_batch_q, | |
| max_seqlen_k=max_seqlen_in_batch_k, | |
| dropout_p=dropout, | |
| softmax_scale=softmax_scale, | |
| causal=self.is_causal, | |
| ) | |
| else: | |
| attn_output_unpad = flash_attn_varlen_func( | |
| query_states, | |
| key_states, | |
| value_states, | |
| cu_seqlens_q=cu_seqlens_q, | |
| cu_seqlens_k=cu_seqlens_k, | |
| max_seqlen_q=max_seqlen_in_batch_q, | |
| max_seqlen_k=max_seqlen_in_batch_k, | |
| dropout_p=dropout, | |
| softmax_scale=softmax_scale, | |
| causal=self.is_causal, | |
| window_size=(self.config.sliding_window, self.config.sliding_window), | |
| ) | |
| attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) | |
| else: | |
| if not use_sliding_windows: | |
| attn_output = flash_attn_func( | |
| query_states, | |
| key_states, | |
| value_states, | |
| dropout, | |
| softmax_scale=softmax_scale, | |
| causal=self.is_causal, | |
| ) | |
| else: | |
| attn_output = flash_attn_func( | |
| query_states, | |
| key_states, | |
| value_states, | |
| dropout, | |
| softmax_scale=softmax_scale, | |
| causal=self.is_causal, | |
| window_size=(self.config.sliding_window, self.config.sliding_window), | |
| ) | |
| return attn_output | |
| def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): | |
| indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) | |
| batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape | |
| key_layer = index_first_axis( | |
| key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k | |
| ) | |
| value_layer = index_first_axis( | |
| value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k | |
| ) | |
| if query_length == kv_seq_len: | |
| query_layer = index_first_axis( | |
| query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k | |
| ) | |
| cu_seqlens_q = cu_seqlens_k | |
| max_seqlen_in_batch_q = max_seqlen_in_batch_k | |
| indices_q = indices_k | |
| elif query_length == 1: | |
| max_seqlen_in_batch_q = 1 | |
| cu_seqlens_q = torch.arange( | |
| batch_size + 1, dtype=torch.int32, device=query_layer.device | |
| ) # There is a memcpy here, that is very bad. | |
| indices_q = cu_seqlens_q[:-1] | |
| query_layer = query_layer.squeeze(1) | |
| else: | |
| # The -q_len: slice assumes left padding. | |
| attention_mask = attention_mask[:, -query_length:] | |
| query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) | |
| return ( | |
| query_layer, | |
| key_layer, | |
| value_layer, | |
| indices_q, | |
| (cu_seqlens_q, cu_seqlens_k), | |
| (max_seqlen_in_batch_q, max_seqlen_in_batch_k), | |
| ) | |
| class PerceiverLayer(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.hidden_size = config.hidden_size | |
| self.hidden_act = config.perceiver_config.hidden_act | |
| self.n_latents = config.perceiver_config.resampler_n_latents | |
| self.depth = config.perceiver_config.resampler_depth | |
| self.rms_norm_eps = config.rms_norm_eps | |
| self.intermediate_size = self.hidden_size * 4 | |
| self.input_latents_norm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) | |
| self.input_context_norm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) | |
| self.self_attn = ( | |
| PerceiverAttention(config) | |
| if not getattr(config, "_flash_attn_2_enabled", False) | |
| else PerceiverFlashAttention2(config) | |
| ) | |
| self.post_attention_layernorm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) | |
| self.mlp = MLP( | |
| activation=self.hidden_act, | |
| input_size=self.hidden_size, | |
| intermediate_size=self.intermediate_size, | |
| output_size=self.hidden_size, | |
| ) | |
| def forward( | |
| self, | |
| latents: torch.Tensor, | |
| context: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
| output_attentions: Optional[bool] = False, | |
| use_cache: Optional[bool] = False, | |
| **kwargs, | |
| ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: | |
| """ | |
| Args: | |
| latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | |
| context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | |
| attention_mask (`torch.FloatTensor`, *optional*): attention mask of size | |
| `(batch, sequence_length)` where padding elements are indicated by 0. | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
| returned tensors for more detail. | |
| use_cache (`bool`, *optional*): | |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
| (see `past_key_values`). | |
| past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states | |
| """ | |
| residual = latents | |
| latents = self.input_latents_norm(latents) | |
| context = self.input_context_norm(context) | |
| if attention_mask is None: | |
| attention_mask = torch.ones( | |
| size=(context.size(0), context.size(1)), | |
| dtype=torch.bool, | |
| device=context.device, | |
| ) | |
| attention_mask = torch.cat( | |
| [ | |
| attention_mask, | |
| torch.ones( | |
| size=(attention_mask.size(0), latents.size(1)), | |
| dtype=attention_mask.dtype, | |
| device=attention_mask.device, | |
| ), | |
| ], | |
| dim=-1, | |
| ) | |
| latents, self_attn_weights, present_key_value = self.self_attn( | |
| latents=latents, | |
| context=context, | |
| attention_mask=( | |
| _prepare_4d_attention_mask(attention_mask, latents.dtype, tgt_len=self.n_latents) | |
| if not self.config._flash_attn_2_enabled | |
| else attention_mask | |
| ), | |
| ) | |
| latents = residual + latents | |
| residual = latents | |
| latents = self.post_attention_layernorm(latents) | |
| latents = self.mlp(latents) | |
| latents = residual + latents | |
| outputs = (latents,) | |
| if output_attentions: | |
| outputs += (self_attn_weights,) | |
| if use_cache: | |
| outputs += (present_key_value,) | |
| return outputs | |
| class PerceiverResampler(nn.Module): | |
| def __init__( | |
| self, | |
| config, | |
| ) -> None: | |
| """ | |
| Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or | |
| MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then | |
| returns a Tensor of shape [bsz, n_latents, embed_dim]. | |
| :param embed_dim: Dimensionality of embeddings being fed to the Perceiver Resampler (also dimensionality of | |
| latent embeddings *returned* by the Perceiver Resampler. Could be e.g., VIT embed_dim, ResNet | |
| pool dim, and so on. | |
| :param depth: Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3). | |
| :param n_heads: Number of heads in each Transformer block (for multi-headed self-attention). | |
| :param head_dim: Dimensionality of each head projection in the Transformer block. | |
| :param n_latents: Number of latent embeddings to resample ("compress") the input sequence to (usually < 128). | |
| """ | |
| super().__init__() | |
| self.config = config | |
| self.hidden_size = config.hidden_size | |
| self.hidden_act = config.perceiver_config.hidden_act | |
| self.n_latents = config.perceiver_config.resampler_n_latents | |
| self.depth = config.perceiver_config.resampler_depth | |
| self.rms_norm_eps = config.rms_norm_eps | |
| # Create Latents for Perceiver | |
| self.latents = nn.Parameter(torch.ones(self.n_latents, self.hidden_size)) | |
| # Create Transformer Blocks | |
| self.layers = nn.ModuleList([PerceiverLayer(config) for _ in range(self.depth)]) | |
| self.norm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) | |
| def forward( | |
| self, | |
| context: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| latents = repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0]) | |
| for perceiver_layer in self.layers: | |
| layer_outputs = perceiver_layer( | |
| latents, | |
| context, | |
| attention_mask=attention_mask, | |
| position_ids=None, | |
| past_key_value=None, | |
| output_attentions=False, | |
| use_cache=False, | |
| ) | |
| latents = layer_outputs[0] | |
| latents = self.norm(latents) | |
| return latents | |