mirror of
https://github.com/aljazceru/InvSR.git
synced 2025-12-17 06:14:22 +01:00
first commit
This commit is contained in:
149
scripts/cal_metrics_nonref.py
Normal file
149
scripts/cal_metrics_nonref.py
Normal file
@@ -0,0 +1,149 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
# Power by Zongsheng Yue 2022-08-13 21:37:58
|
||||
|
||||
'''
|
||||
Calculate PSNR, SSIM, LPIPS, and NIQE.
|
||||
'''
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
import math
|
||||
import pyiqa
|
||||
import torch
|
||||
import argparse
|
||||
from einops import rearrange
|
||||
from loguru import logger as base_logger
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
||||
from utils import util_image
|
||||
from utils.util_opts import str2bool
|
||||
from datapipe.datasets import BaseData
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--bs", type=int, default=1, help="Batch size")
|
||||
parser.add_argument("-i", "--indir", type=str, default="", help="Path to save the testing images")
|
||||
parser.add_argument("-r", "--refdir", type=str, default="", help="Reference images for fid")
|
||||
parser.add_argument("-t", "--tocpu", type=str2bool, default='false')
|
||||
parser.add_argument("--pi", type=str2bool, default='false', help="PI metric")
|
||||
parser.add_argument("--niqe", type=str2bool, default='false', help="NIQE metric")
|
||||
parser.add_argument("--maniqa", type=str2bool, default='false', help="MANIQA metric")
|
||||
parser.add_argument("--tres", type=str2bool, default='false', help="TReS metric")
|
||||
parser.add_argument("--dbcnn", type=str2bool, default='false', help="DBCNN metric")
|
||||
args = parser.parse_args()
|
||||
|
||||
# setting logger
|
||||
log_path = str(Path(args.indir).parent / 'metrics.log')
|
||||
logger = base_logger
|
||||
logger.remove()
|
||||
logger.add(log_path, format="{time:YYYY-MM-DD(HH:mm:ss)}: {message}", mode='w', level='INFO')
|
||||
logger.add(sys.stderr, format="{message}", level='INFO')
|
||||
logger.info(f"Image Floder: {args.indir}")
|
||||
|
||||
if args.pi:
|
||||
pi_metric = pyiqa.create_metric('pi')
|
||||
if args.niqe:
|
||||
niqe_metric = pyiqa.create_metric('niqe')
|
||||
if args.maniqa:
|
||||
maniqa_metric = pyiqa.create_metric('maniqa')
|
||||
if args.tres:
|
||||
tres_metric = pyiqa.create_metric('tres')
|
||||
if args.dbcnn:
|
||||
dbcnn_metric = pyiqa.create_metric('dbcnn')
|
||||
if args.refdir:
|
||||
fid_metric = pyiqa.create_metric('fid')
|
||||
if args.tocpu:
|
||||
clipiqa_metric = pyiqa.create_metric('clipiqa').to('cpu')
|
||||
musiq_metric = pyiqa.create_metric('musiq').to('cpu')
|
||||
else:
|
||||
clipiqa_metric = pyiqa.create_metric('clipiqa')
|
||||
musiq_metric = pyiqa.create_metric('musiq')
|
||||
|
||||
dataset = BaseData(
|
||||
dir_path=args.indir,
|
||||
transform_type='default',
|
||||
transform_kwargs={'mean': 0.0, 'std': 1.0},
|
||||
need_path=True,
|
||||
im_exts=['png', 'jpeg', 'jpg', ],
|
||||
recursive=False,
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.bs,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
num_workers=0
|
||||
)
|
||||
logger.info(f'Number of images: {len(dataset)}')
|
||||
|
||||
metrics = {
|
||||
'PI': 0,
|
||||
'CLIPIQA': 0,
|
||||
'MUSIQ': 0,
|
||||
'MANIQA': 0,
|
||||
'TRES': 0,
|
||||
'DBCNN': 0,
|
||||
}
|
||||
if args.niqe:
|
||||
metrics['NIQE'] = 0
|
||||
for ii, data in enumerate(dataloader):
|
||||
im = data['image'].cuda() # N x h x w x 3, [0,1]
|
||||
current_bs = im.shape[0]
|
||||
|
||||
if args.pi:
|
||||
current_pi = pi_metric(im).sum().item()
|
||||
if args.niqe:
|
||||
current_niqe = niqe_metric(im).sum().item()
|
||||
if args.maniqa:
|
||||
current_maniqa = maniqa_metric(im).sum().item()
|
||||
if args.tres:
|
||||
current_tres = tres_metric(im).sum().item()
|
||||
if args.dbcnn:
|
||||
current_dbcnn = dbcnn_metric(im).sum().item()
|
||||
if args.tocpu:
|
||||
current_clipiqa = clipiqa_metric(im.cpu()).sum().item()
|
||||
current_musiq = musiq_metric(im.cpu()).sum().item()
|
||||
else:
|
||||
current_clipiqa = clipiqa_metric(im).sum().item()
|
||||
current_musiq = musiq_metric(im).sum().item()
|
||||
|
||||
if (ii+1) % 10 == 0:
|
||||
log_str = ('Processing: {:03d}/{:03d}'.format(ii+1, math.ceil(len(dataset) / args.bs)))
|
||||
logger.info(log_str)
|
||||
|
||||
metrics['CLIPIQA'] += current_clipiqa
|
||||
metrics['MUSIQ'] += current_musiq
|
||||
if args.pi:
|
||||
metrics['PI'] += current_pi
|
||||
if args.niqe:
|
||||
metrics['NIQE'] += current_niqe
|
||||
if args.maniqa:
|
||||
metrics['MANIQA'] += current_maniqa
|
||||
if args.tres:
|
||||
metrics['TRES'] += current_tres
|
||||
if args.dbcnn:
|
||||
metrics['DBCNN'] += current_dbcnn
|
||||
|
||||
for key in metrics.keys():
|
||||
metrics[key] /= len(dataset)
|
||||
|
||||
if args.refdir:
|
||||
metrics['FID'] = fid_metric(args.indir, args.refdir, mode='legacy_pytorch')
|
||||
|
||||
logger.info(f"MEAN CLIPIQA: {metrics['CLIPIQA']:6.4f}")
|
||||
logger.info(f"MEAN MUSIQ: {metrics['MUSIQ']:6.4f}")
|
||||
if args.pi:
|
||||
logger.info(f"MEAN PI: {metrics['PI']:6.4f}")
|
||||
if args.niqe:
|
||||
logger.info(f"MEAN NIQE: {metrics['NIQE']:6.4f}")
|
||||
if args.maniqa:
|
||||
logger.info(f"MEAN MANIQA: {metrics['MANIQA']:6.4f}")
|
||||
if args.tres:
|
||||
logger.info(f"MEAN TRES: {metrics['TRES']:6.4f}")
|
||||
if args.dbcnn:
|
||||
logger.info(f"MEAN DBCNN: {metrics['DBCNN']:6.4f}")
|
||||
if args.refdir:
|
||||
logger.info(f"MEAN FID: {metrics['FID']:6.4f}")
|
||||
|
||||
195
scripts/cal_metrics_ref.py
Normal file
195
scripts/cal_metrics_ref.py
Normal file
@@ -0,0 +1,195 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
# Power by Zongsheng Yue 2022-08-13 21:37:58
|
||||
|
||||
'''
|
||||
Calculate PSNR, SSIM, LPIPS, and NIQE.
|
||||
'''
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
import math
|
||||
import lpips
|
||||
import pyiqa
|
||||
import torch
|
||||
import argparse
|
||||
from einops import rearrange
|
||||
from loguru import logger as base_logger
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
||||
from utils import util_image
|
||||
from utils.util_opts import str2bool
|
||||
from datapipe.datasets import BaseData
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--bs", type=int, default=16, help="Batch size")
|
||||
parser.add_argument("--gt_dir", type=str, default="", help="Path to save the HQ images")
|
||||
parser.add_argument("--sr_dir", type=str, default="", help="Path to save the SR images")
|
||||
parser.add_argument("--log_name", type=str, default='metrics.log', help="Logging path")
|
||||
parser.add_argument("--test_y_channel", type=str2bool, default='true', help="Y channel for PSNR and SSIM")
|
||||
parser.add_argument("--fid", type=str2bool, default='false', help="Calculating FID")
|
||||
parser.add_argument("--niqe", type=str2bool, default='false', help="Calculating NIQE")
|
||||
parser.add_argument("--dists", type=str2bool, default='false', help="Calculating DISTS")
|
||||
parser.add_argument("--maniqa", type=str2bool, default='false', help="Calculating MANIQA")
|
||||
parser.add_argument("--pi", type=str2bool, default='false', help="Calculating PI")
|
||||
parser.add_argument("--tocpu", type=str2bool, default='false', help="Moving model to CPU")
|
||||
args = parser.parse_args()
|
||||
|
||||
# setting logger
|
||||
log_path = str(Path(args.sr_dir).parent / f'{args.log_name}')
|
||||
logger = base_logger
|
||||
logger.remove()
|
||||
logger.add(log_path, format="{time:YYYY-MM-DD(HH:mm:ss)}: {message}", mode='w', level='INFO')
|
||||
logger.add(sys.stderr, format="{message}", level='INFO')
|
||||
logger.info(f"Ground truth: {args.gt_dir}")
|
||||
logger.info(f"SR result: {args.sr_dir}")
|
||||
|
||||
if args.test_y_channel:
|
||||
psnr_metric = pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr')
|
||||
ssim_metric = pyiqa.create_metric('ssim', test_y_channel=True, color_space='ycbcr')
|
||||
else:
|
||||
psnr_metric = pyiqa.create_metric('psnr', test_y_channel=False, color_space='rgb')
|
||||
ssim_metric = pyiqa.create_metric('ssim', test_y_channel=False, color_space='rgb')
|
||||
if args.fid:
|
||||
fid_metric = pyiqa.create_metric('fid')
|
||||
if args.niqe:
|
||||
niqe_metric = pyiqa.create_metric('niqe')
|
||||
if args.dists:
|
||||
dists_metric = pyiqa.create_metric('dists')
|
||||
if args.maniqa:
|
||||
maniqa_metric = pyiqa.create_metric('maniqa')
|
||||
if args.pi:
|
||||
pi_metric = pyiqa.create_metric('pi')
|
||||
loss_fn_vgg = lpips.LPIPS(net='vgg').cuda()
|
||||
loss_fn_alex = lpips.LPIPS(net='alex').cuda()
|
||||
if args.tocpu:
|
||||
clipiqa_metric = pyiqa.create_metric('clipiqa').to('cpu')
|
||||
musiq_metric = pyiqa.create_metric('musiq').to('cpu')
|
||||
else:
|
||||
clipiqa_metric = pyiqa.create_metric('clipiqa')
|
||||
musiq_metric = pyiqa.create_metric('musiq')
|
||||
|
||||
dataset = BaseData(
|
||||
dir_path=args.sr_dir,
|
||||
transform_type='default',
|
||||
transform_kwargs={'mean': 0.0, 'std': 1.0},
|
||||
extra_dir_path=args.gt_dir,
|
||||
extra_transform_type='default',
|
||||
extra_transform_kwargs={'mean': 0.0, 'std': 1.0},
|
||||
need_path=True,
|
||||
im_exts=['png', 'jpg'],
|
||||
recursive=False,
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.bs,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
num_workers=0
|
||||
)
|
||||
logger.info(f'Number of images: {len(dataset)}')
|
||||
|
||||
metrics = {
|
||||
'PSNR': 0,
|
||||
'SSIM': 0,
|
||||
'LPIPS_VGG': 0,
|
||||
'LPIPS_ALEX': 0,
|
||||
'CLIPIQA': 0,
|
||||
'MUSIQ': 0,
|
||||
}
|
||||
if args.niqe:
|
||||
metrics['NIQE'] = 0
|
||||
if args.dists:
|
||||
metrics['DISTS'] = 0
|
||||
if args.maniqa:
|
||||
metrics['MANIQA'] = 0
|
||||
if args.pi:
|
||||
metrics['PI'] = 0
|
||||
for ii, data in enumerate(dataloader):
|
||||
im_sr = data['image'].cuda() # N x h x w x 3, [0,1]
|
||||
im_gt = data['gt'].cuda() # N x h x w x 3, [0,1]
|
||||
current_bs = im_sr.shape[0]
|
||||
|
||||
if not (im_sr.shape == im_gt.shape):
|
||||
height = min(im_sr.shape[-2], im_gt.shape[-2])
|
||||
width = min(im_sr.shape[-1], im_gt.shape[-1])
|
||||
im_sr = im_sr[:, :, :height, :width]
|
||||
im_gt = im_gt[:, :, :height, :width]
|
||||
|
||||
current_psnr = psnr_metric(im_sr, im_gt).mean().item()
|
||||
current_ssim = ssim_metric(im_sr, im_gt).mean().item()
|
||||
current_lpips_vgg = loss_fn_vgg(
|
||||
(im_gt - 0.5) / 0.5,
|
||||
(im_sr - 0.5) / 0.5,
|
||||
).mean().item()
|
||||
current_lpips_alex = loss_fn_alex(
|
||||
(im_gt - 0.5) / 0.5,
|
||||
(im_sr - 0.5) / 0.5,
|
||||
).mean().item()
|
||||
if args.tocpu:
|
||||
current_clipiqa = clipiqa_metric(im_sr.cpu()).mean().item()
|
||||
current_musiq = musiq_metric(im_sr.cpu()).mean().item()
|
||||
else:
|
||||
current_clipiqa = clipiqa_metric(im_sr).mean().item()
|
||||
current_musiq = musiq_metric(im_sr).mean().item()
|
||||
if args.niqe:
|
||||
current_niqe = niqe_metric(im_sr).mean().item()
|
||||
if args.dists:
|
||||
current_dists = dists_metric(im_sr, im_gt).mean().item()
|
||||
if args.maniqa:
|
||||
current_maniqa = maniqa_metric(im_sr).mean().item()
|
||||
if args.pi:
|
||||
current_pi = pi_metric(im_sr).mean().item()
|
||||
|
||||
if (ii+1) % 30 == 0:
|
||||
log_str = ('Processing: {:03d}/{:03d}, PSNR={:5.2f}, LPIPS={:6.4f}/{:6.4f}, CLIPIQA={:6.4f}, MUSIQ={:6.4f}'.format(
|
||||
ii+1,
|
||||
math.ceil(len(dataset) /args.bs),
|
||||
current_psnr,
|
||||
current_lpips_vgg,
|
||||
current_lpips_alex,
|
||||
current_clipiqa,
|
||||
current_musiq,
|
||||
))
|
||||
logger.info(log_str)
|
||||
|
||||
metrics['PSNR'] += current_psnr * current_bs
|
||||
metrics['SSIM'] += current_ssim * current_bs
|
||||
metrics['LPIPS_VGG'] += current_lpips_vgg * current_bs
|
||||
metrics['LPIPS_ALEX'] += current_lpips_alex * current_bs
|
||||
metrics['CLIPIQA'] += current_clipiqa * current_bs
|
||||
metrics['MUSIQ'] += current_musiq * current_bs
|
||||
if args.niqe:
|
||||
metrics['NIQE'] += current_niqe * current_bs
|
||||
if args.dists:
|
||||
metrics['DISTS'] += current_dists * current_bs
|
||||
if args.maniqa:
|
||||
metrics['MANIQA'] += current_maniqa * current_bs
|
||||
if args.pi:
|
||||
metrics['PI'] += current_pi * current_bs
|
||||
|
||||
for key in metrics.keys():
|
||||
metrics[key] /= len(dataset)
|
||||
|
||||
if args.fid:
|
||||
metrics['FID'] = fid_metric(args.sr_dir, args.gt_dir)
|
||||
|
||||
logger.info(f"MEAN PSNR: {metrics['PSNR']:5.2f}")
|
||||
logger.info(f"MEAN SSIM: {metrics['SSIM']:6.4f}")
|
||||
logger.info(f"MEAN LPIPS(VGG): {metrics['LPIPS_VGG']:6.4f}")
|
||||
logger.info(f"MEAN LPIPS(ALEX): {metrics['LPIPS_ALEX']:6.4f}")
|
||||
logger.info(f"MEAN CLIPIQA: {metrics['CLIPIQA']:6.4f}")
|
||||
logger.info(f"MEAN MUSIQ: {metrics['MUSIQ']:6.4f}")
|
||||
if args.fid:
|
||||
logger.info(f"MEAN FID: {metrics['FID']:6.2f}")
|
||||
if args.niqe:
|
||||
logger.info(f"MEAN NIQE: {metrics['NIQE']:7.4f}")
|
||||
if args.dists:
|
||||
logger.info(f"MEAN DISTS: {metrics['DISTS']:6.4f}")
|
||||
if args.maniqa:
|
||||
logger.info(f"MEAN MANIQA: {metrics['MANIQA']:6.4f}")
|
||||
if args.pi:
|
||||
logger.info(f"MEAN PI: {metrics['PI']:7.4f}")
|
||||
|
||||
107
scripts/prepare_sr_testing_syn.py
Normal file
107
scripts/prepare_sr_testing_syn.py
Normal file
@@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
# Power by Zongsheng Yue 2024-04-07 20:57:36
|
||||
|
||||
import os
|
||||
import torch
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.append(str(Path(__file__).parents[1]))
|
||||
|
||||
from basicsr.data.realesrgan_dataset import RealESRGANDataset
|
||||
from utils import util_image
|
||||
from utils import util_common
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--indir",
|
||||
type=str,
|
||||
default="/mnt/lustre/share/zhangwenwei/data/imagenet/val",
|
||||
help="Folder to save the checkpoints and training log",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--outdir",
|
||||
type=str,
|
||||
default="./ImageNet-Test",
|
||||
help="Folder to save the checkpoints and training log",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Image resolution of the ground truth",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_imgs",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Number of images.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if Path(args.indir).is_dir():
|
||||
img_list = sorted([x for x in Path(args.indir).glob('*.[JjPp][PpNn]*[Gg]')])
|
||||
elif args.indir.endswith('txt'):
|
||||
img_list = util_common.readline_txt(args.indir)
|
||||
else:
|
||||
raise ValueError('Please input valid args.indir!')
|
||||
print(f'Number of images in the input folder: {len(img_list)}')
|
||||
|
||||
random.seed(10000)
|
||||
random.shuffle(img_list)
|
||||
|
||||
num_imgs = args.num_imgs
|
||||
if num_imgs > 0:
|
||||
assert num_imgs <= len(img_list)
|
||||
img_list = random.sample(img_list, k=num_imgs)
|
||||
|
||||
gt_dir = Path(args.outdir) / 'gt'
|
||||
if not gt_dir.exists():
|
||||
gt_dir.mkdir(parents=True)
|
||||
lq_dir = Path(args.outdir) / 'lq'
|
||||
if not lq_dir.exists():
|
||||
lq_dir.mkdir(parents=True)
|
||||
|
||||
# Loading configuration
|
||||
configs = OmegaConf.load('./configs/degradation_testing_realesrgan.yaml')
|
||||
opts, opts_degradation = configs.opts, configs.degradation
|
||||
opts['gt_size'] = args.resolution
|
||||
opts_degradation['gt_size'] = args.resolution
|
||||
|
||||
dataset = RealESRGANDataset(opts, mode='testing')
|
||||
dataset.image_paths = img_list
|
||||
dataset.text_paths = [None, ] * len(img_list)
|
||||
dataset.moment_paths = [None, ] * len(img_list)
|
||||
for ii in range(len(img_list)):
|
||||
data_dict1 = dataset.__getitem__(ii)
|
||||
if (ii + 1) % 100 == 0:
|
||||
print(f'Processing: {ii+1}/{len(img_list)}')
|
||||
prefix = 'realesrgan'
|
||||
data_dict2 = dataset.degrade_fun(
|
||||
opts_degradation,
|
||||
im_gt=data_dict1['gt'].unsqueeze(0),
|
||||
kernel1=data_dict1['kernel1'],
|
||||
kernel2=data_dict1['kernel2'],
|
||||
sinc_kernel=data_dict1['sinc_kernel'],
|
||||
)
|
||||
im_lq, im_gt = data_dict2['lq'], data_dict2['gt']
|
||||
im_lq, im_gt = util_image.tensor2img([im_lq, im_gt], rgb2bgr=True, min_max=(0,1) ) # uint8
|
||||
|
||||
im_name = Path(data_dict1['gt_path']).stem
|
||||
im_path_gt = gt_dir / f'{im_name}.png'
|
||||
util_image.imwrite(im_gt, im_path_gt, chn='bgr', dtype_in='uint8')
|
||||
|
||||
im_path_lq = lq_dir / f'{im_name}.png'
|
||||
util_image.imwrite(im_lq, im_path_lq, chn='bgr', dtype_in='uint8')
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user