mirror of
https://github.com/aljazceru/InvSR.git
synced 2025-12-17 14:24:27 +01:00
196 lines
6.9 KiB
Python
196 lines
6.9 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding:utf-8 -*-
|
|
# Power by Zongsheng Yue 2022-08-13 21:37:58
|
|
|
|
'''
|
|
Calculate PSNR, SSIM, LPIPS, and NIQE.
|
|
'''
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import os
|
|
import math
|
|
import lpips
|
|
import pyiqa
|
|
import torch
|
|
import argparse
|
|
from einops import rearrange
|
|
from loguru import logger as base_logger
|
|
|
|
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
|
from utils import util_image
|
|
from utils.util_opts import str2bool
|
|
from datapipe.datasets import BaseData
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--bs", type=int, default=16, help="Batch size")
|
|
parser.add_argument("--gt_dir", type=str, default="", help="Path to save the HQ images")
|
|
parser.add_argument("--sr_dir", type=str, default="", help="Path to save the SR images")
|
|
parser.add_argument("--log_name", type=str, default='metrics.log', help="Logging path")
|
|
parser.add_argument("--test_y_channel", type=str2bool, default='true', help="Y channel for PSNR and SSIM")
|
|
parser.add_argument("--fid", type=str2bool, default='false', help="Calculating FID")
|
|
parser.add_argument("--niqe", type=str2bool, default='false', help="Calculating NIQE")
|
|
parser.add_argument("--dists", type=str2bool, default='false', help="Calculating DISTS")
|
|
parser.add_argument("--maniqa", type=str2bool, default='false', help="Calculating MANIQA")
|
|
parser.add_argument("--pi", type=str2bool, default='false', help="Calculating PI")
|
|
parser.add_argument("--tocpu", type=str2bool, default='false', help="Moving model to CPU")
|
|
args = parser.parse_args()
|
|
|
|
# setting logger
|
|
log_path = str(Path(args.sr_dir).parent / f'{args.log_name}')
|
|
logger = base_logger
|
|
logger.remove()
|
|
logger.add(log_path, format="{time:YYYY-MM-DD(HH:mm:ss)}: {message}", mode='w', level='INFO')
|
|
logger.add(sys.stderr, format="{message}", level='INFO')
|
|
logger.info(f"Ground truth: {args.gt_dir}")
|
|
logger.info(f"SR result: {args.sr_dir}")
|
|
|
|
if args.test_y_channel:
|
|
psnr_metric = pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr')
|
|
ssim_metric = pyiqa.create_metric('ssim', test_y_channel=True, color_space='ycbcr')
|
|
else:
|
|
psnr_metric = pyiqa.create_metric('psnr', test_y_channel=False, color_space='rgb')
|
|
ssim_metric = pyiqa.create_metric('ssim', test_y_channel=False, color_space='rgb')
|
|
if args.fid:
|
|
fid_metric = pyiqa.create_metric('fid')
|
|
if args.niqe:
|
|
niqe_metric = pyiqa.create_metric('niqe')
|
|
if args.dists:
|
|
dists_metric = pyiqa.create_metric('dists')
|
|
if args.maniqa:
|
|
maniqa_metric = pyiqa.create_metric('maniqa')
|
|
if args.pi:
|
|
pi_metric = pyiqa.create_metric('pi')
|
|
loss_fn_vgg = lpips.LPIPS(net='vgg').cuda()
|
|
loss_fn_alex = lpips.LPIPS(net='alex').cuda()
|
|
if args.tocpu:
|
|
clipiqa_metric = pyiqa.create_metric('clipiqa').to('cpu')
|
|
musiq_metric = pyiqa.create_metric('musiq').to('cpu')
|
|
else:
|
|
clipiqa_metric = pyiqa.create_metric('clipiqa')
|
|
musiq_metric = pyiqa.create_metric('musiq')
|
|
|
|
dataset = BaseData(
|
|
dir_path=args.sr_dir,
|
|
transform_type='default',
|
|
transform_kwargs={'mean': 0.0, 'std': 1.0},
|
|
extra_dir_path=args.gt_dir,
|
|
extra_transform_type='default',
|
|
extra_transform_kwargs={'mean': 0.0, 'std': 1.0},
|
|
need_path=True,
|
|
im_exts=['png', 'jpg'],
|
|
recursive=False,
|
|
)
|
|
dataloader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
batch_size=args.bs,
|
|
shuffle=False,
|
|
drop_last=False,
|
|
num_workers=0
|
|
)
|
|
logger.info(f'Number of images: {len(dataset)}')
|
|
|
|
metrics = {
|
|
'PSNR': 0,
|
|
'SSIM': 0,
|
|
'LPIPS_VGG': 0,
|
|
'LPIPS_ALEX': 0,
|
|
'CLIPIQA': 0,
|
|
'MUSIQ': 0,
|
|
}
|
|
if args.niqe:
|
|
metrics['NIQE'] = 0
|
|
if args.dists:
|
|
metrics['DISTS'] = 0
|
|
if args.maniqa:
|
|
metrics['MANIQA'] = 0
|
|
if args.pi:
|
|
metrics['PI'] = 0
|
|
for ii, data in enumerate(dataloader):
|
|
im_sr = data['image'].cuda() # N x h x w x 3, [0,1]
|
|
im_gt = data['gt'].cuda() # N x h x w x 3, [0,1]
|
|
current_bs = im_sr.shape[0]
|
|
|
|
if not (im_sr.shape == im_gt.shape):
|
|
height = min(im_sr.shape[-2], im_gt.shape[-2])
|
|
width = min(im_sr.shape[-1], im_gt.shape[-1])
|
|
im_sr = im_sr[:, :, :height, :width]
|
|
im_gt = im_gt[:, :, :height, :width]
|
|
|
|
current_psnr = psnr_metric(im_sr, im_gt).mean().item()
|
|
current_ssim = ssim_metric(im_sr, im_gt).mean().item()
|
|
current_lpips_vgg = loss_fn_vgg(
|
|
(im_gt - 0.5) / 0.5,
|
|
(im_sr - 0.5) / 0.5,
|
|
).mean().item()
|
|
current_lpips_alex = loss_fn_alex(
|
|
(im_gt - 0.5) / 0.5,
|
|
(im_sr - 0.5) / 0.5,
|
|
).mean().item()
|
|
if args.tocpu:
|
|
current_clipiqa = clipiqa_metric(im_sr.cpu()).mean().item()
|
|
current_musiq = musiq_metric(im_sr.cpu()).mean().item()
|
|
else:
|
|
current_clipiqa = clipiqa_metric(im_sr).mean().item()
|
|
current_musiq = musiq_metric(im_sr).mean().item()
|
|
if args.niqe:
|
|
current_niqe = niqe_metric(im_sr).mean().item()
|
|
if args.dists:
|
|
current_dists = dists_metric(im_sr, im_gt).mean().item()
|
|
if args.maniqa:
|
|
current_maniqa = maniqa_metric(im_sr).mean().item()
|
|
if args.pi:
|
|
current_pi = pi_metric(im_sr).mean().item()
|
|
|
|
if (ii+1) % 30 == 0:
|
|
log_str = ('Processing: {:03d}/{:03d}, PSNR={:5.2f}, LPIPS={:6.4f}/{:6.4f}, CLIPIQA={:6.4f}, MUSIQ={:6.4f}'.format(
|
|
ii+1,
|
|
math.ceil(len(dataset) /args.bs),
|
|
current_psnr,
|
|
current_lpips_vgg,
|
|
current_lpips_alex,
|
|
current_clipiqa,
|
|
current_musiq,
|
|
))
|
|
logger.info(log_str)
|
|
|
|
metrics['PSNR'] += current_psnr * current_bs
|
|
metrics['SSIM'] += current_ssim * current_bs
|
|
metrics['LPIPS_VGG'] += current_lpips_vgg * current_bs
|
|
metrics['LPIPS_ALEX'] += current_lpips_alex * current_bs
|
|
metrics['CLIPIQA'] += current_clipiqa * current_bs
|
|
metrics['MUSIQ'] += current_musiq * current_bs
|
|
if args.niqe:
|
|
metrics['NIQE'] += current_niqe * current_bs
|
|
if args.dists:
|
|
metrics['DISTS'] += current_dists * current_bs
|
|
if args.maniqa:
|
|
metrics['MANIQA'] += current_maniqa * current_bs
|
|
if args.pi:
|
|
metrics['PI'] += current_pi * current_bs
|
|
|
|
for key in metrics.keys():
|
|
metrics[key] /= len(dataset)
|
|
|
|
if args.fid:
|
|
metrics['FID'] = fid_metric(args.sr_dir, args.gt_dir)
|
|
|
|
logger.info(f"MEAN PSNR: {metrics['PSNR']:5.2f}")
|
|
logger.info(f"MEAN SSIM: {metrics['SSIM']:6.4f}")
|
|
logger.info(f"MEAN LPIPS(VGG): {metrics['LPIPS_VGG']:6.4f}")
|
|
logger.info(f"MEAN LPIPS(ALEX): {metrics['LPIPS_ALEX']:6.4f}")
|
|
logger.info(f"MEAN CLIPIQA: {metrics['CLIPIQA']:6.4f}")
|
|
logger.info(f"MEAN MUSIQ: {metrics['MUSIQ']:6.4f}")
|
|
if args.fid:
|
|
logger.info(f"MEAN FID: {metrics['FID']:6.2f}")
|
|
if args.niqe:
|
|
logger.info(f"MEAN NIQE: {metrics['NIQE']:7.4f}")
|
|
if args.dists:
|
|
logger.info(f"MEAN DISTS: {metrics['DISTS']:6.4f}")
|
|
if args.maniqa:
|
|
logger.info(f"MEAN MANIQA: {metrics['MANIQA']:6.4f}")
|
|
if args.pi:
|
|
logger.info(f"MEAN PI: {metrics['PI']:7.4f}")
|
|
|