import glob import torch from torchvision import models import torch.nn as nn from PIL import Image import os import numpy as np import random from argparse import ArgumentParser from spatial_dataloader import VideoDataset_mp4 import scipy.io as scio def exit_folder(folder_name): if not os.path.exists(folder_name): os.makedirs(folder_name) # print(1) class ResNet18_LP(torch.nn.Module): """Modified ResNet18 for feature extraction""" def __init__(self, layer=2): super(ResNet18_LP, self).__init__() if layer == 1: self.features = nn.Sequential( *list(models.resnet18(weights=models.ResNet18_Weights.DEFAULT).children())[:-5]) elif layer == 2: self.features = nn.Sequential( *list(models.resnet18(weights=models.ResNet18_Weights.DEFAULT).children())[:-4]) elif layer == 3: self.features = nn.Sequential( *list(models.resnet18(weights=models.ResNet18_Weights.DEFAULT).children())[:-3]) else: self.features = nn.Sequential( *list(models.resnet18(weights=models.ResNet18_Weights.DEFAULT).children())[:-2]) for p in self.features.parameters(): p.requires_grad = False def forward(self, x): x = self.features(x) features_mean = nn.functional.adaptive_avg_pool2d(x, 1) features_std = global_std_pool2d(x) return features_mean, features_std def global_std_pool2d(x): """2D global standard variation pooling""" return torch.std(x.view(x.size()[0], x.size()[1], -1, 1), dim=2, keepdim=True) if __name__ == "__main__": parser = ArgumentParser( description='"Extracting Laplacian Pyramids Features using Pre-Trained ResNet-18') parser.add_argument("--seed", type=int, default=20241017) parser.add_argument('--database', default='Qeval', type=str, help='database name') parser.add_argument('--frame_batch_size', type=int, default=64, help='frame batch size for feature extraction (default: 64)') parser.add_argument('--layer', type=int, default=2, help='RN18 layer for feature extraction (default: 2)') parser.add_argument('--num_levels', type=int, default=6, help='number of gaussian pyramids') # parser.add_argument('--prompt_num',type=int,default=619) args = parser.parse_args() torch.manual_seed(args.seed) # torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False np.random.seed(args.seed) random.seed(args.seed) torch.utils.backcompat.broadcast_warning.enabled = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") extractor = ResNet18_LP(layer=2).to(device) # extractor.eval() # breakpoint() if args.database == 'Qeval': #vids_dir = f'/home/ps/codebase/XGC-track1/test' vids_dir = f'/home/ps/codebase/IQA_Database/iccv25_challenges/T2V/test' save_folder = f'data/iccv_spa_test' if not os.path.exists(save_folder): os.makedirs(save_folder) dataset = VideoDataset_mp4(args.database, vids_dir, args.num_levels) for i in range(0, len(dataset)): dt, f_len, nm = dataset[i] video_length = dt.shape[0] print(f'process {i}th vid, name: {nm}') frame_start = 0 frame_end = frame_start + args.frame_batch_size output1 = torch.Tensor().to(device) output2 = torch.Tensor().to(device) with torch.no_grad(): while frame_end < video_length: batch = dt[frame_start:frame_end].to(device) features_mean, features_std = extractor(batch) #features_mean= extractor(batch) output1 = torch.cat((output1, features_mean), 0) output2 = torch.cat((output2, features_std), 0) frame_end += args.frame_batch_size frame_start += args.frame_batch_size last_batch = dt[frame_start:video_length].to(device) features_mean, features_std = extractor(last_batch) #features_mean = extractor(last_batch) output1 = torch.cat((output1, features_mean), 0) output2 = torch.cat((output2, features_std), 0) #features = torch.cat((output1, output2), 1).squeeze() features = torch.cat((output1, output2), 1).squeeze() exit_folder(os.path.join(save_folder, nm)) for j in range(f_len): img_features = features[j*(args.num_levels-1): (j+1)*(args.num_levels-1)] np.save(os.path.join(save_folder, nm, str(j)), img_features.to('cpu').numpy()) # del extractor torch.cuda.empty_cache()