zwx8981's picture
Upload 493 files
1c80527 verified
import glob
import torch
from torchvision import models
import torch.nn as nn
from PIL import Image
import os
import numpy as np
import random
from argparse import ArgumentParser
from spatial_dataloader import VideoDataset_mp42
import scipy.io as scio
import pyiqa
from pyiqa.data.multiscale_trans_util import get_multiscale_patches
from torchvision.models import convnext_base, ConvNeXt_Base_Weights
from torchvision import transforms
from PIL import Image
class ConvNeXtFeatureExtractor(nn.Module):
def __init__(self, output_stage: int = 4, pooled: bool = True, device='cuda'):
"""
:param output_stage: [1-4] 指定从哪个stage输出特征
:param pooled: 是否对输出特征做adaptive avg pooling(输出 [B, C] 向量)
:param device: 运行设备,如 "cuda" 或 "cpu"
"""
super().__init__()
assert output_stage in [1, 2, 3, 4], "output_stage 必须是 1 到 4"
# 加载预训练模型
weights = ConvNeXt_Base_Weights.DEFAULT
self.model = convnext_base(weights=weights).to(device)
self.model.eval()
self.device = device
# 保存预处理 transform
#self.transform = weights.transforms()
# 取出前几个 stage
self.backbone = nn.Sequential(*list(self.model.features.children()))
self.pooled = pooled
if self.pooled:
self.pool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
输入预处理好的图像张量,输出特征
"""
x = x.to(self.device)
x = self.backbone(x)
if self.pooled:
x = self.pool(x)
x = x.view(x.size(0), -1) # 展平为 [B, C]
return x
def exit_folder(folder_name):
if not os.path.exists(folder_name):
os.makedirs(folder_name)
def global_std_pool2d(x):
"""2D global standard variation pooling"""
return torch.std(x.view(x.size()[0], x.size()[1], -1, 1),
dim=2, keepdim=True)
if __name__ == "__main__":
parser = ArgumentParser(
description='"Extracting Laplacian Pyramids Features using Pre-Trained ResNet-18')
parser.add_argument("--seed", type=int, default=20241017)
parser.add_argument('--database', default='Qeval', type=str,
help='database name')
parser.add_argument('--frame_batch_size', type=int, default=64,
help='frame batch size for feature extraction (default: 64)')
parser.add_argument('--layer', type=int, default=2,
help='RN18 layer for feature extraction (default: 2)')
parser.add_argument('--num_levels', type=int, default=6,
help='number of gaussian pyramids')
# parser.add_argument('--prompt_num',type=int,default=619)
args = parser.parse_args()
torch.manual_seed(args.seed) #
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)
random.seed(args.seed)
torch.utils.backcompat.broadcast_warning.enabled = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#extractor = ResNet18_LP(layer=2).to(device) #
#extractor = ARNIQA_LP().to(device) #
#extractor = SCNN_LP().to(device) #
#extractor = UNIQUE_LP().to(device) #
extractor = ConvNeXtFeatureExtractor().to(device)
extractor.eval()
# breakpoint()
if args.database == 'Qeval':
#vids_dir = f'/home/ps/codebase/XGC-track1/test'
vids_dir = f'/home/ps/codebase/IQA_Database/iccv25_challenges/T2V/train'
save_folder = f'data/iccv_spa_train2'
if not os.path.exists(save_folder):
os.makedirs(save_folder)
dataset = VideoDataset_mp42(args.database, vids_dir, args.num_levels)
for i in range(0, len(dataset)):
dt, f_len, nm = dataset[i]
video_length = dt.shape[0]
print(f'process {i}th vid, name: {nm}')
frame_start = 0
frame_end = frame_start + args.frame_batch_size
output1 = torch.Tensor().to(device)
#output2 = torch.Tensor().to(device)
with torch.no_grad():
while frame_end < video_length:
batch = dt[frame_start:frame_end].to(device)
#features_mean, features_std = extractor(batch)
features_mean= extractor(batch)
output1 = torch.cat((output1, features_mean), 0)
#output2 = torch.cat((output2, features_std), 0)
frame_end += args.frame_batch_size
frame_start += args.frame_batch_size
last_batch = dt[frame_start:video_length].to(device)
#features_mean, features_std = extractor(last_batch)
features_mean = extractor(last_batch)
output1 = torch.cat((output1, features_mean), 0)
#output2 = torch.cat((output2, features_std), 0)
#features = torch.cat((output1, output2), 1).squeeze()
#features = torch.cat((output1, output2), 1).squeeze()
features = output1.squeeze()
exit_folder(os.path.join(save_folder, nm))
for j in range(f_len):
img_features = features[j]
img_features = img_features.reshape(1,-1)
np.save(os.path.join(save_folder, nm, str(j)),
img_features.to('cpu').numpy())
# del extractor
torch.cuda.empty_cache()