first commit

This commit is contained in:
zsyOAOA
2024-12-11 18:46:36 +08:00
parent 9e65255d34
commit 27f2eb7dc3
847 changed files with 377076 additions and 2 deletions

195
scripts/cal_metrics_ref.py Normal file
View File

@@ -0,0 +1,195 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-08-13 21:37:58
'''
Calculate PSNR, SSIM, LPIPS, and NIQE.
'''
import sys
from pathlib import Path
import os
import math
import lpips
import pyiqa
import torch
import argparse
from einops import rearrange
from loguru import logger as base_logger
sys.path.append(str(Path(__file__).resolve().parents[1]))
from utils import util_image
from utils.util_opts import str2bool
from datapipe.datasets import BaseData
parser = argparse.ArgumentParser()
parser.add_argument("--bs", type=int, default=16, help="Batch size")
parser.add_argument("--gt_dir", type=str, default="", help="Path to save the HQ images")
parser.add_argument("--sr_dir", type=str, default="", help="Path to save the SR images")
parser.add_argument("--log_name", type=str, default='metrics.log', help="Logging path")
parser.add_argument("--test_y_channel", type=str2bool, default='true', help="Y channel for PSNR and SSIM")
parser.add_argument("--fid", type=str2bool, default='false', help="Calculating FID")
parser.add_argument("--niqe", type=str2bool, default='false', help="Calculating NIQE")
parser.add_argument("--dists", type=str2bool, default='false', help="Calculating DISTS")
parser.add_argument("--maniqa", type=str2bool, default='false', help="Calculating MANIQA")
parser.add_argument("--pi", type=str2bool, default='false', help="Calculating PI")
parser.add_argument("--tocpu", type=str2bool, default='false', help="Moving model to CPU")
args = parser.parse_args()
# setting logger
log_path = str(Path(args.sr_dir).parent / f'{args.log_name}')
logger = base_logger
logger.remove()
logger.add(log_path, format="{time:YYYY-MM-DD(HH:mm:ss)}: {message}", mode='w', level='INFO')
logger.add(sys.stderr, format="{message}", level='INFO')
logger.info(f"Ground truth: {args.gt_dir}")
logger.info(f"SR result: {args.sr_dir}")
if args.test_y_channel:
psnr_metric = pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr')
ssim_metric = pyiqa.create_metric('ssim', test_y_channel=True, color_space='ycbcr')
else:
psnr_metric = pyiqa.create_metric('psnr', test_y_channel=False, color_space='rgb')
ssim_metric = pyiqa.create_metric('ssim', test_y_channel=False, color_space='rgb')
if args.fid:
fid_metric = pyiqa.create_metric('fid')
if args.niqe:
niqe_metric = pyiqa.create_metric('niqe')
if args.dists:
dists_metric = pyiqa.create_metric('dists')
if args.maniqa:
maniqa_metric = pyiqa.create_metric('maniqa')
if args.pi:
pi_metric = pyiqa.create_metric('pi')
loss_fn_vgg = lpips.LPIPS(net='vgg').cuda()
loss_fn_alex = lpips.LPIPS(net='alex').cuda()
if args.tocpu:
clipiqa_metric = pyiqa.create_metric('clipiqa').to('cpu')
musiq_metric = pyiqa.create_metric('musiq').to('cpu')
else:
clipiqa_metric = pyiqa.create_metric('clipiqa')
musiq_metric = pyiqa.create_metric('musiq')
dataset = BaseData(
dir_path=args.sr_dir,
transform_type='default',
transform_kwargs={'mean': 0.0, 'std': 1.0},
extra_dir_path=args.gt_dir,
extra_transform_type='default',
extra_transform_kwargs={'mean': 0.0, 'std': 1.0},
need_path=True,
im_exts=['png', 'jpg'],
recursive=False,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.bs,
shuffle=False,
drop_last=False,
num_workers=0
)
logger.info(f'Number of images: {len(dataset)}')
metrics = {
'PSNR': 0,
'SSIM': 0,
'LPIPS_VGG': 0,
'LPIPS_ALEX': 0,
'CLIPIQA': 0,
'MUSIQ': 0,
}
if args.niqe:
metrics['NIQE'] = 0
if args.dists:
metrics['DISTS'] = 0
if args.maniqa:
metrics['MANIQA'] = 0
if args.pi:
metrics['PI'] = 0
for ii, data in enumerate(dataloader):
im_sr = data['image'].cuda() # N x h x w x 3, [0,1]
im_gt = data['gt'].cuda() # N x h x w x 3, [0,1]
current_bs = im_sr.shape[0]
if not (im_sr.shape == im_gt.shape):
height = min(im_sr.shape[-2], im_gt.shape[-2])
width = min(im_sr.shape[-1], im_gt.shape[-1])
im_sr = im_sr[:, :, :height, :width]
im_gt = im_gt[:, :, :height, :width]
current_psnr = psnr_metric(im_sr, im_gt).mean().item()
current_ssim = ssim_metric(im_sr, im_gt).mean().item()
current_lpips_vgg = loss_fn_vgg(
(im_gt - 0.5) / 0.5,
(im_sr - 0.5) / 0.5,
).mean().item()
current_lpips_alex = loss_fn_alex(
(im_gt - 0.5) / 0.5,
(im_sr - 0.5) / 0.5,
).mean().item()
if args.tocpu:
current_clipiqa = clipiqa_metric(im_sr.cpu()).mean().item()
current_musiq = musiq_metric(im_sr.cpu()).mean().item()
else:
current_clipiqa = clipiqa_metric(im_sr).mean().item()
current_musiq = musiq_metric(im_sr).mean().item()
if args.niqe:
current_niqe = niqe_metric(im_sr).mean().item()
if args.dists:
current_dists = dists_metric(im_sr, im_gt).mean().item()
if args.maniqa:
current_maniqa = maniqa_metric(im_sr).mean().item()
if args.pi:
current_pi = pi_metric(im_sr).mean().item()
if (ii+1) % 30 == 0:
log_str = ('Processing: {:03d}/{:03d}, PSNR={:5.2f}, LPIPS={:6.4f}/{:6.4f}, CLIPIQA={:6.4f}, MUSIQ={:6.4f}'.format(
ii+1,
math.ceil(len(dataset) /args.bs),
current_psnr,
current_lpips_vgg,
current_lpips_alex,
current_clipiqa,
current_musiq,
))
logger.info(log_str)
metrics['PSNR'] += current_psnr * current_bs
metrics['SSIM'] += current_ssim * current_bs
metrics['LPIPS_VGG'] += current_lpips_vgg * current_bs
metrics['LPIPS_ALEX'] += current_lpips_alex * current_bs
metrics['CLIPIQA'] += current_clipiqa * current_bs
metrics['MUSIQ'] += current_musiq * current_bs
if args.niqe:
metrics['NIQE'] += current_niqe * current_bs
if args.dists:
metrics['DISTS'] += current_dists * current_bs
if args.maniqa:
metrics['MANIQA'] += current_maniqa * current_bs
if args.pi:
metrics['PI'] += current_pi * current_bs
for key in metrics.keys():
metrics[key] /= len(dataset)
if args.fid:
metrics['FID'] = fid_metric(args.sr_dir, args.gt_dir)
logger.info(f"MEAN PSNR: {metrics['PSNR']:5.2f}")
logger.info(f"MEAN SSIM: {metrics['SSIM']:6.4f}")
logger.info(f"MEAN LPIPS(VGG): {metrics['LPIPS_VGG']:6.4f}")
logger.info(f"MEAN LPIPS(ALEX): {metrics['LPIPS_ALEX']:6.4f}")
logger.info(f"MEAN CLIPIQA: {metrics['CLIPIQA']:6.4f}")
logger.info(f"MEAN MUSIQ: {metrics['MUSIQ']:6.4f}")
if args.fid:
logger.info(f"MEAN FID: {metrics['FID']:6.2f}")
if args.niqe:
logger.info(f"MEAN NIQE: {metrics['NIQE']:7.4f}")
if args.dists:
logger.info(f"MEAN DISTS: {metrics['DISTS']:6.4f}")
if args.maniqa:
logger.info(f"MEAN MANIQA: {metrics['MANIQA']:6.4f}")
if args.pi:
logger.info(f"MEAN PI: {metrics['PI']:7.4f}")