Image-Text-to-Text
Transformers
Safetensors
English
Helium1_VL_2B
custom_code
Helium1-VL-2B / casa_attention.py
ameroyer's picture
Super-squash branch 'main' using huggingface_hub
1126ea7 verified
"""CASA layers"""
import bisect
from dataclasses import dataclass
from itertools import accumulate
from typing import TYPE_CHECKING, Callable, Literal, Sequence, TypedDict, overload
from typing import cast as type_cast
import torch
from transformers.configuration_utils import PretrainedConfig
from .utils import StreamingModule, StreamingState, delta_w_factory
if TYPE_CHECKING:
from transformers.configuration_utils import PretrainedConfig
try:
from flash_attn import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None # type: ignore
WindowsComputeKwargs = TypedDict(
"WindowsComputeKwargs",
{
"num_post_image_tokens": int,
"num_pre_image_tokens": int,
},
total=False,
)
def __split_n_merge__(
x: torch.Tensor,
sample_lengths: list[int],
padding_side: Literal["left", "right"] = "right",
pad_value: int | float | bool = 0,
) -> torch.Tensor:
max_sample_length = max(sample_lengths)
pad_tuple = tuple(0 for _ in range((x.ndim - 1) * 2))
return torch.stack(
[
torch.nn.functional.pad(
_x,
pad_tuple + (0, max_sample_length - _x.shape[0])
if padding_side == "right"
else pad_tuple + (max_sample_length - _x.shape[0], 0),
value=pad_value,
)
for _x in torch.split(x, sample_lengths, dim=0)
],
dim=0,
)
@overload
def insert_image_tokens(
inputs_embeds: torch.Tensor,
image_embeds: torch.Tensor | Sequence[torch.Tensor],
image_embeds_insertion_points: list[torch.Tensor],
recover_batch_dim: Literal[True],
attention_mask: torch.Tensor | None = None,
padding_side: Literal["left", "right"] = "right",
keep_only_attended: bool = False,
pad_output: int | float | bool = 0.0,
) -> tuple[
torch.Tensor,
None,
torch.Tensor | None,
torch.Tensor,
]: ...
@overload
def insert_image_tokens(
inputs_embeds: torch.Tensor,
image_embeds: torch.Tensor | Sequence[torch.Tensor],
image_embeds_insertion_points: list[torch.Tensor],
recover_batch_dim: Literal[False],
attention_mask: torch.Tensor | None = None,
padding_side: Literal["left", "right"] = "right",
keep_only_attended: bool = False,
pad_output: int | float | bool = 0.0,
) -> tuple[
torch.Tensor,
list[int],
torch.Tensor | None,
torch.Tensor,
]: ...
def insert_image_tokens(
inputs_embeds: torch.Tensor,
image_embeds: torch.Tensor | Sequence[torch.Tensor],
image_embeds_insertion_points: list[torch.Tensor],
recover_batch_dim: bool = True,
attention_mask: torch.Tensor | None = None,
padding_side: Literal["left", "right"] = "right",
keep_only_attended: bool = False,
pad_output: int | float | bool = 0.0,
) -> tuple[
torch.Tensor | torch.Tensor,
list[int] | None,
torch.Tensor | torch.Tensor | None,
torch.Tensor | torch.Tensor,
]:
"""
Insert image embeddings into text embeddings
Args:
inputs_embeds (torch.Tensor): (B, S, D) input token embeddings.
image_embeds (torch.Tensor | list[torch.Tensor]): (N_images, Nt, D) | List[(Nt, D)] image token embeddings.
image_embeds_insertion_points (list[torch.Tensor]): Insertion indices.
attention_mask (torch.Tensor, optional): (B, S) attention mask.
padding_side (Literal["left", "right"]): Padding scheme. Controls behavior for padded images.
return_indices (bool): Whether to return gather indices or the fused sequence directly.
keep_only_attended: This is only applicable when recover_batch_dim is False; whether to
remove any non-attended tokens in the whole array. In this case, the attention
mask returned is **still the original one**, so we can remember which indices have been
removed
Returns:
output (torch.Tensor): (B, S + Ni * Nt) gather indices or (B, S + Ni * Nt, D) fused sequence
image_embeds (torch.Tensor): (B, Ni * Nt) image embeds, padded and batch if input was a list
attention_mask (torch.Tensor): Same shape, 1 for real tokens, 0 for image and text padding.
image_tokens_mask (torch.Tensor): (B, S + Ni * Nt, 1), marks image token positions.
"""
if isinstance(image_embeds, list) and len(image_embeds) == 0:
batch_size, text_seq_length, token_dim = inputs_embeds.shape
if recover_batch_dim:
return (
inputs_embeds,
None,
attention_mask,
torch.zeros((batch_size, text_seq_length, 1), dtype=torch.bool),
)
else:
flattened_seq_length = inputs_embeds.shape[0] * inputs_embeds.shape[1]
return (
torch.reshape(inputs_embeds, (flattened_seq_length, inputs_embeds.shape[2])),
[text_seq_length] * inputs_embeds.shape[0],
attention_mask.flatten() if attention_mask is not None else None,
torch.zeros((flattened_seq_length, 1), dtype=torch.bool),
)
# Sanity checks
if isinstance(image_embeds, torch.Tensor):
assert inputs_embeds.shape[-1] == image_embeds.shape[-1]
else:
assert all(inputs_embeds.shape[-1] == _x.shape[-1] for _x in image_embeds)
batch_size, text_seq_length, token_dim = inputs_embeds.shape
image_seq_length = [x.shape[0] for x in image_embeds]
# Flatten insertion points
insertion_offset = []
counter, offset_from_text, offset_from_image = 0, 0, 0
for sample in image_embeds_insertion_points:
for pt in sample:
insertion_offset.append(pt + offset_from_image + offset_from_text)
offset_from_image += image_seq_length[counter]
counter += 1
offset_from_text += text_seq_length
image_insert_positions = [
x for idx, pt in enumerate(insertion_offset) for x in range(pt, pt + image_seq_length[idx])
]
# Flatten image embeds
if isinstance(image_embeds, list):
image_embeds = torch.cat(image_embeds, dim=0)
else:
image_embeds = type_cast(torch.Tensor, image_embeds)
image_embeds = torch.reshape(image_embeds, (-1, token_dim))
# Flatten text embeds across batch dim (B x S, D)
inputs_embeds = torch.reshape(inputs_embeds, (-1, token_dim))
flattened_seq_length = inputs_embeds.shape[0] + sum(image_seq_length)
text_insert_positions = sorted(
set(range(flattened_seq_length)).difference(set(image_insert_positions))
)
# Scatter image embeds in the flattened dict
# scatter text related stuff
output = torch.empty(
(flattened_seq_length, token_dim),
device=inputs_embeds.device,
dtype=inputs_embeds.dtype,
)
txt_positions_tensor = torch.Tensor(text_insert_positions).to(
dtype=torch.long, device=inputs_embeds.device
)
output.scatter_(0, txt_positions_tensor[:, None].expand(-1, token_dim), inputs_embeds)
attention_mask_new: torch.Tensor | None = None
if attention_mask is not None:
attention_mask_new = torch.ones(
(flattened_seq_length,), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask_new.scatter_(
0, txt_positions_tensor, attention_mask.flatten().to(torch.bool)
)
# scatter image related stuff
image_tokens_mask = torch.zeros(
(flattened_seq_length,), dtype=torch.bool, device=inputs_embeds.device
)
img_positions_tensor = torch.Tensor(image_insert_positions).to(
device=inputs_embeds.device, dtype=torch.long
)
output.scatter_(0, img_positions_tensor[:, None].expand(-1, token_dim), image_embeds)
image_tokens_mask.scatter_(0, img_positions_tensor, True)
# Compute expected sample length, taking into account the real batch
# i.e. recover the batch dimension of image embeddings
sample_lengths = []
counter = 0
for sample_idx, pts in enumerate(image_embeds_insertion_points):
num_image_tokens = 0
for _ in pts:
num_image_tokens += image_seq_length[counter]
counter += 1
if keep_only_attended and attention_mask is not None:
attended_seq_length = torch.sum(attention_mask[sample_idx]).cpu().item()
sample_lengths.append(attended_seq_length + num_image_tokens)
else:
sample_lengths.append(text_seq_length + num_image_tokens)
# For CASA attention, we can keep stuff flatten ad return
# the sample_lengths for the blockwise attention
if not recover_batch_dim:
if keep_only_attended and attention_mask_new is not None:
output = output[attention_mask_new]
image_tokens_mask = image_tokens_mask[attention_mask_new]
return output, sample_lengths, attention_mask_new, image_tokens_mask[..., None]
# Otherwise, time to (pad) and reshape
# Easy case: everything has the same length
if all(x == sample_lengths[0] for x in sample_lengths):
output = torch.reshape(output, (batch_size, sample_lengths[0], token_dim))
image_tokens_mask = torch.reshape(image_tokens_mask, (batch_size, sample_lengths[0], 1))
if attention_mask_new is not None:
attention_mask_new = torch.reshape(attention_mask_new, (batch_size, sample_lengths[0]))
# if there is any size mismatch we break into a
# list and pad again
else:
# split and merge
output = __split_n_merge__(output, sample_lengths, padding_side, pad_value=pad_output)
# note that the extra padding tokens are also marked as image tokens to be removed later
image_tokens_mask = __split_n_merge__(
image_tokens_mask, sample_lengths, padding_side, True
)[:, :, None]
if attention_mask_new is not None:
attention_mask_new = __split_n_merge__(
attention_mask_new, sample_lengths, padding_side, 0
)
# Return
return output, sample_lengths, attention_mask_new, image_tokens_mask
def get_sample_lengths_from_insertion_points(
image_embeds_insertion_points: list[torch.Tensor],
image_embeds: torch.Tensor | list[torch.Tensor] | None,
total_seq_len: int | None = None,
attention_mask: torch.Tensor | None = None,
**kwargs: WindowsComputeKwargs,
) -> tuple[list[tuple[int, bool]], list[int]]:
"""Compute sample lengths as if each image insertion point defines a
new document (ex document ID)
"""
num_post_image_tokens = type_cast(int, kwargs.get("num_post_image_tokens", 0))
num_pre_image_tokens = type_cast(int, kwargs.get("num_pre_image_tokens", 0))
squashed_samples_lengths = type_cast(
list[list[int]] | None, kwargs.get("squashed_samples_lengths", None)
)
if squashed_samples_lengths is not None:
assert len(squashed_samples_lengths) == len(image_embeds_insertion_points)
def __insert_next_sample__(
batch_idx: int, insrt_pt: int, last_insrt_pt: int, end_of_batch_sample: bool = False
) -> None:
nonlocal attention_mask
nonlocal text_sample_lengths, full_sample_lengths
nonlocal cum_samples_lengths, current_image_offset
nonlocal last_image_idx, current_image_idx, current_length
# Add the sample between [last_insrt_pt, insrt_pt] with breaks in
# between any squashed samples we find on the way
start_pt = bisect.bisect_left(cum_samples_lengths, last_insrt_pt)
added_sample = False
for end_of_sample in cum_samples_lengths[start_pt:]:
# we will break the loop at the end when end_of_sample = insrt_pt
end_of_sample = min(end_of_sample, insrt_pt)
# Add between [last_insrt_pt, end_of_sample]
current_length = end_of_sample - last_insrt_pt
if attention_mask is not None:
current_length -= int(
torch.sum(~attention_mask[batch_idx, last_insrt_pt:end_of_sample]).item()
)
if current_length > 0:
added_sample = True
text_sample_lengths.append(
(current_length, end_of_batch_sample and insrt_pt == end_of_sample)
)
# add image tokens to current_length
if current_image_idx > 0 and image_embeds is not None:
images_in_sample = [
img_idx
for img_idx in range(last_image_idx, current_image_idx)
if img_idx < len(image_embeds_insertion_points[batch_idx])
and last_insrt_pt
<= image_embeds_insertion_points[batch_idx][img_idx]
< end_of_sample
]
if len(images_in_sample) > 0:
num_image_tokens = sum(
_x.shape[0]
for _x in image_embeds[
current_image_offset + images_in_sample[0] : current_image_offset
+ images_in_sample[-1]
+ 1
]
)
current_length += num_image_tokens
full_sample_lengths.append(current_length)
# prepare for next loop
last_insrt_pt = end_of_sample
if end_of_sample == insrt_pt:
break
# End of loop: Catching weird use case where we may end up on a span
# full of padding tokens which will not get added due to current_length > 0
if end_of_batch_sample:
assert added_sample, "Weird edge case. Don't do that, thank you"
text_sample_lengths[-1] = (text_sample_lengths[-1][0], True)
# End of loop: Catching weird use case where we may end up on a span
# full of padding tokens which will not get added due to current_length > 0
if end_of_batch_sample:
assert added_sample, "Weird edge case. Don't do that, thank you"
text_sample_lengths[-1] = (text_sample_lengths[-1][0], True)
current_image_offset = 0
text_sample_lengths, full_sample_lengths = [], []
cum_samples_lengths: list[int] = []
current_length, last_insrt_pt, last_image_idx, current_image_idx = 0, 0, 0, 0
for batch_idx, pts in enumerate(image_embeds_insertion_points):
if squashed_samples_lengths is not None:
cum_samples_lengths = list(accumulate(squashed_samples_lengths[batch_idx]))
else:
assert total_seq_len is not None
cum_samples_lengths = [total_seq_len]
for current_image_idx, insrt_pt in enumerate(pts.cpu().tolist()):
# check if the images are consecutive in which way we want
# them to belong to the same window
if current_image_idx >= 1 and insrt_pt == (
image_embeds_insertion_points[batch_idx][current_image_idx - 1]
+ num_pre_image_tokens
+ num_post_image_tokens
):
continue
# Otherwise, we found a new sample
# not very important but for completeness: the insertion points come *after*
# the pre-image tokens per design but for the document-id mask it is more consistent to
# have them correspond to the same image
insrt_pt -= num_pre_image_tokens
# Update text and full sample lengths
if insrt_pt > last_insrt_pt:
__insert_next_sample__(
batch_idx, insrt_pt, last_insrt_pt, end_of_batch_sample=False
)
last_image_idx = current_image_idx
last_insrt_pt = insrt_pt
# End of batch: add sample in progress and reset
current_image_idx += 1
if cum_samples_lengths[-1] > last_insrt_pt:
__insert_next_sample__(
batch_idx, cum_samples_lengths[-1], last_insrt_pt, end_of_batch_sample=True
)
current_length, last_insrt_pt, last_image_idx, current_image_idx = 0, 0, 0, 0
current_image_offset += len(pts)
# Sanity checks that the is_eob are correctly place
assert sum(_x[1] for _x in text_sample_lengths) == len(image_embeds_insertion_points), (
f"Number of eob markers ({sum(_x[1] for _x in text_sample_lengths)}) differs"
f" from original batch size ({len(image_embeds_insertion_points)})"
)
return text_sample_lengths, full_sample_lengths
class CASAAttentionHandler:
def __init__(
self,
inputs_embeds: torch.Tensor,
image_embeds: torch.Tensor | list[torch.Tensor],
image_embeds_insertion_points: list[torch.Tensor],
attention_mask: torch.Tensor | None = None,
rope_fn: Callable | None = None,
windows: Literal["batch", "squashed", "images", "turn_based"] = "images",
use_asymetric_q_kv: bool = True,
casa_windows_info: None | dict = None,
):
"""Initialize the structure holding the query buffer for CASA attention layers
(ie the **flattened** text+image inserted tokens).
Note that this structure is shared across all casa layers, and it gets updated
with the current hidden states at every layer; this is merely a buffer to keep
scatter_ operations in-plae as much as possible
In this module, the embeddings related values (image_tokens_mask,
text_sample_lengths etc) are stored under the assumption of a tensor
which is *flatened* and *witout padding tokens*
Only the attention mask is kept as-is (text-only, batched, padded) to
be able to recover original shapes when needed
"""
super().__init__()
assert windows == "images" # for inference code release
# Note 1: Unless overriden, text/full_sample_lengths are defined such that one
# document = one sample in the batch
if attention_mask is None:
text_sample_lengths = [(_x.shape[0], True) for _x in inputs_embeds]
else:
text_sample_lengths = [(int(torch.sum(_x).item()), True) for _x in attention_mask]
(
full_inputs_embeds,
full_sample_lengths,
# Full attention mask is only needed at inference to
# flatten the KV-Cache and remove padding tokens
_,
self.image_tokens_mask,
) = insert_image_tokens(
inputs_embeds=inputs_embeds,
image_embeds=image_embeds,
image_embeds_insertion_points=image_embeds_insertion_points,
attention_mask=attention_mask,
recover_batch_dim=False,
keep_only_attended=attention_mask is not None,
)
assert self.image_tokens_mask.ndim == 2
self.image_embeds = image_embeds
self.image_embeds_insertion_points = image_embeds_insertion_points
self.attention_mask = None if attention_mask is None else attention_mask.bool()
self.use_asymetric_qkv = use_asymetric_q_kv
# At inference, we have to use asymetric QKV for efficiency
if self.attention_mask is not None:
self.use_asymetric_qkv = True
# Build CASA windows
assert casa_windows_info is not None
text_sample_lengths, full_sample_lengths = get_sample_lengths_from_insertion_points(
image_embeds_insertion_points=image_embeds_insertion_points,
image_embeds=image_embeds,
total_seq_len=inputs_embeds.shape[1],
attention_mask=self.attention_mask,
**casa_windows_info, # pyright: ignore
)
# Sanity checks on the sample lengths
self.text_sample_lengths = [(int(s), eob) for s, eob in text_sample_lengths if s > 0]
self.full_sample_lengths = [int(s) for s in full_sample_lengths if s > 0]
assert len(self.text_sample_lengths) == len(self.full_sample_lengths), (
f"Sanity check failed; text sample lengths {len(self.text_sample_lengths)}"
f" != full sample lengths {len(self.full_sample_lengths)}"
)
if self.attention_mask is None:
num_unpadded_text_tokens = inputs_embeds.shape[0] * inputs_embeds.shape[1]
else:
num_unpadded_text_tokens = int(
torch.sum(type_cast(torch.Tensor, attention_mask)).item()
)
assert sum(_x[0] for _x in self.text_sample_lengths) == num_unpadded_text_tokens, (
f"Sanity check failed; sample lengths {sum(self.full_sample_lengths)} != {full_inputs_embeds.shape[0]}"
)
assert sum(self.full_sample_lengths) == full_inputs_embeds.shape[0], (
f"Sanity check failed; sample lengths {sum(self.full_sample_lengths)} != {full_inputs_embeds.shape[0]}"
)
# Finally we can compute cu_seqlen based on sample lengths
self.max_seqlen_q = max(self.text_sample_lengths)[0]
self.cu_seqlens_q = self.get_cu_seqlens(
[x[0] for x in self.text_sample_lengths], device=inputs_embeds.device
)
self.max_seqlen_kv = max(self.full_sample_lengths)
self.cu_seqlens_kv = self.get_cu_seqlens(
self.full_sample_lengths, device=inputs_embeds.device
)
# For inference: We save the length of the current document
# to trim the KV cache appropriately
self.current_doc_lengths = self.full_sample_lengths
# Precompute position embeddings
self.position_embeds = None
self.rope_fn = rope_fn
if self.rope_fn is not None:
self.position_embeds = self.compute_position_embeddings(
self.rope_fn, full_sample_lengths, dummy_for_dtype_and_device=full_inputs_embeds
)
@property
def batch_lengths(self) -> list[int]:
"""Return a (batch_size,) list of integers containing the
number of (non-padded) text tokens for each sample in the batch"""
bls = [0]
for ln, eob in self.text_sample_lengths:
bls[-1] += ln
if eob:
bls.append(0)
return bls[:-1]
@property
def full_batch_lengths(self) -> list[int]:
"""Same as batch_lengths for text+image tokens"""
bls = [0]
for (_, eob), ln in zip(self.text_sample_lengths, self.full_sample_lengths):
bls[-1] += ln
if eob:
bls.append(0)
return bls[:-1]
def get_cu_seqlens(
self, sample_lengths: list[int], device: torch.device | None
) -> torch.Tensor:
"""Update cu_seqlengths according to the given sample_lengths"""
return torch.Tensor(list(accumulate(sample_lengths, initial=0))).to(
dtype=torch.int32, device=device
)
def compute_position_embeddings(
self,
rope_fn: Callable,
sample_lengths: list[int],
dummy_for_dtype_and_device: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute info required for position embeddings. Can be override e.g. for Qwen"""
# option 1: Standard range
# position_ids = torch.arange(0, full_inputs_embeds.shape[0])
# option 2: Follows document boundary
position_ids = torch.cat([torch.arange(0, lg) for lg in sample_lengths], dim=0)
return rope_fn(
dummy_for_dtype_and_device,
position_ids.to(dummy_for_dtype_and_device.device)[None, ...],
)
def get_position_embedding(
self,
key: Literal["q", "kv"],
num_queries: int = 0,
) -> tuple[torch.Tensor, torch.Tensor] | None:
if self.position_embeds is None:
return None
cos, sin = self.position_embeds
bls = self.full_batch_lengths
# For Q, we only want the text-only posembeds
if key == "q" and self.use_asymetric_qkv:
bls = self.batch_lengths
cos, sin = cos[:, ~self.image_tokens_mask[:, 0]], sin[:, ~self.image_tokens_mask[:, 0]]
elif key not in {"q", "kv"}:
raise ValueError(f"Unknow for position embedding {key}")
# Easy case: training or first step at inference: we use all the posembeds
if num_queries == 0:
return cos, sin
# If num queries is given, we need to trim for *every sample in the batch*
cos = [x[:, -num_queries:] for x in torch.split(cos, bls, dim=1)]
sin = [x[:, -num_queries:] for x in torch.split(sin, bls, dim=1)]
return torch.cat(cos, dim=1), torch.cat(sin, dim=1)
def get_full_embeds(
self, hidden_states: torch.Tensor, norm_fn: Callable | None
) -> torch.Tensor:
"""Update attended hidden states in the current query buffer
:param hidden_states: (b, s, d) Tensor input to the CASA attention layer"
"""
assert self.image_embeds is not None
return insert_image_tokens(
inputs_embeds=hidden_states,
image_embeds=self.image_embeds
if norm_fn is None
else norm_fn(self.image_embeds)
if isinstance(self.image_embeds, torch.Tensor)
else [norm_fn(_x) for _x in self.image_embeds],
image_embeds_insertion_points=self.image_embeds_insertion_points,
attention_mask=self.attention_mask,
recover_batch_dim=False,
keep_only_attended=self.attention_mask is not None,
)[0][None, :, :]
def recover_text_embeds(
self,
hidden_states_out: torch.Tensor,
hidden_states_in: torch.Tensor,
update_image_embeddings: bool = False,
) -> torch.Tensor:
"""Returns text embeddings from the query buffer, including non-attended tokens at inference"""
if update_image_embeddings and not self.use_asymetric_qkv:
raise NotImplementedError("Implement image embeddings updates for asymetric QKV")
# Remove image tokens in the symetric case
if not self.use_asymetric_qkv:
hidden_states_out = hidden_states_out[~self.image_tokens_mask[:, 0]]
# if there's not attention mask, we are in the right padded case
# (keep_only_attended = False) we can directly return the query
# outputs (which don't contain the image)
if self.attention_mask is None:
return hidden_states_out
# Otherwise, we need to "scatter" back only the text-attended tokens to the original
# hidden states, which contain the paddings
num_queries = hidden_states_in.shape[1]
# Case 1: the padded hidden_states_in is larger than hidden_states_out
# we rebatch+pad hidden_state_out before doing the scattering
if hidden_states_out.shape[0] != hidden_states_in.shape[0] * hidden_states_in.shape[1]:
s = torch.split(hidden_states_out, self.batch_lengths, dim=0)
assert max(_s.shape[0] for _s in s) <= num_queries # sanity check
s = [
torch.nn.functional.pad(_s, (0, 0, num_queries - _s.shape[0], 0), value=0)
for _s in s
]
return torch.where(
self.attention_mask[:, -num_queries:, None],
torch.stack(s),
hidden_states_in,
)
# If both have the smae shape, it means hidden_states_in contained no padding
# so we can directly return hidden states out
return hidden_states_out
def extend(self, num_tokens: int, offset: int = 0):
"""Extend all necessary values of the Handler for infenrece
Note: this implementation curently assumes a single conversation at a time
(otherwise image tokens mask would have to change) and that tokens added are
attended to"""
# image embeds is inserted in the first step and stored in the KV cache
self.image_embeds = None
# Update attention mask (non-flattened) (assumes all new tokens are attended to)
if self.attention_mask is not None:
self.attention_mask = torch.nn.functional.pad(
self.attention_mask, (0, num_tokens), value=1
)
# Update image token mask (assumes only one image/conversation
# is started at once so that we always extend by zero)
# Note that the mask is stored flattened to avoid padding so we have to
# do something a bit ugly and inefficient here
imtokmask = torch.split(self.image_tokens_mask, self.full_batch_lengths, dim=0)
imtokmask = [torch.nn.functional.pad(x, (0, 0, 0, num_tokens), value=0) for x in imtokmask]
self.image_tokens_mask = torch.cat(imtokmask, dim=0)
# Recompute cumulative document lengths after assigning the new
# number of tokens to each sample in the batch
for idx, (ln, is_eob) in enumerate(self.text_sample_lengths):
if is_eob:
self.text_sample_lengths[idx] = (num_tokens + ln, is_eob)
self.full_sample_lengths[idx] += num_tokens
# Recompute cu sequlen
# First step: Technically this never occurs, but we keep it for completeness
if offset == 0:
self.max_seqlen_q = max(self.text_sample_lengths)[0]
self.cu_seqlens_q = self.get_cu_seqlens(
[x[0] for x in self.text_sample_lengths], device=self.cu_seqlens_q.device
)
self.max_seqlen_kv = max(self.full_sample_lengths)
self.cu_seqlens_kv = self.get_cu_seqlens(
self.full_sample_lengths, device=self.cu_seqlens_kv.device
)
# Step > 0: the annoying part is since flashattn_varlen does not accept
# 0-len documents, we need to remove documents from the KV Cache when they're past
# their windows. In our current setting, this means we only want to keep the latest
# documents
else:
self.max_seqlen_q = num_tokens
self.cu_seqlens_q = self.get_cu_seqlens(
[num_tokens for (_, eob) in self.text_sample_lengths if eob],
device=self.cu_seqlens_q.device,
)
final_doc_lengths = [
ln
for (_, eob), ln in zip(self.text_sample_lengths, self.full_sample_lengths)
if eob
]
self.current_doc_lengths = final_doc_lengths
self.max_seqlen_kv = max(self.current_doc_lengths)
self.cu_seqlens_kv = self.get_cu_seqlens(
final_doc_lengths,
device=self.cu_seqlens_kv.device,
)
# Update position embeddings
if self.rope_fn is not None and self.position_embeds is not None:
self.position_embeds = self.compute_position_embeddings(
self.rope_fn,
self.full_sample_lengths,
dummy_for_dtype_and_device=self.position_embeds[0],
)
@dataclass
class CASAAttentionStreamingState(StreamingState):
"""Streaming State for CASA Atention module. Keep the hidden"""
k: torch.Tensor = None # pyright: ignore[reportAssignmentType]
v: torch.Tensor = None # pyright: ignore[reportAssignmentType]
recover_batched_trims: list[int] = None # pyright: ignore[reportAssignmentType]
casa_handler: CASAAttentionHandler = None # pyright: ignore[reportAssignmentType]
def maybe_get_casa_handler(
self,
casa_handler: CASAAttentionHandler | None,
is_first_casa_layer: bool = False,
num_queries: int = -1,
) -> CASAAttentionHandler | None:
# Set given Casa Handler the first time we reach this
if self.casa_handler is None:
self.casa_handler = casa_handler # pyright: ignore
# subsequent calls: we need to extend shape to accomodate new tokens
# however because CASA handler is shared across layers, we only need to do it once
if self.casa_handler is not None and self.offset > 0 and is_first_casa_layer:
# since CasaHandler is shared, we only use its extend step once
self.casa_handler.extend(num_queries, offset=self.offset)
return self.casa_handler
def __recover_batched_kv__(self, states: torch.Tensor) -> torch.Tensor:
"""Recover batched key/value states with left padding"""
s = torch.split(states, self.casa_handler.full_batch_lengths, dim=1)
mlen = max(_s.shape[1] for _s in s)
# Remember the added padding so that we can re-flatten KV later
if self.recover_batched_trims is None:
self.recover_batched_trims = [mlen - _s.shape[1] for _s in s]
s = [torch.nn.functional.pad(_s, (0, 0, 0, 0, mlen - _s.shape[1], 0), value=0) for _s in s]
return torch.cat(s, dim=0)
def __get_flattened_kv__(
self, k: torch.Tensor | None = None, v: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Flattened and remove padding to act with flash_attn_func
"""
k = self.k if k is None else k
v = self.v if v is None else v
assert k is not None and v is not None
# Since every batch at least contributes one document,
# we can use this to check whether we are in streaming mode with dropped docs.
# If so, we should trim the kv cache accordingly
if len(self.casa_handler.current_doc_lengths) == len(k):
k = torch.cat(
[
_k[self.recover_batched_trims[idx] :][-doc_len:]
for idx, _k, doc_len in zip(
range(len(k)), k, self.casa_handler.current_doc_lengths
)
]
)
v = torch.cat(
[
_v[self.recover_batched_trims[idx] :][-doc_len:]
for idx, _v, doc_len in zip(
range(len(k)), v, self.casa_handler.current_doc_lengths
)
]
)
return k[None, ...], v[None, ...]
k = torch.cat([_k[self.recover_batched_trims[idx] :] for idx, _k in enumerate(k)])
v = torch.cat([_v[self.recover_batched_trims[idx] :] for idx, _v in enumerate(v)])
return k[None, ...], v[None, ...]
def extend_kv(
self, key_states: torch.Tensor, value_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Extend KV Cache while keep
"""
assert self.casa_handler is not None
if self.k is None and self.v is None:
# Init with batch-padded key and value states
self.k = self.__recover_batched_kv__(key_states)
self.v = self.__recover_batched_kv__(value_states)
return self.__get_flattened_kv__()
if self.k is not None and self.v is not None:
# this is during generation; normally there is no padding at this stage
# so we can directly reshape the flattened key states
rshp = (self.k.shape[0], -1, self.k.shape[2], self.k.shape[3])
self.k = torch.cat([self.k, key_states.reshape(rshp)], dim=1)
self.v = torch.cat([self.v, value_states.reshape(rshp)], dim=1)
return self.__get_flattened_kv__()
raise ValueError("Impossible configuration (k and v updates are desynchronized )")
class CASAAttention(StreamingModule[CASAAttentionStreamingState]):
def __init__(
self,
config: "PretrainedConfig",
layer_idx: int | None,
self_attn: torch.nn.Module | None = None,
input_layernorm_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
):
super().__init__(CASAAttentionStreamingState)
self.head_dim = config.head_dim
self.config = config
self.is_first_casa_layer = layer_idx == (min(config.xa_layers) if config.xa_layers else 0)
self.use_delta_w = config.casa_delta_w
self.q_proj_casa = self.init_from_config_proj("q", config)
self.k_proj_casa = self.init_from_config_proj("k", config)
self.v_proj_casa = self.init_from_config_proj("v", config)
self.o_proj_casa = self.init_from_config_proj("o", config)
# Delta_w
self.override_q_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
self.override_k_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
self.override_v_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
self.override_o_proj: Callable[[torch.Tensor], torch.Tensor] | None = None
if config.casa_delta_w:
assert self_attn is not None
self.set_delta_w(self_attn)
# Layer norm
self.norm_fn: Callable | None = None
if config.xa_norm_on_images:
assert input_layernorm_fn is not None
self.norm_fn = input_layernorm_fn
def init_from_mha(self, self_attn: torch.nn.Module):
assert self_attn is not None
with torch.no_grad():
assert hasattr(self_attn, "q_proj")
for key in ["q", "k", "v", "o"]:
src = type_cast(torch.nn.Linear, getattr(self_attn, f"{key}_proj"))
tgt = type_cast(torch.nn.Linear, getattr(self, f"{key}_proj_casa"))
tgt.weight.copy_(src.weight)
if tgt.bias is not None and src.bias is not None:
tgt.bias.copy_(src.bias)
def set_delta_w(self, self_attn: torch.nn.Module):
"""Delta w setup"""
self.override_q_proj = delta_w_factory(
self.q_proj_casa, type_cast(torch.nn.Linear, self_attn.q_proj)
)
self.override_k_proj = delta_w_factory(
self.k_proj_casa, type_cast(torch.nn.Linear, self_attn.k_proj)
)
self.override_v_proj = delta_w_factory(
self.v_proj_casa, type_cast(torch.nn.Linear, self_attn.v_proj)
)
self.override_o_proj = delta_w_factory(
self.o_proj_casa, type_cast(torch.nn.Linear, self_attn.o_proj)
)
with torch.no_grad():
torch.nn.init.zeros_(self.q_proj_casa.weight)
torch.nn.init.zeros_(self.k_proj_casa.weight)
torch.nn.init.zeros_(self.v_proj_casa.weight)
torch.nn.init.zeros_(self.o_proj_casa.weight)
if self.q_proj_casa.bias is not None:
torch.nn.init.zeros_(self.q_proj_casa.bias)
if self.k_proj_casa.bias is not None:
torch.nn.init.zeros_(self.k_proj_casa.bias)
if self.v_proj_casa.bias is not None:
torch.nn.init.zeros_(self.v_proj_casa.bias)
if self.o_proj_casa.bias is not None:
torch.nn.init.zeros_(self.o_proj_casa.bias)
def init_from_config_proj(
self, key: Literal["q", "o", "k", "v"], config: PretrainedConfig
) -> torch.nn.Linear:
"""Initialize the Linear proj in this module"""
raise NotImplementedError("Abastract class.")
def apply_position_embeddings(
self,
key: Literal["q", "kv"],
x: torch.Tensor, # (batch, seq_len, num_heads, head_dim)
casa_handler: CASAAttentionHandler | None,
num_queries: int = 0,
unsqueeze_dim: int = 1,
) -> torch.Tensor: # (batch, seq_len, num_heads, head_dim)
"""Apply position embeddings to query and key states"""
raise NotImplementedError("Abastract class.")
def forward(
self,
hidden_states: torch.Tensor,
casa_handler: CASAAttentionHandler | None,
) -> torch.Tensor | None:
"""Generic forward for CASA uses for instance in `helium1_attention`"""
og_dtype = hidden_states.dtype
if self.is_streaming:
casa_handler = self.streaming_state.maybe_get_casa_handler(
casa_handler,
is_first_casa_layer=self.is_first_casa_layer,
num_queries=hidden_states.shape[1],
)
# Case of text-only samples at training (or inference when no handler was cached)
# in this case we just skip CASA so we return None (no casa_update)
if casa_handler is None:
return None
if self.is_streaming:
assert casa_handler.use_asymetric_qkv, (
"You should set `use_asymetric_qkv` to True during inference"
)
og_shape = hidden_states.shape
# Build Q inputs
if casa_handler.use_asymetric_qkv:
q_inputs = hidden_states.flatten(0, 1)[None, ...]
if casa_handler.attention_mask is not None:
q_inputs = q_inputs[:, casa_handler.attention_mask[:, -og_shape[1] :].flatten()]
else:
q_inputs = casa_handler.get_full_embeds(hidden_states, norm_fn=self.norm_fn)
# Case 1: Training or first inference step
if not self.is_streaming or self.streaming_state.offset == 0:
kv_inputs = casa_handler.get_full_embeds(hidden_states, norm_fn=self.norm_fn)
else:
# during streaming, the KV cache including image embeddings
# will be inserted later so for now we only update the incoming queries
kv_inputs = q_inputs
# Compute QKV for the blockwise attention
bs, total_seq_len = kv_inputs.shape[:2]
hidden_shape_q = (bs, q_inputs.shape[1], -1, self.head_dim)
hidden_shape_kv = (bs, total_seq_len, -1, self.head_dim)
if self.override_q_proj is None:
query_states = self.q_proj_casa(q_inputs).view(*hidden_shape_q)
else:
query_states = self.override_q_proj(q_inputs).view(*hidden_shape_q)
if self.override_k_proj is None:
key_states = self.k_proj_casa(kv_inputs).view(*hidden_shape_kv)
else:
key_states = self.override_k_proj(kv_inputs).view(*hidden_shape_kv)
if self.override_v_proj is None:
value_states = self.v_proj_casa(kv_inputs).view(*hidden_shape_kv)
else:
value_states = self.override_v_proj(kv_inputs).view(*hidden_shape_kv)
# Apply position embedding at the right offset
num_queries = 0
if self.streaming and self.streaming_state.offset > 0:
num_queries = og_shape[1]
query_states = self.apply_position_embeddings(
"q", query_states, num_queries=num_queries, casa_handler=casa_handler
)
key_states = self.apply_position_embeddings(
"kv", key_states, num_queries=num_queries, casa_handler=casa_handler
)
assert flash_attn_varlen_func is not None, (
"flash_attention is not installed but required for block-wise attention"
)
# Flashattention has different efficient implem for streaming
# In that case, the KV cache has to be batched and has been extended
# to accomodate the shape of ne the new updates
if self.is_streaming:
key_states, value_states = self.streaming_state.extend_kv(
key_states=key_states, value_states=value_states
)
if casa_handler.use_asymetric_qkv:
cu_seqlens_q = casa_handler.cu_seqlens_q
max_seqlen_q = casa_handler.max_seqlen_q
else:
cu_seqlens_q = casa_handler.cu_seqlens_kv
max_seqlen_q = casa_handler.max_seqlen_kv
assert cu_seqlens_q[-1] == query_states.shape[1], (
f"{cu_seqlens_q[-1]} != {query_states.shape[1]}"
)
assert casa_handler.cu_seqlens_kv[-1] == key_states.shape[1], (
f"{casa_handler.cu_seqlens_kv[-1]} != {key_states.shape[1]}"
)
# for quer
attn_output: torch.Tensor = flash_attn_varlen_func(
query_states[0].to(torch.bfloat16),
key_states[0].to(torch.bfloat16),
value_states[0].to(torch.bfloat16),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=casa_handler.cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=casa_handler.max_seqlen_kv,
dropout_p=0.0,
# softmax_scale=None, # defaults to 1/sqrt(d)
causal=True,
).to(og_dtype)
attn_output = attn_output.reshape(hidden_shape_q[1], -1).contiguous()
if self.override_o_proj is None:
attn_output = self.o_proj_casa(attn_output)
else:
attn_output = self.override_o_proj(attn_output)
attn_output = casa_handler.recover_text_embeds(
attn_output, hidden_states, update_image_embeddings=self.config.xa_update_image_embeds
)
attn_output = attn_output.reshape(og_shape)
if self.is_streaming:
self.streaming_state.offset += attn_output.shape[1]
return attn_output