zwx8981's picture
Upload 493 files
1c80527 verified
import glob
import torch
from torchvision import models
import torch.nn as nn
from PIL import Image
import os
import numpy as np
import random
from argparse import ArgumentParser
from spatial_dataloader import VideoDataset_mp4
import scipy.io as scio
def exit_folder(folder_name):
if not os.path.exists(folder_name):
os.makedirs(folder_name)
# print(1)
class ResNet18_LP(torch.nn.Module):
"""Modified ResNet18 for feature extraction"""
def __init__(self, layer=2):
super(ResNet18_LP, self).__init__()
if layer == 1:
self.features = nn.Sequential(
*list(models.resnet18(weights=models.ResNet18_Weights.DEFAULT).children())[:-5])
elif layer == 2:
self.features = nn.Sequential(
*list(models.resnet18(weights=models.ResNet18_Weights.DEFAULT).children())[:-4])
elif layer == 3:
self.features = nn.Sequential(
*list(models.resnet18(weights=models.ResNet18_Weights.DEFAULT).children())[:-3])
else:
self.features = nn.Sequential(
*list(models.resnet18(weights=models.ResNet18_Weights.DEFAULT).children())[:-2])
for p in self.features.parameters():
p.requires_grad = False
def forward(self, x):
x = self.features(x)
features_mean = nn.functional.adaptive_avg_pool2d(x, 1)
features_std = global_std_pool2d(x)
return features_mean, features_std
def global_std_pool2d(x):
"""2D global standard variation pooling"""
return torch.std(x.view(x.size()[0], x.size()[1], -1, 1),
dim=2, keepdim=True)
if __name__ == "__main__":
parser = ArgumentParser(
description='"Extracting Laplacian Pyramids Features using Pre-Trained ResNet-18')
parser.add_argument("--seed", type=int, default=20241017)
parser.add_argument('--database', default='Qeval', type=str,
help='database name')
parser.add_argument('--frame_batch_size', type=int, default=64,
help='frame batch size for feature extraction (default: 64)')
parser.add_argument('--layer', type=int, default=2,
help='RN18 layer for feature extraction (default: 2)')
parser.add_argument('--num_levels', type=int, default=6,
help='number of gaussian pyramids')
# parser.add_argument('--prompt_num',type=int,default=619)
args = parser.parse_args()
torch.manual_seed(args.seed) #
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)
random.seed(args.seed)
torch.utils.backcompat.broadcast_warning.enabled = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
extractor = ResNet18_LP(layer=2).to(device) #
extractor.eval()
# breakpoint()
if args.database == 'Qeval':
#vids_dir = f'/home/ps/codebase/XGC-track1/test'
vids_dir = f'/home/ps/codebase/IQA_Database/iccv25_challenges/T2V/test'
save_folder = f'data/iccv_spa_test'
if not os.path.exists(save_folder):
os.makedirs(save_folder)
dataset = VideoDataset_mp4(args.database, vids_dir, args.num_levels)
for i in range(0, len(dataset)):
dt, f_len, nm = dataset[i]
video_length = dt.shape[0]
print(f'process {i}th vid, name: {nm}')
frame_start = 0
frame_end = frame_start + args.frame_batch_size
output1 = torch.Tensor().to(device)
output2 = torch.Tensor().to(device)
with torch.no_grad():
while frame_end < video_length:
batch = dt[frame_start:frame_end].to(device)
features_mean, features_std = extractor(batch)
#features_mean= extractor(batch)
output1 = torch.cat((output1, features_mean), 0)
output2 = torch.cat((output2, features_std), 0)
frame_end += args.frame_batch_size
frame_start += args.frame_batch_size
last_batch = dt[frame_start:video_length].to(device)
features_mean, features_std = extractor(last_batch)
#features_mean = extractor(last_batch)
output1 = torch.cat((output1, features_mean), 0)
output2 = torch.cat((output2, features_std), 0)
#features = torch.cat((output1, output2), 1).squeeze()
features = torch.cat((output1, output2), 1).squeeze()
exit_folder(os.path.join(save_folder, nm))
for j in range(f_len):
img_features = features[j*(args.num_levels-1): (j+1)*(args.num_levels-1)]
np.save(os.path.join(save_folder, nm, str(j)),
img_features.to('cpu').numpy())
# del extractor
torch.cuda.empty_cache()