mirror of
https://github.com/aljazceru/InvSR.git
synced 2025-12-17 14:24:27 +01:00
385 lines
17 KiB
Python
385 lines
17 KiB
Python
import cv2
|
|
import math
|
|
import numpy as np
|
|
import os
|
|
import os.path as osp
|
|
import random
|
|
import time
|
|
import torch
|
|
from pathlib import Path
|
|
|
|
import albumentations
|
|
|
|
import torch.nn.functional as F
|
|
from torch.utils import data as data
|
|
|
|
from basicsr.utils import DiffJPEG
|
|
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
|
|
from basicsr.data.transforms import augment
|
|
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
|
from basicsr.utils.registry import DATASET_REGISTRY
|
|
from basicsr.utils.img_process_util import filter2D
|
|
from basicsr.data.transforms import paired_random_crop, random_crop
|
|
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
|
|
|
from utils import util_image
|
|
|
|
def readline_txt(txt_file):
|
|
txt_file = [txt_file, ] if isinstance(txt_file, str) else txt_file
|
|
out = []
|
|
for txt_file_current in txt_file:
|
|
with open(txt_file_current, 'r') as ff:
|
|
out.extend([x[:-1] for x in ff.readlines()])
|
|
|
|
return out
|
|
|
|
@DATASET_REGISTRY.register(suffix='basicsr')
|
|
class RealESRGANDataset(data.Dataset):
|
|
"""Dataset used for Real-ESRGAN model:
|
|
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
|
|
|
It loads gt (Ground-Truth) images, and augments them.
|
|
It also generates blur kernels and sinc kernels for generating low-quality images.
|
|
Note that the low-quality images are processed in tensors on GPUS for faster processing.
|
|
|
|
Args:
|
|
opt (dict): Config for train datasets. It contains the following keys:
|
|
dataroot_gt (str): Data root path for gt.
|
|
meta_info (str): Path for meta information file.
|
|
io_backend (dict): IO backend type and other kwarg.
|
|
use_hflip (bool): Use horizontal flips.
|
|
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
|
Please see more options in the codes.
|
|
"""
|
|
|
|
def __init__(self, opt, mode='training'):
|
|
super(RealESRGANDataset, self).__init__()
|
|
self.opt = opt
|
|
self.file_client = None
|
|
self.io_backend_opt = opt['io_backend']
|
|
|
|
# file client (lmdb io backend)
|
|
self.image_paths = []
|
|
self.text_paths = []
|
|
self.moment_paths = []
|
|
if opt.get('data_source', None) is not None:
|
|
for ii in range(len(opt['data_source'])):
|
|
configs = opt['data_source'].get(f'source{ii+1}')
|
|
root_path = Path(configs.root_path)
|
|
im_folder = root_path / configs.image_path
|
|
im_ext = configs.im_ext
|
|
image_stems = sorted([x.stem for x in im_folder.glob(f"*.{im_ext}")])
|
|
if configs.get('length', None) is not None:
|
|
assert configs.length < len(image_stems)
|
|
image_stems = image_stems[:configs.length]
|
|
|
|
if configs.get("text_path", None) is not None:
|
|
text_folder = root_path / configs.text_path
|
|
text_stems = [x.stem for x in text_folder.glob("*.txt")]
|
|
image_stems = sorted(list(set(image_stems).intersection(set(text_stems))))
|
|
self.text_paths.extend([str(text_folder / f"{x}.txt") for x in image_stems])
|
|
else:
|
|
self.text_paths.extend([None, ] * len(image_stems))
|
|
|
|
self.image_paths.extend([str(im_folder / f"{x}.{im_ext}") for x in image_stems])
|
|
|
|
if configs.get("moment_path", None) is not None:
|
|
moment_folder = root_path / configs.moment_path
|
|
self.moment_paths.extend([str(moment_folder / f"{x}.npy") for x in image_stems])
|
|
else:
|
|
self.moment_paths.extend([None, ] * len(image_stems))
|
|
|
|
# blur settings for the first degradation
|
|
self.blur_kernel_size = opt['blur_kernel_size']
|
|
self.kernel_list = opt['kernel_list']
|
|
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
|
|
self.blur_sigma = opt['blur_sigma']
|
|
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
|
|
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
|
|
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
|
|
|
|
# blur settings for the second degradation
|
|
self.blur_kernel_size2 = opt['blur_kernel_size2']
|
|
self.kernel_list2 = opt['kernel_list2']
|
|
self.kernel_prob2 = opt['kernel_prob2']
|
|
self.blur_sigma2 = opt['blur_sigma2']
|
|
self.betag_range2 = opt['betag_range2']
|
|
self.betap_range2 = opt['betap_range2']
|
|
self.sinc_prob2 = opt['sinc_prob2']
|
|
|
|
# a final sinc filter
|
|
self.final_sinc_prob = opt['final_sinc_prob']
|
|
|
|
self.kernel_range1 = [x for x in range(3, opt['blur_kernel_size'], 2)] # kernel size ranges from 7 to 21
|
|
self.kernel_range2 = [x for x in range(3, opt['blur_kernel_size2'], 2)] # kernel size ranges from 7 to 21
|
|
# TODO: kernel range is now hard-coded, should be in the configure file
|
|
# convolving with pulse tensor brings no blurry effect
|
|
self.pulse_tensor = torch.zeros(opt['blur_kernel_size2'], opt['blur_kernel_size2']).float()
|
|
self.pulse_tensor[opt['blur_kernel_size2']//2, opt['blur_kernel_size2']//2] = 1
|
|
|
|
self.mode = mode
|
|
|
|
def __getitem__(self, index):
|
|
if self.file_client is None:
|
|
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
|
|
|
# -------------------------------- Load gt images -------------------------------- #
|
|
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
|
gt_path = self.image_paths[index]
|
|
# avoid errors caused by high latency in reading files
|
|
retry = 3
|
|
while retry > 0:
|
|
try:
|
|
img_bytes = self.file_client.get(gt_path, 'gt')
|
|
img_gt = imfrombytes(img_bytes, float32=True)
|
|
except:
|
|
index = random.randint(0, self.__len__())
|
|
gt_path = self.image_paths[index]
|
|
time.sleep(1) # sleep 1s for occasional server congestion
|
|
finally:
|
|
retry -= 1
|
|
if self.mode == 'testing':
|
|
if not hasattr(self, 'test_aug'):
|
|
self.test_aug = albumentations.Compose([
|
|
albumentations.SmallestMaxSize(
|
|
max_size=self.opt['gt_size'],
|
|
interpolation=cv2.INTER_AREA,
|
|
),
|
|
albumentations.CenterCrop(self.opt['gt_size'], self.opt['gt_size']),
|
|
])
|
|
img_gt = self.test_aug(image=img_gt)['image']
|
|
elif self.mode == 'training':
|
|
# -------------------- Do augmentation for training: flip, rotation -------------------- #
|
|
if self.opt['use_hflip'] or self.opt['use_rot']:
|
|
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
|
|
|
|
h, w = img_gt.shape[0:2]
|
|
gt_size = self.opt['gt_size']
|
|
|
|
# resize or pad
|
|
if not self.opt['random_crop']:
|
|
if not min(h, w) == gt_size:
|
|
if not hasattr(self, 'smallest_resizer'):
|
|
self.smallest_resizer = util_image.SmallestMaxSize(
|
|
max_size=gt_size, pass_resize=False,
|
|
)
|
|
img_gt = self.smallest_resizer(img_gt)
|
|
|
|
# center crop
|
|
if not hasattr(self, 'center_cropper'):
|
|
self.center_cropper = albumentations.CenterCrop(gt_size, gt_size)
|
|
img_gt = self.center_cropper(image=img_gt)['image']
|
|
else:
|
|
img_gt = random_crop(img_gt, self.opt['gt_size'])
|
|
else:
|
|
raise ValueError(f'Unexpected value {self.mode} for mode parameter')
|
|
|
|
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
|
|
kernel_size = random.choice(self.kernel_range1)
|
|
if np.random.uniform() < self.opt['sinc_prob']:
|
|
# this sinc filter setting is for kernels ranging from [7, 21]
|
|
if kernel_size < 13:
|
|
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
|
else:
|
|
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
|
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
|
else:
|
|
kernel = random_mixed_kernels(
|
|
self.kernel_list,
|
|
self.kernel_prob,
|
|
kernel_size,
|
|
self.blur_sigma,
|
|
self.blur_sigma, [-math.pi, math.pi],
|
|
self.betag_range,
|
|
self.betap_range,
|
|
noise_range=None)
|
|
# pad kernel
|
|
pad_size = (self.blur_kernel_size - kernel_size) // 2
|
|
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
|
|
|
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
|
|
kernel_size = random.choice(self.kernel_range2)
|
|
if np.random.uniform() < self.opt['sinc_prob2']:
|
|
if kernel_size < 13:
|
|
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
|
else:
|
|
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
|
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
|
else:
|
|
kernel2 = random_mixed_kernels(
|
|
self.kernel_list2,
|
|
self.kernel_prob2,
|
|
kernel_size,
|
|
self.blur_sigma2,
|
|
self.blur_sigma2, [-math.pi, math.pi],
|
|
self.betag_range2,
|
|
self.betap_range2,
|
|
noise_range=None)
|
|
|
|
# pad kernel
|
|
pad_size = (self.blur_kernel_size2 - kernel_size) // 2
|
|
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
|
|
|
# ------------------------------------- the final sinc kernel ------------------------------------- #
|
|
if np.random.uniform() < self.opt['final_sinc_prob']:
|
|
kernel_size = random.choice(self.kernel_range2)
|
|
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
|
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=self.blur_kernel_size2)
|
|
sinc_kernel = torch.FloatTensor(sinc_kernel)
|
|
else:
|
|
sinc_kernel = self.pulse_tensor
|
|
|
|
# BGR to RGB, HWC to CHW, numpy to tensor
|
|
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
|
|
kernel = torch.FloatTensor(kernel)
|
|
kernel2 = torch.FloatTensor(kernel2)
|
|
|
|
if self.text_paths[index] is None or self.opt['random_crop']:
|
|
prompt = ""
|
|
else:
|
|
with open(self.text_paths[index], 'r') as ff:
|
|
prompt = ff.read()
|
|
if self.opt.max_token_length is not None:
|
|
prompt = prompt[:self.opt.max_token_length]
|
|
|
|
return_d = {
|
|
'gt': img_gt,
|
|
'gt_path': gt_path,
|
|
'txt': prompt,
|
|
'kernel1': kernel,
|
|
'kernel2': kernel2,
|
|
'sinc_kernel': sinc_kernel,
|
|
}
|
|
if self.moment_paths[index] is not None and (not self.opt['random_crop']):
|
|
return_d['gt_moment'] = np.load(self.moment_paths[index])
|
|
|
|
return return_d
|
|
|
|
def __len__(self):
|
|
return len(self.image_paths)
|
|
|
|
def degrade_fun(self, conf_degradation, im_gt, kernel1, kernel2, sinc_kernel):
|
|
if not hasattr(self, 'jpeger'):
|
|
self.jpeger = DiffJPEG(differentiable=False) # simulate JPEG compression artifacts
|
|
|
|
ori_h, ori_w = im_gt.size()[2:4]
|
|
sf = conf_degradation.sf
|
|
|
|
# ----------------------- The first degradation process ----------------------- #
|
|
# blur
|
|
out = filter2D(im_gt, kernel1)
|
|
# random resize
|
|
updown_type = random.choices(
|
|
['up', 'down', 'keep'],
|
|
conf_degradation['resize_prob'],
|
|
)[0]
|
|
if updown_type == 'up':
|
|
scale = random.uniform(1, conf_degradation['resize_range'][1])
|
|
elif updown_type == 'down':
|
|
scale = random.uniform(conf_degradation['resize_range'][0], 1)
|
|
else:
|
|
scale = 1
|
|
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
|
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
|
# add noise
|
|
gray_noise_prob = conf_degradation['gray_noise_prob']
|
|
if random.random() < conf_degradation['gaussian_noise_prob']:
|
|
out = random_add_gaussian_noise_pt(
|
|
out,
|
|
sigma_range=conf_degradation['noise_range'],
|
|
clip=True,
|
|
rounds=False,
|
|
gray_prob=gray_noise_prob,
|
|
)
|
|
else:
|
|
out = random_add_poisson_noise_pt(
|
|
out,
|
|
scale_range=conf_degradation['poisson_scale_range'],
|
|
gray_prob=gray_noise_prob,
|
|
clip=True,
|
|
rounds=False)
|
|
# JPEG compression
|
|
jpeg_p = out.new_zeros(out.size(0)).uniform_(*conf_degradation['jpeg_range'])
|
|
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
|
out = self.jpeger(out, quality=jpeg_p)
|
|
|
|
# ----------------------- The second degradation process ----------------------- #
|
|
# blur
|
|
if random.random() < conf_degradation['second_order_prob']:
|
|
if random.random() < conf_degradation['second_blur_prob']:
|
|
out = filter2D(out, kernel2)
|
|
# random resize
|
|
updown_type = random.choices(
|
|
['up', 'down', 'keep'],
|
|
conf_degradation['resize_prob2'],
|
|
)[0]
|
|
if updown_type == 'up':
|
|
scale = random.uniform(1, conf_degradation['resize_range2'][1])
|
|
elif updown_type == 'down':
|
|
scale = random.uniform(conf_degradation['resize_range2'][0], 1)
|
|
else:
|
|
scale = 1
|
|
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
|
out = F.interpolate(
|
|
out,
|
|
size=(int(ori_h / sf * scale), int(ori_w / sf * scale)),
|
|
mode=mode,
|
|
)
|
|
# add noise
|
|
gray_noise_prob = conf_degradation['gray_noise_prob2']
|
|
if random.random() < conf_degradation['gaussian_noise_prob2']:
|
|
out = random_add_gaussian_noise_pt(
|
|
out,
|
|
sigma_range=conf_degradation['noise_range2'],
|
|
clip=True,
|
|
rounds=False,
|
|
gray_prob=gray_noise_prob,
|
|
)
|
|
else:
|
|
out = random_add_poisson_noise_pt(
|
|
out,
|
|
scale_range=conf_degradation['poisson_scale_range2'],
|
|
gray_prob=gray_noise_prob,
|
|
clip=True,
|
|
rounds=False,
|
|
)
|
|
|
|
# JPEG compression + the final sinc filter
|
|
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
|
# as one operation.
|
|
# We consider two orders:
|
|
# 1. [resize back + sinc filter] + JPEG compression
|
|
# 2. JPEG compression + [resize back + sinc filter]
|
|
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
|
if random.random() < 0.5:
|
|
# resize back + the final sinc filter
|
|
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
|
out = F.interpolate(
|
|
out,
|
|
size=(ori_h // sf, ori_w // sf),
|
|
mode=mode,
|
|
)
|
|
out = filter2D(out, sinc_kernel)
|
|
# JPEG compression
|
|
jpeg_p = out.new_zeros(out.size(0)).uniform_(*conf_degradation['jpeg_range2'])
|
|
out = torch.clamp(out, 0, 1)
|
|
out = self.jpeger(out, quality=jpeg_p)
|
|
else:
|
|
# JPEG compression
|
|
jpeg_p = out.new_zeros(out.size(0)).uniform_(*conf_degradation['jpeg_range2'])
|
|
out = torch.clamp(out, 0, 1)
|
|
out = self.jpeger(out, quality=jpeg_p)
|
|
# resize back + the final sinc filter
|
|
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
|
out = F.interpolate(
|
|
out,
|
|
size=(ori_h // sf, ori_w // sf),
|
|
mode=mode,
|
|
)
|
|
out = filter2D(out, sinc_kernel)
|
|
|
|
# clamp and round
|
|
im_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
|
|
|
return {'lq':im_lq.contiguous(), 'gt':im_gt}
|