first commit

This commit is contained in:
zsyOAOA
2024-12-11 18:46:36 +08:00
parent 9e65255d34
commit 27f2eb7dc3
847 changed files with 377076 additions and 2 deletions

View File

@@ -0,0 +1,149 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-08-13 21:37:58
'''
Calculate PSNR, SSIM, LPIPS, and NIQE.
'''
import sys
from pathlib import Path
import os
import math
import pyiqa
import torch
import argparse
from einops import rearrange
from loguru import logger as base_logger
sys.path.append(str(Path(__file__).resolve().parents[1]))
from utils import util_image
from utils.util_opts import str2bool
from datapipe.datasets import BaseData
parser = argparse.ArgumentParser()
parser.add_argument("--bs", type=int, default=1, help="Batch size")
parser.add_argument("-i", "--indir", type=str, default="", help="Path to save the testing images")
parser.add_argument("-r", "--refdir", type=str, default="", help="Reference images for fid")
parser.add_argument("-t", "--tocpu", type=str2bool, default='false')
parser.add_argument("--pi", type=str2bool, default='false', help="PI metric")
parser.add_argument("--niqe", type=str2bool, default='false', help="NIQE metric")
parser.add_argument("--maniqa", type=str2bool, default='false', help="MANIQA metric")
parser.add_argument("--tres", type=str2bool, default='false', help="TReS metric")
parser.add_argument("--dbcnn", type=str2bool, default='false', help="DBCNN metric")
args = parser.parse_args()
# setting logger
log_path = str(Path(args.indir).parent / 'metrics.log')
logger = base_logger
logger.remove()
logger.add(log_path, format="{time:YYYY-MM-DD(HH:mm:ss)}: {message}", mode='w', level='INFO')
logger.add(sys.stderr, format="{message}", level='INFO')
logger.info(f"Image Floder: {args.indir}")
if args.pi:
pi_metric = pyiqa.create_metric('pi')
if args.niqe:
niqe_metric = pyiqa.create_metric('niqe')
if args.maniqa:
maniqa_metric = pyiqa.create_metric('maniqa')
if args.tres:
tres_metric = pyiqa.create_metric('tres')
if args.dbcnn:
dbcnn_metric = pyiqa.create_metric('dbcnn')
if args.refdir:
fid_metric = pyiqa.create_metric('fid')
if args.tocpu:
clipiqa_metric = pyiqa.create_metric('clipiqa').to('cpu')
musiq_metric = pyiqa.create_metric('musiq').to('cpu')
else:
clipiqa_metric = pyiqa.create_metric('clipiqa')
musiq_metric = pyiqa.create_metric('musiq')
dataset = BaseData(
dir_path=args.indir,
transform_type='default',
transform_kwargs={'mean': 0.0, 'std': 1.0},
need_path=True,
im_exts=['png', 'jpeg', 'jpg', ],
recursive=False,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.bs,
shuffle=False,
drop_last=False,
num_workers=0
)
logger.info(f'Number of images: {len(dataset)}')
metrics = {
'PI': 0,
'CLIPIQA': 0,
'MUSIQ': 0,
'MANIQA': 0,
'TRES': 0,
'DBCNN': 0,
}
if args.niqe:
metrics['NIQE'] = 0
for ii, data in enumerate(dataloader):
im = data['image'].cuda() # N x h x w x 3, [0,1]
current_bs = im.shape[0]
if args.pi:
current_pi = pi_metric(im).sum().item()
if args.niqe:
current_niqe = niqe_metric(im).sum().item()
if args.maniqa:
current_maniqa = maniqa_metric(im).sum().item()
if args.tres:
current_tres = tres_metric(im).sum().item()
if args.dbcnn:
current_dbcnn = dbcnn_metric(im).sum().item()
if args.tocpu:
current_clipiqa = clipiqa_metric(im.cpu()).sum().item()
current_musiq = musiq_metric(im.cpu()).sum().item()
else:
current_clipiqa = clipiqa_metric(im).sum().item()
current_musiq = musiq_metric(im).sum().item()
if (ii+1) % 10 == 0:
log_str = ('Processing: {:03d}/{:03d}'.format(ii+1, math.ceil(len(dataset) / args.bs)))
logger.info(log_str)
metrics['CLIPIQA'] += current_clipiqa
metrics['MUSIQ'] += current_musiq
if args.pi:
metrics['PI'] += current_pi
if args.niqe:
metrics['NIQE'] += current_niqe
if args.maniqa:
metrics['MANIQA'] += current_maniqa
if args.tres:
metrics['TRES'] += current_tres
if args.dbcnn:
metrics['DBCNN'] += current_dbcnn
for key in metrics.keys():
metrics[key] /= len(dataset)
if args.refdir:
metrics['FID'] = fid_metric(args.indir, args.refdir, mode='legacy_pytorch')
logger.info(f"MEAN CLIPIQA: {metrics['CLIPIQA']:6.4f}")
logger.info(f"MEAN MUSIQ: {metrics['MUSIQ']:6.4f}")
if args.pi:
logger.info(f"MEAN PI: {metrics['PI']:6.4f}")
if args.niqe:
logger.info(f"MEAN NIQE: {metrics['NIQE']:6.4f}")
if args.maniqa:
logger.info(f"MEAN MANIQA: {metrics['MANIQA']:6.4f}")
if args.tres:
logger.info(f"MEAN TRES: {metrics['TRES']:6.4f}")
if args.dbcnn:
logger.info(f"MEAN DBCNN: {metrics['DBCNN']:6.4f}")
if args.refdir:
logger.info(f"MEAN FID: {metrics['FID']:6.4f}")

195
scripts/cal_metrics_ref.py Normal file
View File

@@ -0,0 +1,195 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-08-13 21:37:58
'''
Calculate PSNR, SSIM, LPIPS, and NIQE.
'''
import sys
from pathlib import Path
import os
import math
import lpips
import pyiqa
import torch
import argparse
from einops import rearrange
from loguru import logger as base_logger
sys.path.append(str(Path(__file__).resolve().parents[1]))
from utils import util_image
from utils.util_opts import str2bool
from datapipe.datasets import BaseData
parser = argparse.ArgumentParser()
parser.add_argument("--bs", type=int, default=16, help="Batch size")
parser.add_argument("--gt_dir", type=str, default="", help="Path to save the HQ images")
parser.add_argument("--sr_dir", type=str, default="", help="Path to save the SR images")
parser.add_argument("--log_name", type=str, default='metrics.log', help="Logging path")
parser.add_argument("--test_y_channel", type=str2bool, default='true', help="Y channel for PSNR and SSIM")
parser.add_argument("--fid", type=str2bool, default='false', help="Calculating FID")
parser.add_argument("--niqe", type=str2bool, default='false', help="Calculating NIQE")
parser.add_argument("--dists", type=str2bool, default='false', help="Calculating DISTS")
parser.add_argument("--maniqa", type=str2bool, default='false', help="Calculating MANIQA")
parser.add_argument("--pi", type=str2bool, default='false', help="Calculating PI")
parser.add_argument("--tocpu", type=str2bool, default='false', help="Moving model to CPU")
args = parser.parse_args()
# setting logger
log_path = str(Path(args.sr_dir).parent / f'{args.log_name}')
logger = base_logger
logger.remove()
logger.add(log_path, format="{time:YYYY-MM-DD(HH:mm:ss)}: {message}", mode='w', level='INFO')
logger.add(sys.stderr, format="{message}", level='INFO')
logger.info(f"Ground truth: {args.gt_dir}")
logger.info(f"SR result: {args.sr_dir}")
if args.test_y_channel:
psnr_metric = pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr')
ssim_metric = pyiqa.create_metric('ssim', test_y_channel=True, color_space='ycbcr')
else:
psnr_metric = pyiqa.create_metric('psnr', test_y_channel=False, color_space='rgb')
ssim_metric = pyiqa.create_metric('ssim', test_y_channel=False, color_space='rgb')
if args.fid:
fid_metric = pyiqa.create_metric('fid')
if args.niqe:
niqe_metric = pyiqa.create_metric('niqe')
if args.dists:
dists_metric = pyiqa.create_metric('dists')
if args.maniqa:
maniqa_metric = pyiqa.create_metric('maniqa')
if args.pi:
pi_metric = pyiqa.create_metric('pi')
loss_fn_vgg = lpips.LPIPS(net='vgg').cuda()
loss_fn_alex = lpips.LPIPS(net='alex').cuda()
if args.tocpu:
clipiqa_metric = pyiqa.create_metric('clipiqa').to('cpu')
musiq_metric = pyiqa.create_metric('musiq').to('cpu')
else:
clipiqa_metric = pyiqa.create_metric('clipiqa')
musiq_metric = pyiqa.create_metric('musiq')
dataset = BaseData(
dir_path=args.sr_dir,
transform_type='default',
transform_kwargs={'mean': 0.0, 'std': 1.0},
extra_dir_path=args.gt_dir,
extra_transform_type='default',
extra_transform_kwargs={'mean': 0.0, 'std': 1.0},
need_path=True,
im_exts=['png', 'jpg'],
recursive=False,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.bs,
shuffle=False,
drop_last=False,
num_workers=0
)
logger.info(f'Number of images: {len(dataset)}')
metrics = {
'PSNR': 0,
'SSIM': 0,
'LPIPS_VGG': 0,
'LPIPS_ALEX': 0,
'CLIPIQA': 0,
'MUSIQ': 0,
}
if args.niqe:
metrics['NIQE'] = 0
if args.dists:
metrics['DISTS'] = 0
if args.maniqa:
metrics['MANIQA'] = 0
if args.pi:
metrics['PI'] = 0
for ii, data in enumerate(dataloader):
im_sr = data['image'].cuda() # N x h x w x 3, [0,1]
im_gt = data['gt'].cuda() # N x h x w x 3, [0,1]
current_bs = im_sr.shape[0]
if not (im_sr.shape == im_gt.shape):
height = min(im_sr.shape[-2], im_gt.shape[-2])
width = min(im_sr.shape[-1], im_gt.shape[-1])
im_sr = im_sr[:, :, :height, :width]
im_gt = im_gt[:, :, :height, :width]
current_psnr = psnr_metric(im_sr, im_gt).mean().item()
current_ssim = ssim_metric(im_sr, im_gt).mean().item()
current_lpips_vgg = loss_fn_vgg(
(im_gt - 0.5) / 0.5,
(im_sr - 0.5) / 0.5,
).mean().item()
current_lpips_alex = loss_fn_alex(
(im_gt - 0.5) / 0.5,
(im_sr - 0.5) / 0.5,
).mean().item()
if args.tocpu:
current_clipiqa = clipiqa_metric(im_sr.cpu()).mean().item()
current_musiq = musiq_metric(im_sr.cpu()).mean().item()
else:
current_clipiqa = clipiqa_metric(im_sr).mean().item()
current_musiq = musiq_metric(im_sr).mean().item()
if args.niqe:
current_niqe = niqe_metric(im_sr).mean().item()
if args.dists:
current_dists = dists_metric(im_sr, im_gt).mean().item()
if args.maniqa:
current_maniqa = maniqa_metric(im_sr).mean().item()
if args.pi:
current_pi = pi_metric(im_sr).mean().item()
if (ii+1) % 30 == 0:
log_str = ('Processing: {:03d}/{:03d}, PSNR={:5.2f}, LPIPS={:6.4f}/{:6.4f}, CLIPIQA={:6.4f}, MUSIQ={:6.4f}'.format(
ii+1,
math.ceil(len(dataset) /args.bs),
current_psnr,
current_lpips_vgg,
current_lpips_alex,
current_clipiqa,
current_musiq,
))
logger.info(log_str)
metrics['PSNR'] += current_psnr * current_bs
metrics['SSIM'] += current_ssim * current_bs
metrics['LPIPS_VGG'] += current_lpips_vgg * current_bs
metrics['LPIPS_ALEX'] += current_lpips_alex * current_bs
metrics['CLIPIQA'] += current_clipiqa * current_bs
metrics['MUSIQ'] += current_musiq * current_bs
if args.niqe:
metrics['NIQE'] += current_niqe * current_bs
if args.dists:
metrics['DISTS'] += current_dists * current_bs
if args.maniqa:
metrics['MANIQA'] += current_maniqa * current_bs
if args.pi:
metrics['PI'] += current_pi * current_bs
for key in metrics.keys():
metrics[key] /= len(dataset)
if args.fid:
metrics['FID'] = fid_metric(args.sr_dir, args.gt_dir)
logger.info(f"MEAN PSNR: {metrics['PSNR']:5.2f}")
logger.info(f"MEAN SSIM: {metrics['SSIM']:6.4f}")
logger.info(f"MEAN LPIPS(VGG): {metrics['LPIPS_VGG']:6.4f}")
logger.info(f"MEAN LPIPS(ALEX): {metrics['LPIPS_ALEX']:6.4f}")
logger.info(f"MEAN CLIPIQA: {metrics['CLIPIQA']:6.4f}")
logger.info(f"MEAN MUSIQ: {metrics['MUSIQ']:6.4f}")
if args.fid:
logger.info(f"MEAN FID: {metrics['FID']:6.2f}")
if args.niqe:
logger.info(f"MEAN NIQE: {metrics['NIQE']:7.4f}")
if args.dists:
logger.info(f"MEAN DISTS: {metrics['DISTS']:6.4f}")
if args.maniqa:
logger.info(f"MEAN MANIQA: {metrics['MANIQA']:6.4f}")
if args.pi:
logger.info(f"MEAN PI: {metrics['PI']:7.4f}")

View File

@@ -0,0 +1,107 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2024-04-07 20:57:36
import os
import torch
import random
import argparse
import numpy as np
from omegaconf import OmegaConf
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parents[1]))
from basicsr.data.realesrgan_dataset import RealESRGANDataset
from utils import util_image
from utils import util_common
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--indir",
type=str,
default="/mnt/lustre/share/zhangwenwei/data/imagenet/val",
help="Folder to save the checkpoints and training log",
)
parser.add_argument(
"-o",
"--outdir",
type=str,
default="./ImageNet-Test",
help="Folder to save the checkpoints and training log",
)
parser.add_argument(
"-r",
"--resolution",
type=int,
default=1024,
help="Image resolution of the ground truth",
)
parser.add_argument(
"--num_imgs",
type=int,
default=-1,
help="Number of images.",
)
args = parser.parse_args()
if Path(args.indir).is_dir():
img_list = sorted([x for x in Path(args.indir).glob('*.[JjPp][PpNn]*[Gg]')])
elif args.indir.endswith('txt'):
img_list = util_common.readline_txt(args.indir)
else:
raise ValueError('Please input valid args.indir!')
print(f'Number of images in the input folder: {len(img_list)}')
random.seed(10000)
random.shuffle(img_list)
num_imgs = args.num_imgs
if num_imgs > 0:
assert num_imgs <= len(img_list)
img_list = random.sample(img_list, k=num_imgs)
gt_dir = Path(args.outdir) / 'gt'
if not gt_dir.exists():
gt_dir.mkdir(parents=True)
lq_dir = Path(args.outdir) / 'lq'
if not lq_dir.exists():
lq_dir.mkdir(parents=True)
# Loading configuration
configs = OmegaConf.load('./configs/degradation_testing_realesrgan.yaml')
opts, opts_degradation = configs.opts, configs.degradation
opts['gt_size'] = args.resolution
opts_degradation['gt_size'] = args.resolution
dataset = RealESRGANDataset(opts, mode='testing')
dataset.image_paths = img_list
dataset.text_paths = [None, ] * len(img_list)
dataset.moment_paths = [None, ] * len(img_list)
for ii in range(len(img_list)):
data_dict1 = dataset.__getitem__(ii)
if (ii + 1) % 100 == 0:
print(f'Processing: {ii+1}/{len(img_list)}')
prefix = 'realesrgan'
data_dict2 = dataset.degrade_fun(
opts_degradation,
im_gt=data_dict1['gt'].unsqueeze(0),
kernel1=data_dict1['kernel1'],
kernel2=data_dict1['kernel2'],
sinc_kernel=data_dict1['sinc_kernel'],
)
im_lq, im_gt = data_dict2['lq'], data_dict2['gt']
im_lq, im_gt = util_image.tensor2img([im_lq, im_gt], rgb2bgr=True, min_max=(0,1) ) # uint8
im_name = Path(data_dict1['gt_path']).stem
im_path_gt = gt_dir / f'{im_name}.png'
util_image.imwrite(im_gt, im_path_gt, chn='bgr', dtype_in='uint8')
im_path_lq = lq_dir / f'{im_name}.png'
util_image.imwrite(im_lq, im_path_lq, chn='bgr', dtype_in='uint8')
if __name__ == "__main__":
main()