|
|
"""CASA layers""" |
|
|
|
|
|
import bisect |
|
|
from dataclasses import dataclass |
|
|
from itertools import accumulate |
|
|
from typing import TYPE_CHECKING, Callable, Literal, Sequence, TypedDict, overload |
|
|
from typing import cast as type_cast |
|
|
|
|
|
import torch |
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
|
|
|
from .utils import StreamingModule, StreamingState, delta_w_factory |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
|
|
|
try: |
|
|
from flash_attn import flash_attn_varlen_func |
|
|
except ImportError: |
|
|
flash_attn_varlen_func = None |
|
|
|
|
|
|
|
|
WindowsComputeKwargs = TypedDict( |
|
|
"WindowsComputeKwargs", |
|
|
{ |
|
|
"num_post_image_tokens": int, |
|
|
"num_pre_image_tokens": int, |
|
|
}, |
|
|
total=False, |
|
|
) |
|
|
|
|
|
|
|
|
def __split_n_merge__( |
|
|
x: torch.Tensor, |
|
|
sample_lengths: list[int], |
|
|
padding_side: Literal["left", "right"] = "right", |
|
|
pad_value: int | float | bool = 0, |
|
|
) -> torch.Tensor: |
|
|
max_sample_length = max(sample_lengths) |
|
|
pad_tuple = tuple(0 for _ in range((x.ndim - 1) * 2)) |
|
|
return torch.stack( |
|
|
[ |
|
|
torch.nn.functional.pad( |
|
|
_x, |
|
|
pad_tuple + (0, max_sample_length - _x.shape[0]) |
|
|
if padding_side == "right" |
|
|
else pad_tuple + (max_sample_length - _x.shape[0], 0), |
|
|
value=pad_value, |
|
|
) |
|
|
for _x in torch.split(x, sample_lengths, dim=0) |
|
|
], |
|
|
dim=0, |
|
|
) |
|
|
|
|
|
|
|
|
@overload |
|
|
def insert_image_tokens( |
|
|
inputs_embeds: torch.Tensor, |
|
|
image_embeds: torch.Tensor | Sequence[torch.Tensor], |
|
|
image_embeds_insertion_points: list[torch.Tensor], |
|
|
recover_batch_dim: Literal[True], |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
padding_side: Literal["left", "right"] = "right", |
|
|
keep_only_attended: bool = False, |
|
|
pad_output: int | float | bool = 0.0, |
|
|
) -> tuple[ |
|
|
torch.Tensor, |
|
|
None, |
|
|
torch.Tensor | None, |
|
|
torch.Tensor, |
|
|
]: ... |
|
|
@overload |
|
|
def insert_image_tokens( |
|
|
inputs_embeds: torch.Tensor, |
|
|
image_embeds: torch.Tensor | Sequence[torch.Tensor], |
|
|
image_embeds_insertion_points: list[torch.Tensor], |
|
|
recover_batch_dim: Literal[False], |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
padding_side: Literal["left", "right"] = "right", |
|
|
keep_only_attended: bool = False, |
|
|
pad_output: int | float | bool = 0.0, |
|
|
) -> tuple[ |
|
|
torch.Tensor, |
|
|
list[int], |
|
|
torch.Tensor | None, |
|
|
torch.Tensor, |
|
|
]: ... |
|
|
def insert_image_tokens( |
|
|
inputs_embeds: torch.Tensor, |
|
|
image_embeds: torch.Tensor | Sequence[torch.Tensor], |
|
|
image_embeds_insertion_points: list[torch.Tensor], |
|
|
recover_batch_dim: bool = True, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
padding_side: Literal["left", "right"] = "right", |
|
|
keep_only_attended: bool = False, |
|
|
pad_output: int | float | bool = 0.0, |
|
|
) -> tuple[ |
|
|
torch.Tensor | torch.Tensor, |
|
|
list[int] | None, |
|
|
torch.Tensor | torch.Tensor | None, |
|
|
torch.Tensor | torch.Tensor, |
|
|
]: |
|
|
""" |
|
|
Insert image embeddings into text embeddings |
|
|
|
|
|
Args: |
|
|
inputs_embeds (torch.Tensor): (B, S, D) input token embeddings. |
|
|
image_embeds (torch.Tensor | list[torch.Tensor]): (N_images, Nt, D) | List[(Nt, D)] image token embeddings. |
|
|
image_embeds_insertion_points (list[torch.Tensor]): Insertion indices. |
|
|
attention_mask (torch.Tensor, optional): (B, S) attention mask. |
|
|
padding_side (Literal["left", "right"]): Padding scheme. Controls behavior for padded images. |
|
|
return_indices (bool): Whether to return gather indices or the fused sequence directly. |
|
|
keep_only_attended: This is only applicable when recover_batch_dim is False; whether to |
|
|
remove any non-attended tokens in the whole array. In this case, the attention |
|
|
mask returned is **still the original one**, so we can remember which indices have been |
|
|
removed |
|
|
Returns: |
|
|
output (torch.Tensor): (B, S + Ni * Nt) gather indices or (B, S + Ni * Nt, D) fused sequence |
|
|
image_embeds (torch.Tensor): (B, Ni * Nt) image embeds, padded and batch if input was a list |
|
|
attention_mask (torch.Tensor): Same shape, 1 for real tokens, 0 for image and text padding. |
|
|
image_tokens_mask (torch.Tensor): (B, S + Ni * Nt, 1), marks image token positions. |
|
|
""" |
|
|
if isinstance(image_embeds, list) and len(image_embeds) == 0: |
|
|
batch_size, text_seq_length, token_dim = inputs_embeds.shape |
|
|
if recover_batch_dim: |
|
|
return ( |
|
|
inputs_embeds, |
|
|
None, |
|
|
attention_mask, |
|
|
torch.zeros((batch_size, text_seq_length, 1), dtype=torch.bool), |
|
|
) |
|
|
else: |
|
|
flattened_seq_length = inputs_embeds.shape[0] * inputs_embeds.shape[1] |
|
|
return ( |
|
|
torch.reshape(inputs_embeds, (flattened_seq_length, inputs_embeds.shape[2])), |
|
|
[text_seq_length] * inputs_embeds.shape[0], |
|
|
attention_mask.flatten() if attention_mask is not None else None, |
|
|
torch.zeros((flattened_seq_length, 1), dtype=torch.bool), |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(image_embeds, torch.Tensor): |
|
|
assert inputs_embeds.shape[-1] == image_embeds.shape[-1] |
|
|
else: |
|
|
assert all(inputs_embeds.shape[-1] == _x.shape[-1] for _x in image_embeds) |
|
|
|
|
|
batch_size, text_seq_length, token_dim = inputs_embeds.shape |
|
|
image_seq_length = [x.shape[0] for x in image_embeds] |
|
|
|
|
|
|
|
|
insertion_offset = [] |
|
|
counter, offset_from_text, offset_from_image = 0, 0, 0 |
|
|
for sample in image_embeds_insertion_points: |
|
|
for pt in sample: |
|
|
insertion_offset.append(pt + offset_from_image + offset_from_text) |
|
|
offset_from_image += image_seq_length[counter] |
|
|
counter += 1 |
|
|
offset_from_text += text_seq_length |
|
|
image_insert_positions = [ |
|
|
x for idx, pt in enumerate(insertion_offset) for x in range(pt, pt + image_seq_length[idx]) |
|
|
] |
|
|
|
|
|
|
|
|
if isinstance(image_embeds, list): |
|
|
image_embeds = torch.cat(image_embeds, dim=0) |
|
|
else: |
|
|
image_embeds = type_cast(torch.Tensor, image_embeds) |
|
|
image_embeds = torch.reshape(image_embeds, (-1, token_dim)) |
|
|
|
|
|
|
|
|
inputs_embeds = torch.reshape(inputs_embeds, (-1, token_dim)) |
|
|
flattened_seq_length = inputs_embeds.shape[0] + sum(image_seq_length) |
|
|
text_insert_positions = sorted( |
|
|
set(range(flattened_seq_length)).difference(set(image_insert_positions)) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
output = torch.empty( |
|
|
(flattened_seq_length, token_dim), |
|
|
device=inputs_embeds.device, |
|
|
dtype=inputs_embeds.dtype, |
|
|
) |
|
|
txt_positions_tensor = torch.Tensor(text_insert_positions).to( |
|
|
dtype=torch.long, device=inputs_embeds.device |
|
|
) |
|
|
output.scatter_(0, txt_positions_tensor[:, None].expand(-1, token_dim), inputs_embeds) |
|
|
attention_mask_new: torch.Tensor | None = None |
|
|
if attention_mask is not None: |
|
|
attention_mask_new = torch.ones( |
|
|
(flattened_seq_length,), dtype=torch.bool, device=inputs_embeds.device |
|
|
) |
|
|
attention_mask_new.scatter_( |
|
|
0, txt_positions_tensor, attention_mask.flatten().to(torch.bool) |
|
|
) |
|
|
|
|
|
|
|
|
image_tokens_mask = torch.zeros( |
|
|
(flattened_seq_length,), dtype=torch.bool, device=inputs_embeds.device |
|
|
) |
|
|
img_positions_tensor = torch.Tensor(image_insert_positions).to( |
|
|
device=inputs_embeds.device, dtype=torch.long |
|
|
) |
|
|
output.scatter_(0, img_positions_tensor[:, None].expand(-1, token_dim), image_embeds) |
|
|
image_tokens_mask.scatter_(0, img_positions_tensor, True) |
|
|
|
|
|
|
|
|
|
|
|
sample_lengths = [] |
|
|
counter = 0 |
|
|
for sample_idx, pts in enumerate(image_embeds_insertion_points): |
|
|
num_image_tokens = 0 |
|
|
for _ in pts: |
|
|
num_image_tokens += image_seq_length[counter] |
|
|
counter += 1 |
|
|
if keep_only_attended and attention_mask is not None: |
|
|
attended_seq_length = torch.sum(attention_mask[sample_idx]).cpu().item() |
|
|
sample_lengths.append(attended_seq_length + num_image_tokens) |
|
|
else: |
|
|
sample_lengths.append(text_seq_length + num_image_tokens) |
|
|
|
|
|
|
|
|
|
|
|
if not recover_batch_dim: |
|
|
if keep_only_attended and attention_mask_new is not None: |
|
|
output = output[attention_mask_new] |
|
|
image_tokens_mask = image_tokens_mask[attention_mask_new] |
|
|
return output, sample_lengths, attention_mask_new, image_tokens_mask[..., None] |
|
|
|
|
|
|
|
|
|
|
|
if all(x == sample_lengths[0] for x in sample_lengths): |
|
|
output = torch.reshape(output, (batch_size, sample_lengths[0], token_dim)) |
|
|
image_tokens_mask = torch.reshape(image_tokens_mask, (batch_size, sample_lengths[0], 1)) |
|
|
if attention_mask_new is not None: |
|
|
attention_mask_new = torch.reshape(attention_mask_new, (batch_size, sample_lengths[0])) |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
output = __split_n_merge__(output, sample_lengths, padding_side, pad_value=pad_output) |
|
|
|
|
|
image_tokens_mask = __split_n_merge__( |
|
|
image_tokens_mask, sample_lengths, padding_side, True |
|
|
)[:, :, None] |
|
|
if attention_mask_new is not None: |
|
|
attention_mask_new = __split_n_merge__( |
|
|
attention_mask_new, sample_lengths, padding_side, 0 |
|
|
) |
|
|
|
|
|
return output, sample_lengths, attention_mask_new, image_tokens_mask |
|
|
|
|
|
|
|
|
def get_sample_lengths_from_insertion_points( |
|
|
image_embeds_insertion_points: list[torch.Tensor], |
|
|
image_embeds: torch.Tensor | list[torch.Tensor] | None, |
|
|
total_seq_len: int | None = None, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
**kwargs: WindowsComputeKwargs, |
|
|
) -> tuple[list[tuple[int, bool]], list[int]]: |
|
|
"""Compute sample lengths as if each image insertion point defines a |
|
|
new document (ex document ID) |
|
|
""" |
|
|
num_post_image_tokens = type_cast(int, kwargs.get("num_post_image_tokens", 0)) |
|
|
num_pre_image_tokens = type_cast(int, kwargs.get("num_pre_image_tokens", 0)) |
|
|
squashed_samples_lengths = type_cast( |
|
|
list[list[int]] | None, kwargs.get("squashed_samples_lengths", None) |
|
|
) |
|
|
if squashed_samples_lengths is not None: |
|
|
assert len(squashed_samples_lengths) == len(image_embeds_insertion_points) |
|
|
|
|
|
def __insert_next_sample__( |
|
|
batch_idx: int, insrt_pt: int, last_insrt_pt: int, end_of_batch_sample: bool = False |
|
|
) -> None: |
|
|
nonlocal attention_mask |
|
|
nonlocal text_sample_lengths, full_sample_lengths |
|
|
nonlocal cum_samples_lengths, current_image_offset |
|
|
nonlocal last_image_idx, current_image_idx, current_length |
|
|
|
|
|
|
|
|
start_pt = bisect.bisect_left(cum_samples_lengths, last_insrt_pt) |
|
|
added_sample = False |
|
|
for end_of_sample in cum_samples_lengths[start_pt:]: |
|
|
|
|
|
end_of_sample = min(end_of_sample, insrt_pt) |
|
|
|
|
|
|
|
|
current_length = end_of_sample - last_insrt_pt |
|
|
if attention_mask is not None: |
|
|
current_length -= int( |
|
|
torch.sum(~attention_mask[batch_idx, last_insrt_pt:end_of_sample]).item() |
|
|
) |
|
|
if current_length > 0: |
|
|
added_sample = True |
|
|
text_sample_lengths.append( |
|
|
(current_length, end_of_batch_sample and insrt_pt == end_of_sample) |
|
|
) |
|
|
|
|
|
if current_image_idx > 0 and image_embeds is not None: |
|
|
images_in_sample = [ |
|
|
img_idx |
|
|
for img_idx in range(last_image_idx, current_image_idx) |
|
|
if img_idx < len(image_embeds_insertion_points[batch_idx]) |
|
|
and last_insrt_pt |
|
|
<= image_embeds_insertion_points[batch_idx][img_idx] |
|
|
< end_of_sample |
|
|
] |
|
|
if len(images_in_sample) > 0: |
|
|
num_image_tokens = sum( |
|
|
_x.shape[0] |
|
|
for _x in image_embeds[ |
|
|
current_image_offset + images_in_sample[0] : current_image_offset |
|
|
+ images_in_sample[-1] |
|
|
+ 1 |
|
|
] |
|
|
) |
|
|
current_length += num_image_tokens |
|
|
full_sample_lengths.append(current_length) |
|
|
|
|
|
|
|
|
last_insrt_pt = end_of_sample |
|
|
if end_of_sample == insrt_pt: |
|
|
break |
|
|
|
|
|
|
|
|
if end_of_batch_sample: |
|
|
assert added_sample, "Weird edge case. Don't do that, thank you" |
|
|
text_sample_lengths[-1] = (text_sample_lengths[-1][0], True) |
|
|
|
|
|
|
|
|
|
|
|
if end_of_batch_sample: |
|
|
assert added_sample, "Weird edge case. Don't do that, thank you" |
|
|
text_sample_lengths[-1] = (text_sample_lengths[-1][0], True) |
|
|
|
|
|
current_image_offset = 0 |
|
|
text_sample_lengths, full_sample_lengths = [], [] |
|
|
cum_samples_lengths: list[int] = [] |
|
|
current_length, last_insrt_pt, last_image_idx, current_image_idx = 0, 0, 0, 0 |
|
|
for batch_idx, pts in enumerate(image_embeds_insertion_points): |
|
|
if squashed_samples_lengths is not None: |
|
|
cum_samples_lengths = list(accumulate(squashed_samples_lengths[batch_idx])) |
|
|
else: |
|
|
assert total_seq_len is not None |
|
|
cum_samples_lengths = [total_seq_len] |
|
|
|
|
|
for current_image_idx, insrt_pt in enumerate(pts.cpu().tolist()): |
|
|
|
|
|
|
|
|
if current_image_idx >= 1 and insrt_pt == ( |
|
|
image_embeds_insertion_points[batch_idx][current_image_idx - 1] |
|
|
+ num_pre_image_tokens |
|
|
+ num_post_image_tokens |
|
|
): |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
insrt_pt -= num_pre_image_tokens |
|
|
|
|
|
|
|
|
if insrt_pt > last_insrt_pt: |
|
|
__insert_next_sample__( |
|
|
batch_idx, insrt_pt, last_insrt_pt, end_of_batch_sample=False |
|
|
) |
|
|
last_image_idx = current_image_idx |
|
|
last_insrt_pt = insrt_pt |
|
|
|
|
|
|
|
|
current_image_idx += 1 |
|
|
if cum_samples_lengths[-1] > last_insrt_pt: |
|
|
__insert_next_sample__( |
|
|
batch_idx, cum_samples_lengths[-1], last_insrt_pt, end_of_batch_sample=True |
|
|
) |
|
|
current_length, last_insrt_pt, last_image_idx, current_image_idx = 0, 0, 0, 0 |
|
|
current_image_offset += len(pts) |
|
|
|
|
|
|
|
|
assert sum(_x[1] for _x in text_sample_lengths) == len(image_embeds_insertion_points), ( |
|
|
f"Number of eob markers ({sum(_x[1] for _x in text_sample_lengths)}) differs" |
|
|
f" from original batch size ({len(image_embeds_insertion_points)})" |
|
|
) |
|
|
return text_sample_lengths, full_sample_lengths |
|
|
|
|
|
|
|
|
class CASAAttentionHandler: |
|
|
def __init__( |
|
|
self, |
|
|
inputs_embeds: torch.Tensor, |
|
|
image_embeds: torch.Tensor | list[torch.Tensor], |
|
|
image_embeds_insertion_points: list[torch.Tensor], |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
rope_fn: Callable | None = None, |
|
|
windows: Literal["batch", "squashed", "images", "turn_based"] = "images", |
|
|
use_asymetric_q_kv: bool = True, |
|
|
casa_windows_info: None | dict = None, |
|
|
): |
|
|
"""Initialize the structure holding the query buffer for CASA attention layers |
|
|
(ie the **flattened** text+image inserted tokens). |
|
|
Note that this structure is shared across all casa layers, and it gets updated |
|
|
with the current hidden states at every layer; this is merely a buffer to keep |
|
|
scatter_ operations in-plae as much as possible |
|
|
|
|
|
In this module, the embeddings related values (image_tokens_mask, |
|
|
text_sample_lengths etc) are stored under the assumption of a tensor |
|
|
which is *flatened* and *witout padding tokens* |
|
|
Only the attention mask is kept as-is (text-only, batched, padded) to |
|
|
be able to recover original shapes when needed |
|
|
""" |
|
|
super().__init__() |
|
|
assert windows == "images" |
|
|
|
|
|
|
|
|
if attention_mask is None: |
|
|
text_sample_lengths = [(_x.shape[0], True) for _x in inputs_embeds] |
|
|
else: |
|
|
text_sample_lengths = [(int(torch.sum(_x).item()), True) for _x in attention_mask] |
|
|
( |
|
|
full_inputs_embeds, |
|
|
full_sample_lengths, |
|
|
|
|
|
|
|
|
_, |
|
|
self.image_tokens_mask, |
|
|
) = insert_image_tokens( |
|
|
inputs_embeds=inputs_embeds, |
|
|
image_embeds=image_embeds, |
|
|
image_embeds_insertion_points=image_embeds_insertion_points, |
|
|
attention_mask=attention_mask, |
|
|
recover_batch_dim=False, |
|
|
keep_only_attended=attention_mask is not None, |
|
|
) |
|
|
assert self.image_tokens_mask.ndim == 2 |
|
|
self.image_embeds = image_embeds |
|
|
self.image_embeds_insertion_points = image_embeds_insertion_points |
|
|
self.attention_mask = None if attention_mask is None else attention_mask.bool() |
|
|
self.use_asymetric_qkv = use_asymetric_q_kv |
|
|
|
|
|
if self.attention_mask is not None: |
|
|
self.use_asymetric_qkv = True |
|
|
|
|
|
|
|
|
assert casa_windows_info is not None |
|
|
text_sample_lengths, full_sample_lengths = get_sample_lengths_from_insertion_points( |
|
|
image_embeds_insertion_points=image_embeds_insertion_points, |
|
|
image_embeds=image_embeds, |
|
|
total_seq_len=inputs_embeds.shape[1], |
|
|
attention_mask=self.attention_mask, |
|
|
**casa_windows_info, |
|
|
) |
|
|
|
|
|
|
|
|
self.text_sample_lengths = [(int(s), eob) for s, eob in text_sample_lengths if s > 0] |
|
|
self.full_sample_lengths = [int(s) for s in full_sample_lengths if s > 0] |
|
|
|
|
|
assert len(self.text_sample_lengths) == len(self.full_sample_lengths), ( |
|
|
f"Sanity check failed; text sample lengths {len(self.text_sample_lengths)}" |
|
|
f" != full sample lengths {len(self.full_sample_lengths)}" |
|
|
) |
|
|
if self.attention_mask is None: |
|
|
num_unpadded_text_tokens = inputs_embeds.shape[0] * inputs_embeds.shape[1] |
|
|
else: |
|
|
num_unpadded_text_tokens = int( |
|
|
torch.sum(type_cast(torch.Tensor, attention_mask)).item() |
|
|
) |
|
|
assert sum(_x[0] for _x in self.text_sample_lengths) == num_unpadded_text_tokens, ( |
|
|
f"Sanity check failed; sample lengths {sum(self.full_sample_lengths)} != {full_inputs_embeds.shape[0]}" |
|
|
) |
|
|
assert sum(self.full_sample_lengths) == full_inputs_embeds.shape[0], ( |
|
|
f"Sanity check failed; sample lengths {sum(self.full_sample_lengths)} != {full_inputs_embeds.shape[0]}" |
|
|
) |
|
|
|
|
|
|
|
|
self.max_seqlen_q = max(self.text_sample_lengths)[0] |
|
|
self.cu_seqlens_q = self.get_cu_seqlens( |
|
|
[x[0] for x in self.text_sample_lengths], device=inputs_embeds.device |
|
|
) |
|
|
|
|
|
self.max_seqlen_kv = max(self.full_sample_lengths) |
|
|
self.cu_seqlens_kv = self.get_cu_seqlens( |
|
|
self.full_sample_lengths, device=inputs_embeds.device |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.current_doc_lengths = self.full_sample_lengths |
|
|
|
|
|
|
|
|
self.position_embeds = None |
|
|
self.rope_fn = rope_fn |
|
|
if self.rope_fn is not None: |
|
|
self.position_embeds = self.compute_position_embeddings( |
|
|
self.rope_fn, full_sample_lengths, dummy_for_dtype_and_device=full_inputs_embeds |
|
|
) |
|
|
|
|
|
@property |
|
|
def batch_lengths(self) -> list[int]: |
|
|
"""Return a (batch_size,) list of integers containing the |
|
|
number of (non-padded) text tokens for each sample in the batch""" |
|
|
bls = [0] |
|
|
for ln, eob in self.text_sample_lengths: |
|
|
bls[-1] += ln |
|
|
if eob: |
|
|
bls.append(0) |
|
|
return bls[:-1] |
|
|
|
|
|
@property |
|
|
def full_batch_lengths(self) -> list[int]: |
|
|
"""Same as batch_lengths for text+image tokens""" |
|
|
bls = [0] |
|
|
for (_, eob), ln in zip(self.text_sample_lengths, self.full_sample_lengths): |
|
|
bls[-1] += ln |
|
|
if eob: |
|
|
bls.append(0) |
|
|
return bls[:-1] |
|
|
|
|
|
def get_cu_seqlens( |
|
|
self, sample_lengths: list[int], device: torch.device | None |
|
|
) -> torch.Tensor: |
|
|
"""Update cu_seqlengths according to the given sample_lengths""" |
|
|
return torch.Tensor(list(accumulate(sample_lengths, initial=0))).to( |
|
|
dtype=torch.int32, device=device |
|
|
) |
|
|
|
|
|
def compute_position_embeddings( |
|
|
self, |
|
|
rope_fn: Callable, |
|
|
sample_lengths: list[int], |
|
|
dummy_for_dtype_and_device: torch.Tensor, |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Compute info required for position embeddings. Can be override e.g. for Qwen""" |
|
|
|
|
|
|
|
|
|
|
|
position_ids = torch.cat([torch.arange(0, lg) for lg in sample_lengths], dim=0) |
|
|
return rope_fn( |
|
|
dummy_for_dtype_and_device, |
|
|
position_ids.to(dummy_for_dtype_and_device.device)[None, ...], |
|
|
) |
|
|
|
|
|
def get_position_embedding( |
|
|
self, |
|
|
key: Literal["q", "kv"], |
|
|
num_queries: int = 0, |
|
|
) -> tuple[torch.Tensor, torch.Tensor] | None: |
|
|
if self.position_embeds is None: |
|
|
return None |
|
|
cos, sin = self.position_embeds |
|
|
bls = self.full_batch_lengths |
|
|
|
|
|
if key == "q" and self.use_asymetric_qkv: |
|
|
bls = self.batch_lengths |
|
|
cos, sin = cos[:, ~self.image_tokens_mask[:, 0]], sin[:, ~self.image_tokens_mask[:, 0]] |
|
|
elif key not in {"q", "kv"}: |
|
|
raise ValueError(f"Unknow for position embedding {key}") |
|
|
|
|
|
|
|
|
if num_queries == 0: |
|
|
return cos, sin |
|
|
|
|
|
cos = [x[:, -num_queries:] for x in torch.split(cos, bls, dim=1)] |
|
|
sin = [x[:, -num_queries:] for x in torch.split(sin, bls, dim=1)] |
|
|
return torch.cat(cos, dim=1), torch.cat(sin, dim=1) |
|
|
|
|
|
def get_full_embeds( |
|
|
self, hidden_states: torch.Tensor, norm_fn: Callable | None |
|
|
) -> torch.Tensor: |
|
|
"""Update attended hidden states in the current query buffer |
|
|
|
|
|
:param hidden_states: (b, s, d) Tensor input to the CASA attention layer" |
|
|
""" |
|
|
assert self.image_embeds is not None |
|
|
return insert_image_tokens( |
|
|
inputs_embeds=hidden_states, |
|
|
image_embeds=self.image_embeds |
|
|
if norm_fn is None |
|
|
else norm_fn(self.image_embeds) |
|
|
if isinstance(self.image_embeds, torch.Tensor) |
|
|
else [norm_fn(_x) for _x in self.image_embeds], |
|
|
image_embeds_insertion_points=self.image_embeds_insertion_points, |
|
|
attention_mask=self.attention_mask, |
|
|
recover_batch_dim=False, |
|
|
keep_only_attended=self.attention_mask is not None, |
|
|
)[0][None, :, :] |
|
|
|
|
|
def recover_text_embeds( |
|
|
self, |
|
|
hidden_states_out: torch.Tensor, |
|
|
hidden_states_in: torch.Tensor, |
|
|
update_image_embeddings: bool = False, |
|
|
) -> torch.Tensor: |
|
|
"""Returns text embeddings from the query buffer, including non-attended tokens at inference""" |
|
|
if update_image_embeddings and not self.use_asymetric_qkv: |
|
|
raise NotImplementedError("Implement image embeddings updates for asymetric QKV") |
|
|
|
|
|
if not self.use_asymetric_qkv: |
|
|
hidden_states_out = hidden_states_out[~self.image_tokens_mask[:, 0]] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.attention_mask is None: |
|
|
return hidden_states_out |
|
|
|
|
|
|
|
|
|
|
|
num_queries = hidden_states_in.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
if hidden_states_out.shape[0] != hidden_states_in.shape[0] * hidden_states_in.shape[1]: |
|
|
s = torch.split(hidden_states_out, self.batch_lengths, dim=0) |
|
|
assert max(_s.shape[0] for _s in s) <= num_queries |
|
|
s = [ |
|
|
torch.nn.functional.pad(_s, (0, 0, num_queries - _s.shape[0], 0), value=0) |
|
|
for _s in s |
|
|
] |
|
|
return torch.where( |
|
|
self.attention_mask[:, -num_queries:, None], |
|
|
torch.stack(s), |
|
|
hidden_states_in, |
|
|
) |
|
|
|
|
|
|
|
|
return hidden_states_out |
|
|
|
|
|
def extend(self, num_tokens: int, offset: int = 0): |
|
|
"""Extend all necessary values of the Handler for infenrece |
|
|
Note: this implementation curently assumes a single conversation at a time |
|
|
(otherwise image tokens mask would have to change) and that tokens added are |
|
|
attended to""" |
|
|
|
|
|
self.image_embeds = None |
|
|
|
|
|
|
|
|
if self.attention_mask is not None: |
|
|
self.attention_mask = torch.nn.functional.pad( |
|
|
self.attention_mask, (0, num_tokens), value=1 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
imtokmask = torch.split(self.image_tokens_mask, self.full_batch_lengths, dim=0) |
|
|
imtokmask = [torch.nn.functional.pad(x, (0, 0, 0, num_tokens), value=0) for x in imtokmask] |
|
|
self.image_tokens_mask = torch.cat(imtokmask, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
for idx, (ln, is_eob) in enumerate(self.text_sample_lengths): |
|
|
if is_eob: |
|
|
self.text_sample_lengths[idx] = (num_tokens + ln, is_eob) |
|
|
self.full_sample_lengths[idx] += num_tokens |
|
|
|
|
|
|
|
|
|
|
|
if offset == 0: |
|
|
self.max_seqlen_q = max(self.text_sample_lengths)[0] |
|
|
self.cu_seqlens_q = self.get_cu_seqlens( |
|
|
[x[0] for x in self.text_sample_lengths], device=self.cu_seqlens_q.device |
|
|
) |
|
|
|
|
|
self.max_seqlen_kv = max(self.full_sample_lengths) |
|
|
self.cu_seqlens_kv = self.get_cu_seqlens( |
|
|
self.full_sample_lengths, device=self.cu_seqlens_kv.device |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
self.max_seqlen_q = num_tokens |
|
|
self.cu_seqlens_q = self.get_cu_seqlens( |
|
|
[num_tokens for (_, eob) in self.text_sample_lengths if eob], |
|
|
device=self.cu_seqlens_q.device, |
|
|
) |
|
|
|
|
|
final_doc_lengths = [ |
|
|
ln |
|
|
for (_, eob), ln in zip(self.text_sample_lengths, self.full_sample_lengths) |
|
|
if eob |
|
|
] |
|
|
self.current_doc_lengths = final_doc_lengths |
|
|
self.max_seqlen_kv = max(self.current_doc_lengths) |
|
|
self.cu_seqlens_kv = self.get_cu_seqlens( |
|
|
final_doc_lengths, |
|
|
device=self.cu_seqlens_kv.device, |
|
|
) |
|
|
|
|
|
if self.rope_fn is not None and self.position_embeds is not None: |
|
|
self.position_embeds = self.compute_position_embeddings( |
|
|
self.rope_fn, |
|
|
self.full_sample_lengths, |
|
|
dummy_for_dtype_and_device=self.position_embeds[0], |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CASAAttentionStreamingState(StreamingState): |
|
|
"""Streaming State for CASA Atention module. Keep the hidden""" |
|
|
|
|
|
k: torch.Tensor = None |
|
|
v: torch.Tensor = None |
|
|
recover_batched_trims: list[int] = None |
|
|
casa_handler: CASAAttentionHandler = None |
|
|
|
|
|
def maybe_get_casa_handler( |
|
|
self, |
|
|
casa_handler: CASAAttentionHandler | None, |
|
|
is_first_casa_layer: bool = False, |
|
|
num_queries: int = -1, |
|
|
) -> CASAAttentionHandler | None: |
|
|
|
|
|
if self.casa_handler is None: |
|
|
self.casa_handler = casa_handler |
|
|
|
|
|
|
|
|
if self.casa_handler is not None and self.offset > 0 and is_first_casa_layer: |
|
|
|
|
|
self.casa_handler.extend(num_queries, offset=self.offset) |
|
|
return self.casa_handler |
|
|
|
|
|
def __recover_batched_kv__(self, states: torch.Tensor) -> torch.Tensor: |
|
|
"""Recover batched key/value states with left padding""" |
|
|
s = torch.split(states, self.casa_handler.full_batch_lengths, dim=1) |
|
|
mlen = max(_s.shape[1] for _s in s) |
|
|
|
|
|
if self.recover_batched_trims is None: |
|
|
self.recover_batched_trims = [mlen - _s.shape[1] for _s in s] |
|
|
s = [torch.nn.functional.pad(_s, (0, 0, 0, 0, mlen - _s.shape[1], 0), value=0) for _s in s] |
|
|
return torch.cat(s, dim=0) |
|
|
|
|
|
def __get_flattened_kv__( |
|
|
self, k: torch.Tensor | None = None, v: torch.Tensor | None = None |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Flattened and remove padding to act with flash_attn_func |
|
|
""" |
|
|
k = self.k if k is None else k |
|
|
v = self.v if v is None else v |
|
|
assert k is not None and v is not None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(self.casa_handler.current_doc_lengths) == len(k): |
|
|
k = torch.cat( |
|
|
[ |
|
|
_k[self.recover_batched_trims[idx] :][-doc_len:] |
|
|
for idx, _k, doc_len in zip( |
|
|
range(len(k)), k, self.casa_handler.current_doc_lengths |
|
|
) |
|
|
] |
|
|
) |
|
|
v = torch.cat( |
|
|
[ |
|
|
_v[self.recover_batched_trims[idx] :][-doc_len:] |
|
|
for idx, _v, doc_len in zip( |
|
|
range(len(k)), v, self.casa_handler.current_doc_lengths |
|
|
) |
|
|
] |
|
|
) |
|
|
return k[None, ...], v[None, ...] |
|
|
|
|
|
k = torch.cat([_k[self.recover_batched_trims[idx] :] for idx, _k in enumerate(k)]) |
|
|
v = torch.cat([_v[self.recover_batched_trims[idx] :] for idx, _v in enumerate(v)]) |
|
|
return k[None, ...], v[None, ...] |
|
|
|
|
|
def extend_kv( |
|
|
self, key_states: torch.Tensor, value_states: torch.Tensor |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Extend KV Cache while keep |
|
|
""" |
|
|
assert self.casa_handler is not None |
|
|
if self.k is None and self.v is None: |
|
|
|
|
|
self.k = self.__recover_batched_kv__(key_states) |
|
|
self.v = self.__recover_batched_kv__(value_states) |
|
|
return self.__get_flattened_kv__() |
|
|
if self.k is not None and self.v is not None: |
|
|
|
|
|
|
|
|
rshp = (self.k.shape[0], -1, self.k.shape[2], self.k.shape[3]) |
|
|
self.k = torch.cat([self.k, key_states.reshape(rshp)], dim=1) |
|
|
self.v = torch.cat([self.v, value_states.reshape(rshp)], dim=1) |
|
|
return self.__get_flattened_kv__() |
|
|
|
|
|
raise ValueError("Impossible configuration (k and v updates are desynchronized )") |
|
|
|
|
|
|
|
|
class CASAAttention(StreamingModule[CASAAttentionStreamingState]): |
|
|
def __init__( |
|
|
self, |
|
|
config: "PretrainedConfig", |
|
|
layer_idx: int | None, |
|
|
self_attn: torch.nn.Module | None = None, |
|
|
input_layernorm_fn: Callable[[torch.Tensor], torch.Tensor] | None = None, |
|
|
): |
|
|
super().__init__(CASAAttentionStreamingState) |
|
|
self.head_dim = config.head_dim |
|
|
self.config = config |
|
|
|
|
|
self.is_first_casa_layer = layer_idx == (min(config.xa_layers) if config.xa_layers else 0) |
|
|
self.use_delta_w = config.casa_delta_w |
|
|
|
|
|
self.q_proj_casa = self.init_from_config_proj("q", config) |
|
|
self.k_proj_casa = self.init_from_config_proj("k", config) |
|
|
self.v_proj_casa = self.init_from_config_proj("v", config) |
|
|
self.o_proj_casa = self.init_from_config_proj("o", config) |
|
|
|
|
|
|
|
|
self.override_q_proj: Callable[[torch.Tensor], torch.Tensor] | None = None |
|
|
self.override_k_proj: Callable[[torch.Tensor], torch.Tensor] | None = None |
|
|
self.override_v_proj: Callable[[torch.Tensor], torch.Tensor] | None = None |
|
|
self.override_o_proj: Callable[[torch.Tensor], torch.Tensor] | None = None |
|
|
|
|
|
if config.casa_delta_w: |
|
|
assert self_attn is not None |
|
|
self.set_delta_w(self_attn) |
|
|
|
|
|
|
|
|
self.norm_fn: Callable | None = None |
|
|
if config.xa_norm_on_images: |
|
|
assert input_layernorm_fn is not None |
|
|
self.norm_fn = input_layernorm_fn |
|
|
|
|
|
def init_from_mha(self, self_attn: torch.nn.Module): |
|
|
assert self_attn is not None |
|
|
with torch.no_grad(): |
|
|
assert hasattr(self_attn, "q_proj") |
|
|
for key in ["q", "k", "v", "o"]: |
|
|
src = type_cast(torch.nn.Linear, getattr(self_attn, f"{key}_proj")) |
|
|
tgt = type_cast(torch.nn.Linear, getattr(self, f"{key}_proj_casa")) |
|
|
tgt.weight.copy_(src.weight) |
|
|
if tgt.bias is not None and src.bias is not None: |
|
|
tgt.bias.copy_(src.bias) |
|
|
|
|
|
def set_delta_w(self, self_attn: torch.nn.Module): |
|
|
"""Delta w setup""" |
|
|
self.override_q_proj = delta_w_factory( |
|
|
self.q_proj_casa, type_cast(torch.nn.Linear, self_attn.q_proj) |
|
|
) |
|
|
self.override_k_proj = delta_w_factory( |
|
|
self.k_proj_casa, type_cast(torch.nn.Linear, self_attn.k_proj) |
|
|
) |
|
|
self.override_v_proj = delta_w_factory( |
|
|
self.v_proj_casa, type_cast(torch.nn.Linear, self_attn.v_proj) |
|
|
) |
|
|
self.override_o_proj = delta_w_factory( |
|
|
self.o_proj_casa, type_cast(torch.nn.Linear, self_attn.o_proj) |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
torch.nn.init.zeros_(self.q_proj_casa.weight) |
|
|
torch.nn.init.zeros_(self.k_proj_casa.weight) |
|
|
torch.nn.init.zeros_(self.v_proj_casa.weight) |
|
|
torch.nn.init.zeros_(self.o_proj_casa.weight) |
|
|
if self.q_proj_casa.bias is not None: |
|
|
torch.nn.init.zeros_(self.q_proj_casa.bias) |
|
|
if self.k_proj_casa.bias is not None: |
|
|
torch.nn.init.zeros_(self.k_proj_casa.bias) |
|
|
if self.v_proj_casa.bias is not None: |
|
|
torch.nn.init.zeros_(self.v_proj_casa.bias) |
|
|
if self.o_proj_casa.bias is not None: |
|
|
torch.nn.init.zeros_(self.o_proj_casa.bias) |
|
|
|
|
|
def init_from_config_proj( |
|
|
self, key: Literal["q", "o", "k", "v"], config: PretrainedConfig |
|
|
) -> torch.nn.Linear: |
|
|
"""Initialize the Linear proj in this module""" |
|
|
raise NotImplementedError("Abastract class.") |
|
|
|
|
|
def apply_position_embeddings( |
|
|
self, |
|
|
key: Literal["q", "kv"], |
|
|
x: torch.Tensor, |
|
|
casa_handler: CASAAttentionHandler | None, |
|
|
num_queries: int = 0, |
|
|
unsqueeze_dim: int = 1, |
|
|
) -> torch.Tensor: |
|
|
"""Apply position embeddings to query and key states""" |
|
|
raise NotImplementedError("Abastract class.") |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
casa_handler: CASAAttentionHandler | None, |
|
|
) -> torch.Tensor | None: |
|
|
"""Generic forward for CASA uses for instance in `helium1_attention`""" |
|
|
og_dtype = hidden_states.dtype |
|
|
if self.is_streaming: |
|
|
casa_handler = self.streaming_state.maybe_get_casa_handler( |
|
|
casa_handler, |
|
|
is_first_casa_layer=self.is_first_casa_layer, |
|
|
num_queries=hidden_states.shape[1], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if casa_handler is None: |
|
|
return None |
|
|
|
|
|
if self.is_streaming: |
|
|
assert casa_handler.use_asymetric_qkv, ( |
|
|
"You should set `use_asymetric_qkv` to True during inference" |
|
|
) |
|
|
|
|
|
og_shape = hidden_states.shape |
|
|
|
|
|
|
|
|
if casa_handler.use_asymetric_qkv: |
|
|
q_inputs = hidden_states.flatten(0, 1)[None, ...] |
|
|
if casa_handler.attention_mask is not None: |
|
|
q_inputs = q_inputs[:, casa_handler.attention_mask[:, -og_shape[1] :].flatten()] |
|
|
else: |
|
|
q_inputs = casa_handler.get_full_embeds(hidden_states, norm_fn=self.norm_fn) |
|
|
|
|
|
|
|
|
if not self.is_streaming or self.streaming_state.offset == 0: |
|
|
kv_inputs = casa_handler.get_full_embeds(hidden_states, norm_fn=self.norm_fn) |
|
|
else: |
|
|
|
|
|
|
|
|
kv_inputs = q_inputs |
|
|
|
|
|
|
|
|
bs, total_seq_len = kv_inputs.shape[:2] |
|
|
hidden_shape_q = (bs, q_inputs.shape[1], -1, self.head_dim) |
|
|
hidden_shape_kv = (bs, total_seq_len, -1, self.head_dim) |
|
|
|
|
|
if self.override_q_proj is None: |
|
|
query_states = self.q_proj_casa(q_inputs).view(*hidden_shape_q) |
|
|
else: |
|
|
query_states = self.override_q_proj(q_inputs).view(*hidden_shape_q) |
|
|
|
|
|
if self.override_k_proj is None: |
|
|
key_states = self.k_proj_casa(kv_inputs).view(*hidden_shape_kv) |
|
|
else: |
|
|
key_states = self.override_k_proj(kv_inputs).view(*hidden_shape_kv) |
|
|
|
|
|
if self.override_v_proj is None: |
|
|
value_states = self.v_proj_casa(kv_inputs).view(*hidden_shape_kv) |
|
|
else: |
|
|
value_states = self.override_v_proj(kv_inputs).view(*hidden_shape_kv) |
|
|
|
|
|
|
|
|
num_queries = 0 |
|
|
if self.streaming and self.streaming_state.offset > 0: |
|
|
num_queries = og_shape[1] |
|
|
|
|
|
query_states = self.apply_position_embeddings( |
|
|
"q", query_states, num_queries=num_queries, casa_handler=casa_handler |
|
|
) |
|
|
key_states = self.apply_position_embeddings( |
|
|
"kv", key_states, num_queries=num_queries, casa_handler=casa_handler |
|
|
) |
|
|
assert flash_attn_varlen_func is not None, ( |
|
|
"flash_attention is not installed but required for block-wise attention" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.is_streaming: |
|
|
key_states, value_states = self.streaming_state.extend_kv( |
|
|
key_states=key_states, value_states=value_states |
|
|
) |
|
|
if casa_handler.use_asymetric_qkv: |
|
|
cu_seqlens_q = casa_handler.cu_seqlens_q |
|
|
max_seqlen_q = casa_handler.max_seqlen_q |
|
|
else: |
|
|
cu_seqlens_q = casa_handler.cu_seqlens_kv |
|
|
max_seqlen_q = casa_handler.max_seqlen_kv |
|
|
assert cu_seqlens_q[-1] == query_states.shape[1], ( |
|
|
f"{cu_seqlens_q[-1]} != {query_states.shape[1]}" |
|
|
) |
|
|
assert casa_handler.cu_seqlens_kv[-1] == key_states.shape[1], ( |
|
|
f"{casa_handler.cu_seqlens_kv[-1]} != {key_states.shape[1]}" |
|
|
) |
|
|
|
|
|
attn_output: torch.Tensor = flash_attn_varlen_func( |
|
|
query_states[0].to(torch.bfloat16), |
|
|
key_states[0].to(torch.bfloat16), |
|
|
value_states[0].to(torch.bfloat16), |
|
|
cu_seqlens_q=cu_seqlens_q, |
|
|
cu_seqlens_k=casa_handler.cu_seqlens_kv, |
|
|
max_seqlen_q=max_seqlen_q, |
|
|
max_seqlen_k=casa_handler.max_seqlen_kv, |
|
|
dropout_p=0.0, |
|
|
|
|
|
causal=True, |
|
|
).to(og_dtype) |
|
|
|
|
|
attn_output = attn_output.reshape(hidden_shape_q[1], -1).contiguous() |
|
|
if self.override_o_proj is None: |
|
|
attn_output = self.o_proj_casa(attn_output) |
|
|
else: |
|
|
attn_output = self.override_o_proj(attn_output) |
|
|
|
|
|
attn_output = casa_handler.recover_text_embeds( |
|
|
attn_output, hidden_states, update_image_embeddings=self.config.xa_update_image_embeds |
|
|
) |
|
|
attn_output = attn_output.reshape(og_shape) |
|
|
|
|
|
if self.is_streaming: |
|
|
self.streaming_state.offset += attn_output.shape[1] |
|
|
return attn_output |
|
|
|