|
|
import os |
|
|
from abc import ABCMeta, abstractmethod |
|
|
from typing import Optional, Union, Dict, List |
|
|
from termcolor import colored |
|
|
import random |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoProcessor, |
|
|
AutoTokenizer, |
|
|
LlavaConfig, |
|
|
LlamaForCausalLM, |
|
|
) |
|
|
from torchvision.transforms.v2 import ( |
|
|
ToPILImage, |
|
|
) |
|
|
import decord |
|
|
from decord import VideoReader |
|
|
|
|
|
decord.bridge.set_bridge("torch") |
|
|
|
|
|
|
|
|
from tarsier.modeling_tarsier import TarsierForConditionalGeneration |
|
|
from tarsier.processor import Processor |
|
|
|
|
|
|
|
|
|
|
|
EOL_PROMPTS = { |
|
|
'text': '<sent>\nSummary above sentence in one word:', |
|
|
'image': '<image>\nSummary above image in one word:', |
|
|
'video': '<video>\nSummary above video in one word:', |
|
|
} |
|
|
|
|
|
|
|
|
def transform_pixel_values(pixel_values: torch.Tensor | List[torch.Tensor]) -> torch.Tensor: |
|
|
|
|
|
|
|
|
if isinstance(pixel_values, list): |
|
|
pixel_values = torch.stack(pixel_values) |
|
|
|
|
|
if pixel_values.ndim == 4: |
|
|
|
|
|
|
|
|
pixel_values = pixel_values.unsqueeze(1) |
|
|
elif pixel_values.ndim == 5: |
|
|
|
|
|
pass |
|
|
else: |
|
|
raise ValueError(f"pixel_values should be 4D or 5D, got {pixel_values.ndim}D") |
|
|
return pixel_values |
|
|
|
|
|
|
|
|
base_registry = {} |
|
|
class BaseModel(metaclass=ABCMeta): |
|
|
def __init_subclass__(cls, **kwargs): |
|
|
super().__init_subclass__(**kwargs) |
|
|
|
|
|
if hasattr(cls, 'ARCHITECTURE'): |
|
|
base_registry[cls.ARCHITECTURE] = cls |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained( |
|
|
cls, |
|
|
model_name_or_path: str, |
|
|
load_llm: bool = False, |
|
|
device_map: Optional[Union[str, Dict[str, int]]] = None, |
|
|
**kwargs): |
|
|
print(colored(f'[ MODEL ] Loading {cls.__name__} from {model_name_or_path} [..............]', 'yellow')) |
|
|
|
|
|
return cls(model_name_or_path, load_llm=load_llm, device_map=device_map, **kwargs) |
|
|
|
|
|
|
|
|
class BaseModelForTARA(BaseModel): |
|
|
|
|
|
ARCHITECTURE = "TarsierForConditionalGeneration" |
|
|
LLM_CLASS = LlamaForCausalLM |
|
|
MLLM_CLASS = TarsierForConditionalGeneration |
|
|
|
|
|
@property |
|
|
def describe_prompt(self): |
|
|
return "Describe the video in detail." |
|
|
|
|
|
@property |
|
|
def text_eol_prompt(self): |
|
|
prompt = f'USER: {EOL_PROMPTS["text"]} ASSISTANT: ' |
|
|
return prompt |
|
|
|
|
|
@property |
|
|
def image_eol_prompt(self): |
|
|
prompt = f'USER: {EOL_PROMPTS["image"]} ASSISTANT: ' |
|
|
return prompt |
|
|
|
|
|
@property |
|
|
def video_eol_prompt(self): |
|
|
prompt = f'USER: {EOL_PROMPTS["video"]} ASSISTANT: ' |
|
|
return prompt |
|
|
|
|
|
@property |
|
|
def video_edit_eol_prompt(self): |
|
|
prompt = "Source video: <video>\nEdit instruction: <sent>\n"\ |
|
|
"Look at the attached video carefully. The provided text is instruction to edit the video. "\ |
|
|
"Imagine this edit instruction being applied to the provided video frame.\n"\ |
|
|
"Summarize the resulting edited video in one word:" |
|
|
prompt = f"USER: {prompt} ASSISTANT: " |
|
|
return prompt |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name_or_path: str, |
|
|
load_llm: Optional[bool] = None, |
|
|
device_map: Optional[Union[str, Dict[str, int]]] = None, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
MODEL_CLASS = self.LLM_CLASS if load_llm else self.MLLM_CLASS |
|
|
|
|
|
if load_llm: |
|
|
self.split_weights(model_name_or_path, model_name_or_path + '-llm') |
|
|
model_name_or_path += '-llm' |
|
|
model_config = None |
|
|
self.processor = AutoProcessor.from_pretrained(model_name_or_path, use_fast=False) |
|
|
else: |
|
|
model_config = LlavaConfig.from_pretrained( |
|
|
model_name_or_path, |
|
|
|
|
|
) |
|
|
self.processor = Processor( |
|
|
model_name_or_path, |
|
|
max_n_frames=32, |
|
|
) |
|
|
|
|
|
self.tokenizer = self.processor.tokenizer |
|
|
|
|
|
self.model = MODEL_CLASS.from_pretrained( |
|
|
model_name_or_path, |
|
|
config=model_config, |
|
|
torch_dtype=kwargs.get("torch_dtype", torch.bfloat16), |
|
|
device_map=device_map, |
|
|
|
|
|
) |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
def split_weights(self, mllm_path, llm_path): |
|
|
if os.path.exists(llm_path): |
|
|
print(f'{llm_path} already exists. Skip splitting weights.') |
|
|
return |
|
|
print('Splitting LLM weights from MLLM.') |
|
|
model = self.MLLM_CLASS.from_pretrained(mllm_path) |
|
|
llm = model.language_model |
|
|
processor = AutoProcessor.from_pretrained(mllm_path) |
|
|
tokenizer = AutoTokenizer.from_pretrained(mllm_path) |
|
|
llm.save_pretrained(llm_path) |
|
|
processor.save_pretrained(llm_path) |
|
|
tokenizer.save_pretrained(llm_path) |
|
|
|
|
|
|
|
|
encoder_registry = {} |
|
|
class EncodeMixin(metaclass=ABCMeta): |
|
|
def __init_subclass__(cls, **kwargs): |
|
|
super().__init_subclass__(**kwargs) |
|
|
|
|
|
if hasattr(cls, 'ARCHITECTURE'): |
|
|
encoder_registry[cls.ARCHITECTURE] = cls |
|
|
|
|
|
@abstractmethod |
|
|
def encode_vision(self, pixel_values: torch.Tensor | List[torch.Tensor]) -> torch.Tensor: |
|
|
""" |
|
|
Encodes vision data (images or videos) into a tensor representation. |
|
|
|
|
|
Args: |
|
|
pixel_values (torch.Tensor | List[torch.Tensor]): The input pixel values. |
|
|
- If a tensor, it should be of shape (B, C, H, W) for images or (B, T, C, H, W) for videos. |
|
|
- If a list, it will be stacked into a tensor. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The encoded tensor representation of the input vision data. |
|
|
|
|
|
Raises: |
|
|
ValueError: If `pixel_values` is not 4D or 5D. |
|
|
|
|
|
## Notes: |
|
|
- This function does not accept unbatched inputs. |
|
|
- `pixel_values` should be of type uint8. |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
@abstractmethod |
|
|
def encode_text(self, text: str | List[str]) -> torch.Tensor: |
|
|
""" |
|
|
Encodes the given text(s) into a tensor representation using the model. |
|
|
|
|
|
Args: |
|
|
text (str | List[str]): A single string or a list of strings to be encoded. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The tensor representation of the encoded text(s). |
|
|
|
|
|
## Notes: |
|
|
- The method uses a prompt to encode the text. |
|
|
- If a single string is provided, it is converted into a list containing that string. |
|
|
- The method processes the prompts and generates the tensor representation using the model. |
|
|
- The output tensor contains the hidden states of the last token for each input text. |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class TARA(BaseModelForTARA, EncodeMixin): |
|
|
|
|
|
def encode_vision(self, pixel_values: torch.Tensor | List[torch.Tensor], edit_text: Optional[str] = None) -> torch.Tensor: |
|
|
|
|
|
pixel_values = transform_pixel_values(pixel_values) |
|
|
nframes = pixel_values.shape[1] |
|
|
|
|
|
if edit_text is not None: |
|
|
|
|
|
prompt = self.video_edit_eol_prompt.replace('<sent>', edit_text) |
|
|
else: |
|
|
prompt = self.image_eol_prompt if nframes == 1 else self.video_eol_prompt |
|
|
|
|
|
to_image = ToPILImage() |
|
|
batched_frames = [] |
|
|
for batch in pixel_values: |
|
|
frames = [to_image(v) for v in batch] |
|
|
batched_frames.append(frames) |
|
|
|
|
|
generate_kwargs = { |
|
|
"max_new_tokens": 1, |
|
|
"output_hidden_states": True, |
|
|
"return_dict_in_generate": True, |
|
|
} |
|
|
|
|
|
vision_embs = [] |
|
|
|
|
|
for frames in batched_frames: |
|
|
input_prompt = prompt.replace("<video>", "<image>"*len(frames)) |
|
|
input_ids = self.processor.get_text_inputs(input_prompt) |
|
|
frames = self.processor.get_pixel_values(frames) |
|
|
inputs = { |
|
|
"input_ids": input_ids, |
|
|
"pixel_values": frames |
|
|
} |
|
|
inputs = {k:v.to(self.model.device) for k,v in inputs.items() if v is not None} |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
**generate_kwargs, |
|
|
) |
|
|
vision_embs.append(outputs.hidden_states[0][-1][:, -1, :]) |
|
|
|
|
|
vision_embs = torch.cat(vision_embs) |
|
|
return vision_embs |
|
|
|
|
|
def encode_text(self, text: str | List[str]) -> torch.Tensor: |
|
|
|
|
|
prompt = self.text_eol_prompt |
|
|
|
|
|
if isinstance(text, str): |
|
|
text = [text] |
|
|
|
|
|
prompts = [prompt.replace('<sent>', t) for t in text] |
|
|
|
|
|
generate_kwargs = { |
|
|
"max_new_tokens": 1, |
|
|
"output_hidden_states": True, |
|
|
"return_dict_in_generate": True, |
|
|
} |
|
|
|
|
|
text_embs = [] |
|
|
|
|
|
for p in prompts: |
|
|
text_inputs = self.processor.get_text_inputs(p) |
|
|
inputs = { |
|
|
"input_ids": text_inputs, |
|
|
} |
|
|
inputs = {k:v.to(self.model.device) for k,v in inputs.items() if v is not None} |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
**generate_kwargs, |
|
|
) |
|
|
text_embs.append(outputs.hidden_states[0][-1][:, -1, :]) |
|
|
|
|
|
text_embs = torch.cat(text_embs) |
|
|
return text_embs |
|
|
|
|
|
def describe(self, pixel_values: torch.Tensor | List[torch.Tensor]) -> List[str]: |
|
|
|
|
|
pixel_values = transform_pixel_values(pixel_values) |
|
|
to_image = ToPILImage() |
|
|
batched_frames = [] |
|
|
for batch in pixel_values: |
|
|
frames = [to_image(v) for v in batch] |
|
|
batched_frames.append(frames) |
|
|
descriptions = [] |
|
|
generate_kwargs = { |
|
|
"do_sample": False, |
|
|
"max_new_tokens": 2048, |
|
|
"top_p": 1, |
|
|
"temperature": 0, |
|
|
"use_cache": True |
|
|
} |
|
|
|
|
|
for frames in batched_frames: |
|
|
text_inputs = f"<video>\n{self.describe_prompt}" |
|
|
text_inputs = self.processor.process_prompt(text_inputs, frames) |
|
|
text_inputs = self.processor.get_text_inputs(text_inputs) |
|
|
frames = self.processor.get_pixel_values(frames) |
|
|
inputs = { |
|
|
"input_ids": text_inputs, |
|
|
"pixel_values": frames |
|
|
} |
|
|
inputs = {k:v.to(self.model.device) for k,v in inputs.items() if v is not None} |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
**generate_kwargs, |
|
|
) |
|
|
output_text = self.processor.tokenizer.decode(outputs[0][inputs['input_ids'][0].shape[0]:], skip_special_tokens=True) |
|
|
descriptions.append(output_text) |
|
|
return descriptions |
|
|
|
|
|
|
|
|
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1): |
|
|
if sample in ["rand", "middle"]: |
|
|
acc_samples = min(num_frames, vlen) |
|
|
|
|
|
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) |
|
|
ranges = [] |
|
|
for idx, interv in enumerate(intervals[:-1]): |
|
|
ranges.append((interv, intervals[idx + 1] - 1)) |
|
|
if sample == 'rand': |
|
|
try: |
|
|
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] |
|
|
except (ValueError, IndexError): |
|
|
frame_indices = np.random.permutation(vlen)[:acc_samples] |
|
|
frame_indices.sort() |
|
|
frame_indices = list(frame_indices) |
|
|
elif fix_start is not None: |
|
|
frame_indices = [x[0] + fix_start for x in ranges] |
|
|
elif sample == 'middle': |
|
|
frame_indices = [(x[0] + x[1]) // 2 for x in ranges] |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
if len(frame_indices) < num_frames: |
|
|
padded_frame_indices = [frame_indices[-1]] * num_frames |
|
|
padded_frame_indices[:len(frame_indices)] = frame_indices |
|
|
frame_indices = padded_frame_indices |
|
|
elif "fps" in sample: |
|
|
output_fps = float(sample[3:]) |
|
|
duration = float(vlen) / input_fps |
|
|
delta = 1 / output_fps |
|
|
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) |
|
|
frame_indices = np.around(frame_seconds * input_fps).astype(int) |
|
|
frame_indices = [e for e in frame_indices if e < vlen] |
|
|
if max_num_frames > 0 and len(frame_indices) > max_num_frames: |
|
|
frame_indices = frame_indices[:max_num_frames] |
|
|
|
|
|
else: |
|
|
raise ValueError |
|
|
return frame_indices |
|
|
|
|
|
|
|
|
def read_frames_decord( |
|
|
video_path, num_frames, sample='middle', fix_start=None, |
|
|
max_num_frames=-1, trimmed30=False, height=-1, width=-1 |
|
|
): |
|
|
decord.bridge.set_bridge('torch') |
|
|
|
|
|
|
|
|
num_threads = 1 |
|
|
video_reader = VideoReader(video_path, num_threads=num_threads, height=height, width=width) |
|
|
try: |
|
|
vlen = len(video_reader) |
|
|
|
|
|
fps = video_reader.get_avg_fps() |
|
|
duration = vlen / float(fps) |
|
|
|
|
|
|
|
|
if trimmed30 and duration > 30: |
|
|
duration = 30 |
|
|
vlen = int(30 * float(fps)) |
|
|
|
|
|
frame_indices = get_frame_indices( |
|
|
num_frames, vlen, sample=sample, fix_start=fix_start, |
|
|
input_fps=fps, max_num_frames=max_num_frames |
|
|
) |
|
|
|
|
|
frames = video_reader.get_batch(frame_indices) |
|
|
if not isinstance(frames, torch.Tensor): |
|
|
frames = torch.from_numpy(frames.asnumpy()) |
|
|
frames = frames.permute(0, 3, 1, 2) |
|
|
return frames |
|
|
finally: |
|
|
|
|
|
del video_reader |
|
|
|
|
|
|
|
|
import PIL.Image |
|
|
def read_image_decord(image_path): |
|
|
image = PIL.Image.open(image_path) |
|
|
image = image.convert('RGB') |
|
|
image = np.array(image) |
|
|
image = image.transpose(2, 0, 1) |
|
|
image = torch.from_numpy(image) |
|
|
image = image.unsqueeze(0) |
|
|
return image |
|
|
|
|
|
|
|
|
def read_images_decord(image_paths): |
|
|
images = [] |
|
|
for image_path in image_paths: |
|
|
image = read_image_decord(image_path) |
|
|
images.append(image) |
|
|
images = torch.cat(images) |
|
|
return images |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
model = TARA.from_pretrained( |
|
|
"/work/piyush/experiments/CaRe/Tarsier-7b/final-10112025/nli_9000+ego_1000+subj_replaced-seed_42/merged_checkpoint", |
|
|
device_map='auto', |
|
|
dtype=torch.bfloat16, |
|
|
) |
|
|
n_params = sum(p.numel() for p in model.model.parameters()) |
|
|
print(f"Number of parameters: {round(n_params/1e9, 3)}B") |
|
|
|
|
|
|
|
|
print(colored("Testing video encoding...", 'cyan')) |
|
|
video_path = "./assets/folding_paper.mp4" |
|
|
video_tensor = read_frames_decord(video_path, num_frames=16) |
|
|
video_tensor = video_tensor.unsqueeze(0) |
|
|
video_tensor = video_tensor.to(model.model.device) |
|
|
with torch.no_grad(): |
|
|
video_emb = model.encode_vision(video_tensor).cpu().squeeze(0).float() |
|
|
print("Video shape:", video_tensor.shape) |
|
|
print("Video embedding shape:", video_emb.shape) |
|
|
|
|
|
|
|
|
print(colored("Testing text encoding...", 'cyan')) |
|
|
text = ['someone is folding a paper', 'cutting a paper', 'someone is unfolding a paper'] |
|
|
|
|
|
with torch.no_grad(): |
|
|
text_emb = model.encode_text(text).cpu().float() |
|
|
print("Text:", text) |
|
|
print("Text embedding shape:", text_emb.shape) |
|
|
|
|
|
|