Upload demo_usage.py with huggingface_hub
Browse files- demo_usage.py +34 -7
demo_usage.py
CHANGED
|
@@ -2,6 +2,9 @@ import torch
|
|
| 2 |
from termcolor import colored
|
| 3 |
from modeling_tara import TARA, read_frames_decord, read_images_decord
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
def main(model_path: str = "."):
|
| 7 |
print(colored("="*60, 'yellow'))
|
|
@@ -9,7 +12,7 @@ def main(model_path: str = "."):
|
|
| 9 |
print(colored("="*60, 'yellow'))
|
| 10 |
|
| 11 |
# Load model from current directory
|
| 12 |
-
print(colored("\n[1/
|
| 13 |
model = TARA.from_pretrained(
|
| 14 |
model_path, # Load from current directory
|
| 15 |
device_map='auto',
|
|
@@ -19,11 +22,11 @@ def main(model_path: str = "."):
|
|
| 19 |
n_params = sum(p.numel() for p in model.model.parameters())
|
| 20 |
print(colored(f"β Model loaded successfully!", 'green'))
|
| 21 |
print(f"Number of parameters: {round(n_params/1e9, 3)}B")
|
|
|
|
| 22 |
|
| 23 |
# Encode a sample video
|
| 24 |
-
print(colored("\n[2/
|
| 25 |
video_path = "./assets/folding_paper.mp4"
|
| 26 |
-
|
| 27 |
try:
|
| 28 |
video_tensor = read_frames_decord(video_path, num_frames=16)
|
| 29 |
video_tensor = video_tensor.unsqueeze(0)
|
|
@@ -31,17 +34,22 @@ def main(model_path: str = "."):
|
|
| 31 |
|
| 32 |
with torch.no_grad():
|
| 33 |
video_emb = model.encode_vision(video_tensor).cpu().squeeze(0).float()
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
print(colored("β Video encoded successfully!", 'green'))
|
| 36 |
print(f"Video shape: {video_tensor.shape}") # torch.Size([1, 16, 3, 240, 426])
|
| 37 |
print(f"Video embedding shape: {video_emb.shape}") # torch.Size([4096])
|
|
|
|
| 38 |
except FileNotFoundError:
|
| 39 |
print(colored(f"β Video file not found: {video_path}", 'red'))
|
| 40 |
print(colored(" Please add a video file or update the path in demo_usage.py", 'yellow'))
|
| 41 |
video_emb = None
|
|
|
|
| 42 |
|
| 43 |
# Encode sample texts
|
| 44 |
-
print(colored("\n[3/
|
| 45 |
text = ['someone is folding a paper', 'cutting a paper', 'someone is unfolding a paper']
|
| 46 |
# NOTE: It can also take a single string
|
| 47 |
|
|
@@ -54,7 +62,7 @@ def main(model_path: str = "."):
|
|
| 54 |
|
| 55 |
# Compute similarities if video was encoded
|
| 56 |
if video_emb is not None:
|
| 57 |
-
print(colored("\n[
|
| 58 |
similarities = torch.cosine_similarity(
|
| 59 |
video_emb.unsqueeze(0).unsqueeze(0), # [1, 1, 4096]
|
| 60 |
text_emb.unsqueeze(0), # [1, 3, 4096]
|
|
@@ -68,6 +76,7 @@ def main(model_path: str = "."):
|
|
| 68 |
|
| 69 |
# Negation example: a negation in text query should result
|
| 70 |
# in retrieval of images without the neg. object in the query
|
|
|
|
| 71 |
image_paths = [
|
| 72 |
'./assets/cat.png',
|
| 73 |
'./assets/dog+cat.png',
|
|
@@ -85,7 +94,7 @@ def main(model_path: str = "."):
|
|
| 85 |
print("Text query: ", texts)
|
| 86 |
sim = text_embs @ image_embs.t()
|
| 87 |
print(f"Text-Image similarity: {sim}")
|
| 88 |
-
print("-" *
|
| 89 |
|
| 90 |
texts = ['an image of a cat and a dog together']
|
| 91 |
with torch.no_grad():
|
|
@@ -95,9 +104,27 @@ def main(model_path: str = "."):
|
|
| 95 |
sim = text_embs @ image_embs.t()
|
| 96 |
print(f"Text-Image similarity: {sim}")
|
| 97 |
print("-" * 100)
|
| 98 |
-
import ipdb; ipdb.set_trace()
|
| 99 |
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
print(colored("\n" + "="*60, 'yellow'))
|
| 102 |
print(colored("Demo completed successfully! π", 'green', attrs=['bold']))
|
| 103 |
print(colored("="*60, 'yellow'))
|
|
|
|
| 2 |
from termcolor import colored
|
| 3 |
from modeling_tara import TARA, read_frames_decord, read_images_decord
|
| 4 |
|
| 5 |
+
import warnings
|
| 6 |
+
warnings.filterwarnings("ignore")
|
| 7 |
+
|
| 8 |
|
| 9 |
def main(model_path: str = "."):
|
| 10 |
print(colored("="*60, 'yellow'))
|
|
|
|
| 12 |
print(colored("="*60, 'yellow'))
|
| 13 |
|
| 14 |
# Load model from current directory
|
| 15 |
+
print(colored("\n[1/6] Loading model...", 'cyan'))
|
| 16 |
model = TARA.from_pretrained(
|
| 17 |
model_path, # Load from current directory
|
| 18 |
device_map='auto',
|
|
|
|
| 22 |
n_params = sum(p.numel() for p in model.model.parameters())
|
| 23 |
print(colored(f"β Model loaded successfully!", 'green'))
|
| 24 |
print(f"Number of parameters: {round(n_params/1e9, 3)}B")
|
| 25 |
+
print("-" * 100)
|
| 26 |
|
| 27 |
# Encode a sample video
|
| 28 |
+
print(colored("\n[2/6] Testing video encoding and captioning ...", 'cyan'))
|
| 29 |
video_path = "./assets/folding_paper.mp4"
|
|
|
|
| 30 |
try:
|
| 31 |
video_tensor = read_frames_decord(video_path, num_frames=16)
|
| 32 |
video_tensor = video_tensor.unsqueeze(0)
|
|
|
|
| 34 |
|
| 35 |
with torch.no_grad():
|
| 36 |
video_emb = model.encode_vision(video_tensor).cpu().squeeze(0).float()
|
| 37 |
+
|
| 38 |
+
# Get caption for the video
|
| 39 |
+
video_caption = model.describe(video_tensor)[0]
|
| 40 |
|
| 41 |
print(colored("β Video encoded successfully!", 'green'))
|
| 42 |
print(f"Video shape: {video_tensor.shape}") # torch.Size([1, 16, 3, 240, 426])
|
| 43 |
print(f"Video embedding shape: {video_emb.shape}") # torch.Size([4096])
|
| 44 |
+
print(colored(f"Video caption: {video_caption}", 'magenta'))
|
| 45 |
except FileNotFoundError:
|
| 46 |
print(colored(f"β Video file not found: {video_path}", 'red'))
|
| 47 |
print(colored(" Please add a video file or update the path in demo_usage.py", 'yellow'))
|
| 48 |
video_emb = None
|
| 49 |
+
print("-" * 100)
|
| 50 |
|
| 51 |
# Encode sample texts
|
| 52 |
+
print(colored("\n[3/6] Testing text encoding...", 'cyan'))
|
| 53 |
text = ['someone is folding a paper', 'cutting a paper', 'someone is unfolding a paper']
|
| 54 |
# NOTE: It can also take a single string
|
| 55 |
|
|
|
|
| 62 |
|
| 63 |
# Compute similarities if video was encoded
|
| 64 |
if video_emb is not None:
|
| 65 |
+
print(colored("\n[4/6] Computing video-text similarities...", 'cyan'))
|
| 66 |
similarities = torch.cosine_similarity(
|
| 67 |
video_emb.unsqueeze(0).unsqueeze(0), # [1, 1, 4096]
|
| 68 |
text_emb.unsqueeze(0), # [1, 3, 4096]
|
|
|
|
| 76 |
|
| 77 |
# Negation example: a negation in text query should result
|
| 78 |
# in retrieval of images without the neg. object in the query
|
| 79 |
+
print(colored("\n[5/6] Testing negation example...", 'cyan'))
|
| 80 |
image_paths = [
|
| 81 |
'./assets/cat.png',
|
| 82 |
'./assets/dog+cat.png',
|
|
|
|
| 94 |
print("Text query: ", texts)
|
| 95 |
sim = text_embs @ image_embs.t()
|
| 96 |
print(f"Text-Image similarity: {sim}")
|
| 97 |
+
print("- " * 50)
|
| 98 |
|
| 99 |
texts = ['an image of a cat and a dog together']
|
| 100 |
with torch.no_grad():
|
|
|
|
| 104 |
sim = text_embs @ image_embs.t()
|
| 105 |
print(f"Text-Image similarity: {sim}")
|
| 106 |
print("-" * 100)
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
+
# Composed video retrieval example
|
| 110 |
+
print(colored("\n[6/6] Testing composed video retrieval...", 'cyan'))
|
| 111 |
+
# source_video_path = './assets/source-27375787.mp4'
|
| 112 |
+
# target_video_path = './assets/target-27387901.mp4'
|
| 113 |
+
# edit_text = "Make the billboard blank"
|
| 114 |
+
source_video_path = "./assets/5369546.mp4"
|
| 115 |
+
target_video_path = "./assets/1006630957.mp4"
|
| 116 |
+
edit_text ="make the tree lit up"
|
| 117 |
+
source_video_tensor = read_frames_decord(source_video_path, num_frames=4)
|
| 118 |
+
target_video_tensor = read_frames_decord(target_video_path, num_frames=16)
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
source_video_emb = model.encode_vision(source_video_tensor.unsqueeze(0), edit_text).cpu().squeeze(0).float()
|
| 121 |
+
source_video_emb = torch.nn.functional.normalize(source_video_emb, dim=-1)
|
| 122 |
+
target_video_emb = model.encode_vision(target_video_tensor.unsqueeze(0)).cpu().squeeze(0).float()
|
| 123 |
+
target_video_emb = torch.nn.functional.normalize(target_video_emb, dim=-1)
|
| 124 |
+
sim_with_edit = source_video_emb @ target_video_emb.t()
|
| 125 |
+
print(f"Source-Target similarity with edit: {sim_with_edit}")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
print(colored("\n" + "="*60, 'yellow'))
|
| 129 |
print(colored("Demo completed successfully! π", 'green', attrs=['bold']))
|
| 130 |
print(colored("="*60, 'yellow'))
|