Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import numpy as np | |
| def process_attn(attention, rng, attn_func): | |
| heatmap = np.zeros((len(attention), attention[0].shape[1])) | |
| for i, attn_layer in enumerate(attention): | |
| attn_layer = attn_layer.to(torch.float32).numpy() | |
| if "sum" in attn_func: | |
| last_token_attn_to_inst = np.sum(attn_layer[0, :, -1, rng[0][0]:rng[0][1]], axis=1) | |
| attn = last_token_attn_to_inst | |
| elif "max" in attn_func: | |
| last_token_attn_to_inst = np.max(attn_layer[0, :, -1, rng[0][0]:rng[0][1]], axis=1) | |
| attn = last_token_attn_to_inst | |
| else: raise NotImplementedError | |
| last_token_attn_to_inst_sum = np.sum(attn_layer[0, :, -1, rng[0][0]:rng[0][1]], axis=1) | |
| last_token_attn_to_data_sum = np.sum(attn_layer[0, :, -1, rng[1][0]:rng[1][1]], axis=1) | |
| if "normalize" in attn_func: | |
| epsilon = 1e-8 | |
| heatmap[i, :] = attn / (last_token_attn_to_inst_sum + last_token_attn_to_data_sum + epsilon) | |
| else: | |
| heatmap[i, :] = attn | |
| heatmap = np.nan_to_num(heatmap, nan=0.0) | |
| return heatmap | |
| def calc_attn_score(heatmap, heads): | |
| score = np.mean([heatmap[l, h] for l, h in heads], axis=0) | |
| return score | |