first commit

This commit is contained in:
zsyOAOA
2024-12-11 18:46:36 +08:00
parent 9e65255d34
commit 27f2eb7dc3
847 changed files with 377076 additions and 2 deletions

BIN
basicsr/.DS_Store vendored Normal file

Binary file not shown.

4
basicsr/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
# https://github.com/xinntao/BasicSR
# flake8: noqa
from .data import *
from .utils import *

BIN
basicsr/data/.DS_Store vendored Normal file

Binary file not shown.

101
basicsr/data/__init__.py Normal file
View File

@@ -0,0 +1,101 @@
import importlib
import numpy as np
import random
import torch
import torch.utils.data
from copy import deepcopy
from functools import partial
from os import path as osp
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.registry import DATASET_REGISTRY
__all__ = ['build_dataset', 'build_dataloader']
# automatically scan and import dataset modules for registry
# scan all the files under the data folder with '_dataset' in file names
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
# import all the dataset modules
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
def build_dataset(dataset_opt):
"""Build dataset from options.
Args:
dataset_opt (dict): Configuration for dataset. It must contain:
name (str): Dataset name.
type (str): Dataset type.
"""
dataset_opt = deepcopy(dataset_opt)
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
logger = get_root_logger()
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
return dataset
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
"""Build dataloader.
Args:
dataset (torch.utils.data.Dataset): Dataset.
dataset_opt (dict): Dataset options. It contains the following keys:
phase (str): 'train' or 'val'.
num_worker_per_gpu (int): Number of workers for each GPU.
batch_size_per_gpu (int): Training batch size for each GPU.
num_gpu (int): Number of GPUs. Used only in the train phase.
Default: 1.
dist (bool): Whether in distributed training. Used only in the train
phase. Default: False.
sampler (torch.utils.data.sampler): Data sampler. Default: None.
seed (int | None): Seed. Default: None
"""
phase = dataset_opt['phase']
rank, _ = get_dist_info()
if phase == 'train':
if dist: # distributed training
batch_size = dataset_opt['batch_size_per_gpu']
num_workers = dataset_opt['num_worker_per_gpu']
else: # non-distributed training
multiplier = 1 if num_gpu == 0 else num_gpu
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
dataloader_args = dict(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
sampler=sampler,
drop_last=True)
if sampler is None:
dataloader_args['shuffle'] = True
dataloader_args['worker_init_fn'] = partial(
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
elif phase in ['val', 'test']: # validation
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
else:
raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
prefetch_mode = dataset_opt.get('prefetch_mode')
if prefetch_mode == 'cpu': # CPUPrefetcher
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
logger = get_root_logger()
logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
else:
# prefetch_mode=None: Normal dataloader
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
return torch.utils.data.DataLoader(**dataloader_args)
def worker_init_fn(worker_id, num_workers, rank, seed):
# Set the worker seed to num_workers * rank + worker_id + seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,48 @@
import math
import torch
from torch.utils.data.sampler import Sampler
class EnlargedSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
Modified from torch.utils.data.distributed.DistributedSampler
Support enlarging the dataset for iteration-based training, for saving
time when restart the dataloader after each epoch
Args:
dataset (torch.utils.data.Dataset): Dataset used for sampling.
num_replicas (int | None): Number of processes participating in
the training. It is usually the world_size.
rank (int | None): Rank of the current process within num_replicas.
ratio (int): Enlarging ratio. Default: 1.
"""
def __init__(self, dataset, num_replicas, rank, ratio=1):
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(self.total_size, generator=g).tolist()
dataset_size = len(self.dataset)
indices = [v % dataset_size for v in indices]
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch

315
basicsr/data/data_util.py Normal file
View File

@@ -0,0 +1,315 @@
import cv2
import numpy as np
import torch
from os import path as osp
from torch.nn import functional as F
from basicsr.data.transforms import mod_crop
from basicsr.utils import img2tensor, scandir
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
"""Read a sequence of images from a given folder path.
Args:
path (list[str] | str): List of image paths or image folder path.
require_mod_crop (bool): Require mod crop for each image.
Default: False.
scale (int): Scale factor for mod_crop. Default: 1.
return_imgname(bool): Whether return image names. Default False.
Returns:
Tensor: size (t, c, h, w), RGB, [0, 1].
list[str]: Returned image name list.
"""
if isinstance(path, list):
img_paths = path
else:
img_paths = sorted(list(scandir(path, full_path=True)))
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
if require_mod_crop:
imgs = [mod_crop(img, scale) for img in imgs]
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
imgs = torch.stack(imgs, dim=0)
if return_imgname:
imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
return imgs, imgnames
else:
return imgs
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
"""Generate an index list for reading `num_frames` frames from a sequence
of images.
Args:
crt_idx (int): Current center index.
max_frame_num (int): Max number of the sequence of images (from 1).
num_frames (int): Reading num_frames frames.
padding (str): Padding mode, one of
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
Examples: current_idx = 0, num_frames = 5
The generated frame indices under different padding mode:
replicate: [0, 0, 0, 1, 2]
reflection: [2, 1, 0, 1, 2]
reflection_circle: [4, 3, 0, 1, 2]
circle: [3, 4, 0, 1, 2]
Returns:
list[int]: A list of indices.
"""
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
max_frame_num = max_frame_num - 1 # start from 0
num_pad = num_frames // 2
indices = []
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
if i < 0:
if padding == 'replicate':
pad_idx = 0
elif padding == 'reflection':
pad_idx = -i
elif padding == 'reflection_circle':
pad_idx = crt_idx + num_pad - i
else:
pad_idx = num_frames + i
elif i > max_frame_num:
if padding == 'replicate':
pad_idx = max_frame_num
elif padding == 'reflection':
pad_idx = max_frame_num * 2 - i
elif padding == 'reflection_circle':
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
else:
pad_idx = i - num_frames
else:
pad_idx = i
indices.append(pad_idx)
return indices
def paired_paths_from_lmdb(folders, keys):
"""Generate paired paths from lmdb files.
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
::
lq.lmdb
├── data.mdb
├── lock.mdb
├── meta_info.txt
The data.mdb and lock.mdb are standard lmdb files and you can refer to
https://lmdb.readthedocs.io/en/release/ for more details.
The meta_info.txt is a specified txt file to record the meta information
of our datasets. It will be automatically created when preparing
datasets by our provided dataset tools.
Each line in the txt file records
1)image name (with extension),
2)image shape,
3)compression level, separated by a white space.
Example: `baboon.png (120,125,3) 1`
We use the image name without extension as the lmdb key.
Note that we use the same key for the corresponding lq and gt images.
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt'].
Note that this key is different from lmdb keys.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
input_folder, gt_folder = folders
input_key, gt_key = keys
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
f'formats. But received {input_key}: {input_folder}; '
f'{gt_key}: {gt_folder}')
# ensure that the two meta_info files are the same
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
input_lmdb_keys = [line.split('.')[0] for line in fin]
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
gt_lmdb_keys = [line.split('.')[0] for line in fin]
if set(input_lmdb_keys) != set(gt_lmdb_keys):
raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
else:
paths = []
for lmdb_key in sorted(input_lmdb_keys):
paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
return paths
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
"""Generate paired paths from an meta information file.
Each line in the meta information file contains the image names and
image shape (usually for gt), separated by a white space.
Example of an meta information file:
```
0001_s001.png (480,480,3)
0001_s002.png (480,480,3)
```
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt'].
meta_info_file (str): Path to the meta information file.
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Usually the filename_tmpl is
for files in the input folder.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
input_folder, gt_folder = folders
input_key, gt_key = keys
with open(meta_info_file, 'r') as fin:
gt_names = [line.strip().split(' ')[0] for line in fin]
paths = []
for gt_name in gt_names:
basename, ext = osp.splitext(osp.basename(gt_name))
input_name = f'{filename_tmpl.format(basename)}{ext}'
input_path = osp.join(input_folder, input_name)
gt_path = osp.join(gt_folder, gt_name)
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
return paths
def paired_paths_from_folder(folders, keys, filename_tmpl):
"""Generate paired paths from folders.
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt'].
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Usually the filename_tmpl is
for files in the input folder.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
input_folder, gt_folder = folders
input_key, gt_key = keys
input_paths = list(scandir(input_folder))
gt_paths = list(scandir(gt_folder))
assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
f'{len(input_paths)}, {len(gt_paths)}.')
paths = []
for gt_path in gt_paths:
basename, ext = osp.splitext(osp.basename(gt_path))
input_name = f'{filename_tmpl.format(basename)}{ext}'
input_path = osp.join(input_folder, input_name)
assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
gt_path = osp.join(gt_folder, gt_path)
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
return paths
def paths_from_folder(folder):
"""Generate paths from folder.
Args:
folder (str): Folder path.
Returns:
list[str]: Returned path list.
"""
paths = list(scandir(folder))
paths = [osp.join(folder, path) for path in paths]
return paths
def paths_from_lmdb(folder):
"""Generate paths from lmdb.
Args:
folder (str): Folder path.
Returns:
list[str]: Returned path list.
"""
if not folder.endswith('.lmdb'):
raise ValueError(f'Folder {folder}folder should in lmdb format.')
with open(osp.join(folder, 'meta_info.txt')) as fin:
paths = [line.split('.')[0] for line in fin]
return paths
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
"""Generate Gaussian kernel used in `duf_downsample`.
Args:
kernel_size (int): Kernel size. Default: 13.
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
Returns:
np.array: The Gaussian kernel.
"""
from scipy.ndimage import filters as filters
kernel = np.zeros((kernel_size, kernel_size))
# set element at the middle to one, a dirac delta
kernel[kernel_size // 2, kernel_size // 2] = 1
# gaussian-smooth the dirac, resulting in a gaussian filter
return filters.gaussian_filter(kernel, sigma)
def duf_downsample(x, kernel_size=13, scale=4):
"""Downsamping with Gaussian kernel used in the DUF official code.
Args:
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
kernel_size (int): Kernel size. Default: 13.
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
Default: 4.
Returns:
Tensor: DUF downsampled frames.
"""
assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
squeeze_flag = False
if x.ndim == 4:
squeeze_flag = True
x = x.unsqueeze(0)
b, t, c, h, w = x.size()
x = x.view(-1, 1, h, w)
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
x = F.conv2d(x, gaussian_filter, stride=scale)
x = x[:, :, 2:-2, 2:-2]
x = x.view(b, t, c, x.size(2), x.size(3))
if squeeze_flag:
x = x.squeeze(0)
return x

View File

@@ -0,0 +1,765 @@
import cv2
import math
import numpy as np
import random
import torch
from scipy import special
from scipy.stats import multivariate_normal
# from torchvision.transforms.functional_tensor import rgb_to_grayscale
from torchvision.transforms.functional import rgb_to_grayscale
# -------------------------------------------------------------------- #
# --------------------------- blur kernels --------------------------- #
# -------------------------------------------------------------------- #
# --------------------------- util functions --------------------------- #
def sigma_matrix2(sig_x, sig_y, theta):
"""Calculate the rotated sigma matrix (two dimensional matrix).
Args:
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
Returns:
ndarray: Rotated sigma matrix.
"""
d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
def mesh_grid(kernel_size):
"""Generate the mesh grid, centering at zero.
Args:
kernel_size (int):
Returns:
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
xx (ndarray): with the shape (kernel_size, kernel_size)
yy (ndarray): with the shape (kernel_size, kernel_size)
"""
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
xx, yy = np.meshgrid(ax, ax)
xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
1))).reshape(kernel_size, kernel_size, 2)
return xy, xx, yy
def pdf2(sigma_matrix, grid):
"""Calculate PDF of the bivariate Gaussian distribution.
Args:
sigma_matrix (ndarray): with the shape (2, 2)
grid (ndarray): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size.
Returns:
kernel (ndarrray): un-normalized kernel.
"""
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
return kernel
def cdf2(d_matrix, grid):
"""Calculate the CDF of the standard bivariate Gaussian distribution.
Used in skewed Gaussian distribution.
Args:
d_matrix (ndarrasy): skew matrix.
grid (ndarray): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size.
Returns:
cdf (ndarray): skewed cdf.
"""
rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
grid = np.dot(grid, d_matrix)
cdf = rv.cdf(grid)
return cdf
def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
isotropic (bool):
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
kernel = pdf2(sigma_matrix, grid)
kernel = kernel / np.sum(kernel)
return kernel
def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
"""Generate a bivariate generalized Gaussian kernel.
``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions``
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
beta (float): shape parameter, beta = 1 is the normal distribution.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
kernel = kernel / np.sum(kernel)
return kernel
def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
"""Generate a plateau-like anisotropic kernel.
1 / (1+x^(beta))
Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
beta (float): shape parameter, beta = 1 is the normal distribution.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
kernel = kernel / np.sum(kernel)
return kernel
def random_bivariate_Gaussian(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
noise_range=None,
isotropic=True):
"""Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
if isotropic is False:
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
else:
sigma_y = sigma_x
rotation = 0
kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
return kernel
def random_bivariate_generalized_Gaussian(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
beta_range,
noise_range=None,
isotropic=True):
"""Randomly generate bivariate generalized Gaussian kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
beta_range (tuple): [0.5, 8]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
if isotropic is False:
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
else:
sigma_y = sigma_x
rotation = 0
# assume beta_range[0] < 1 < beta_range[1]
if np.random.uniform() < 0.5:
beta = np.random.uniform(beta_range[0], 1)
else:
beta = np.random.uniform(1, beta_range[1])
kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
return kernel
def random_bivariate_plateau(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
beta_range,
noise_range=None,
isotropic=True):
"""Randomly generate bivariate plateau kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi/2, math.pi/2]
beta_range (tuple): [1, 4]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
if isotropic is False:
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
else:
sigma_y = sigma_x
rotation = 0
# TODO: this may be not proper
if np.random.uniform() < 0.5:
beta = np.random.uniform(beta_range[0], 1)
else:
beta = np.random.uniform(1, beta_range[1])
kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
return kernel
def random_mixed_kernels(kernel_list,
kernel_prob,
kernel_size=21,
sigma_x_range=(0.6, 5),
sigma_y_range=(0.6, 5),
rotation_range=(-math.pi, math.pi),
betag_range=(0.5, 8),
betap_range=(0.5, 8),
noise_range=None):
"""Randomly generate mixed kernels.
Args:
kernel_list (tuple): a list name of kernel types,
support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
'plateau_aniso']
kernel_prob (tuple): corresponding kernel probability for each
kernel type
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
beta_range (tuple): [0.5, 8]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
kernel_type = random.choices(kernel_list, kernel_prob)[0]
if kernel_type == 'iso':
kernel = random_bivariate_Gaussian(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
elif kernel_type == 'aniso':
kernel = random_bivariate_Gaussian(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
elif kernel_type == 'generalized_iso':
kernel = random_bivariate_generalized_Gaussian(
kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
betag_range,
noise_range=noise_range,
isotropic=True)
elif kernel_type == 'generalized_aniso':
kernel = random_bivariate_generalized_Gaussian(
kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
betag_range,
noise_range=noise_range,
isotropic=False)
elif kernel_type == 'plateau_iso':
kernel = random_bivariate_plateau(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
elif kernel_type == 'plateau_aniso':
kernel = random_bivariate_plateau(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
return kernel
np.seterr(divide='ignore', invalid='ignore')
def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
"""2D sinc filter
Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
Args:
cutoff (float): cutoff frequency in radians (pi is max)
kernel_size (int): horizontal and vertical size, must be odd.
pad_to (int): pad kernel size to desired size, must be odd or zero.
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
kernel = np.fromfunction(
lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
kernel = kernel / np.sum(kernel)
if pad_to > kernel_size:
pad_size = (pad_to - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
return kernel
# ------------------------------------------------------------- #
# --------------------------- noise --------------------------- #
# ------------------------------------------------------------- #
# ----------------------- Gaussian Noise ----------------------- #
def generate_gaussian_noise(img, sigma=10, gray_noise=False):
"""Generate Gaussian noise.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
sigma (float): Noise scale (measured in range 255). Default: 10.
Returns:
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
float32.
"""
if gray_noise:
noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
else:
noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
return noise
def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
"""Add Gaussian noise.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
sigma (float): Noise scale (measured in range 255). Default: 10.
Returns:
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
float32.
"""
noise = generate_gaussian_noise(img, sigma, gray_noise)
out = img + noise
if clip and rounds:
out = np.clip((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
"""Add Gaussian noise (PyTorch version).
Args:
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
scale (float | Tensor): Noise scale. Default: 1.0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
b, _, h, w = img.size()
if not isinstance(sigma, (float, int)):
sigma = sigma.view(img.size(0), 1, 1, 1)
if isinstance(gray_noise, (float, int)):
cal_gray_noise = gray_noise > 0
else:
gray_noise = gray_noise.view(b, 1, 1, 1)
cal_gray_noise = torch.sum(gray_noise) > 0
if cal_gray_noise:
noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
noise_gray = noise_gray.view(b, 1, h, w)
# always calculate color noise
noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
if cal_gray_noise:
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
return noise
def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
"""Add Gaussian noise (PyTorch version).
Args:
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
scale (float | Tensor): Noise scale. Default: 1.0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
# ----------------------- Random Gaussian Noise ----------------------- #
def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
if np.random.uniform() < gray_prob:
gray_noise = True
else:
gray_noise = False
return generate_gaussian_noise(img, sigma, gray_noise)
def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
out = img + noise
if clip and rounds:
out = np.clip((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
sigma = torch.rand(
img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
gray_noise = (gray_noise < gray_prob).float()
return generate_gaussian_noise_pt(img, sigma, gray_noise)
def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
# ----------------------- Poisson (Shot) Noise ----------------------- #
def generate_poisson_noise(img, scale=1.0, gray_noise=False):
"""Generate poisson noise.
Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
scale (float): Noise scale. Default: 1.0.
gray_noise (bool): Whether generate gray noise. Default: False.
Returns:
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
float32.
"""
if gray_noise:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# round and clip image for counting vals correctly
img = np.clip((img * 255.0).round(), 0, 255) / 255.
vals = len(np.unique(img))
vals = 2**np.ceil(np.log2(vals))
out = np.float32(np.random.poisson(img * vals) / float(vals))
noise = out - img
if gray_noise:
noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
return noise * scale
def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
"""Add poisson noise.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
scale (float): Noise scale. Default: 1.0.
gray_noise (bool): Whether generate gray noise. Default: False.
Returns:
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
float32.
"""
noise = generate_poisson_noise(img, scale, gray_noise)
out = img + noise
if clip and rounds:
out = np.clip((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
"""Generate a batch of poisson noise (PyTorch version)
Args:
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
Default: 1.0.
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
0 for False, 1 for True. Default: 0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
b, _, h, w = img.size()
if isinstance(gray_noise, (float, int)):
cal_gray_noise = gray_noise > 0
else:
gray_noise = gray_noise.view(b, 1, 1, 1)
cal_gray_noise = torch.sum(gray_noise) > 0
if cal_gray_noise:
img_gray = rgb_to_grayscale(img, num_output_channels=1)
# round and clip image for counting vals correctly
img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
# use for-loop to get the unique values for each sample
vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
out = torch.poisson(img_gray * vals) / vals
noise_gray = out - img_gray
noise_gray = noise_gray.expand(b, 3, h, w)
# always calculate color noise
# round and clip image for counting vals correctly
img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
# use for-loop to get the unique values for each sample
vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
out = torch.poisson(img * vals) / vals
noise = out - img
if cal_gray_noise:
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
if not isinstance(scale, (float, int)):
scale = scale.view(b, 1, 1, 1)
return noise * scale
def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
"""Add poisson noise to a batch of images (PyTorch version).
Args:
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
Default: 1.0.
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
0 for False, 1 for True. Default: 0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
noise = generate_poisson_noise_pt(img, scale, gray_noise)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
# ----------------------- Random Poisson (Shot) Noise ----------------------- #
def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
scale = np.random.uniform(scale_range[0], scale_range[1])
if np.random.uniform() < gray_prob:
gray_noise = True
else:
gray_noise = False
return generate_poisson_noise(img, scale, gray_noise)
def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
noise = random_generate_poisson_noise(img, scale_range, gray_prob)
out = img + noise
if clip and rounds:
out = np.clip((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
scale = torch.rand(
img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
gray_noise = (gray_noise < gray_prob).float()
return generate_poisson_noise_pt(img, scale, gray_noise)
def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
# ------------------------------------------------------------------------ #
# --------------------------- JPEG compression --------------------------- #
# ------------------------------------------------------------------------ #
def add_jpg_compression(img, quality=90):
"""Add JPG compression artifacts.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
quality (float): JPG compression quality. 0 for lowest quality, 100 for
best quality. Default: 90.
Returns:
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
float32.
"""
img = np.clip(img, 0, 1)
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(quality)]
_, encimg = cv2.imencode('.jpg', img * 255., encode_param)
img = np.float32(cv2.imdecode(encimg, 1)) / 255.
return img
def random_add_jpg_compression(img, quality_range=(90, 100)):
"""Randomly add JPG compression artifacts.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
quality_range (tuple[float] | list[float]): JPG compression quality
range. 0 for lowest quality, 100 for best quality.
Default: (90, 100).
Returns:
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
float32.
"""
quality = np.random.uniform(quality_range[0], quality_range[1])
return add_jpg_compression(img, quality)

View File

@@ -0,0 +1,80 @@
import random
import time
from os import path as osp
from torch.utils import data as data
from torchvision.transforms.functional import normalize
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class FFHQDataset(data.Dataset):
"""FFHQ dataset for StyleGAN.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
io_backend (dict): IO backend type and other kwarg.
mean (list | tuple): Image mean.
std (list | tuple): Image std.
use_hflip (bool): Whether to horizontally flip.
"""
def __init__(self, opt):
super(FFHQDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.gt_folder = opt['dataroot_gt']
self.mean = opt['mean']
self.std = opt['std']
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = self.gt_folder
if not self.gt_folder.endswith('.lmdb'):
raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
self.paths = [line.split('.')[0] for line in fin]
else:
# FFHQ has 70000 images in total
self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)]
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# load gt image
gt_path = self.paths[index]
# avoid errors caused by high latency in reading files
retry = 3
while retry > 0:
try:
img_bytes = self.file_client.get(gt_path)
except Exception as e:
logger = get_root_logger()
logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}')
# change another file to read
index = random.randint(0, self.__len__())
gt_path = self.paths[index]
time.sleep(1) # sleep 1s for occasional server congestion
else:
break
finally:
retry -= 1
img_gt = imfrombytes(img_bytes, float32=True)
# random horizontal flip
img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
# normalize
normalize(img_gt, self.mean, self.std, inplace=True)
return {'gt': img_gt, 'gt_path': gt_path}
def __len__(self):
return len(self.paths)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,4 @@
000 100 (720,1280,3)
011 100 (720,1280,3)
015 100 (720,1280,3)
020 100 (720,1280,3)

View File

@@ -0,0 +1,270 @@
000 100 (720,1280,3)
001 100 (720,1280,3)
002 100 (720,1280,3)
003 100 (720,1280,3)
004 100 (720,1280,3)
005 100 (720,1280,3)
006 100 (720,1280,3)
007 100 (720,1280,3)
008 100 (720,1280,3)
009 100 (720,1280,3)
010 100 (720,1280,3)
011 100 (720,1280,3)
012 100 (720,1280,3)
013 100 (720,1280,3)
014 100 (720,1280,3)
015 100 (720,1280,3)
016 100 (720,1280,3)
017 100 (720,1280,3)
018 100 (720,1280,3)
019 100 (720,1280,3)
020 100 (720,1280,3)
021 100 (720,1280,3)
022 100 (720,1280,3)
023 100 (720,1280,3)
024 100 (720,1280,3)
025 100 (720,1280,3)
026 100 (720,1280,3)
027 100 (720,1280,3)
028 100 (720,1280,3)
029 100 (720,1280,3)
030 100 (720,1280,3)
031 100 (720,1280,3)
032 100 (720,1280,3)
033 100 (720,1280,3)
034 100 (720,1280,3)
035 100 (720,1280,3)
036 100 (720,1280,3)
037 100 (720,1280,3)
038 100 (720,1280,3)
039 100 (720,1280,3)
040 100 (720,1280,3)
041 100 (720,1280,3)
042 100 (720,1280,3)
043 100 (720,1280,3)
044 100 (720,1280,3)
045 100 (720,1280,3)
046 100 (720,1280,3)
047 100 (720,1280,3)
048 100 (720,1280,3)
049 100 (720,1280,3)
050 100 (720,1280,3)
051 100 (720,1280,3)
052 100 (720,1280,3)
053 100 (720,1280,3)
054 100 (720,1280,3)
055 100 (720,1280,3)
056 100 (720,1280,3)
057 100 (720,1280,3)
058 100 (720,1280,3)
059 100 (720,1280,3)
060 100 (720,1280,3)
061 100 (720,1280,3)
062 100 (720,1280,3)
063 100 (720,1280,3)
064 100 (720,1280,3)
065 100 (720,1280,3)
066 100 (720,1280,3)
067 100 (720,1280,3)
068 100 (720,1280,3)
069 100 (720,1280,3)
070 100 (720,1280,3)
071 100 (720,1280,3)
072 100 (720,1280,3)
073 100 (720,1280,3)
074 100 (720,1280,3)
075 100 (720,1280,3)
076 100 (720,1280,3)
077 100 (720,1280,3)
078 100 (720,1280,3)
079 100 (720,1280,3)
080 100 (720,1280,3)
081 100 (720,1280,3)
082 100 (720,1280,3)
083 100 (720,1280,3)
084 100 (720,1280,3)
085 100 (720,1280,3)
086 100 (720,1280,3)
087 100 (720,1280,3)
088 100 (720,1280,3)
089 100 (720,1280,3)
090 100 (720,1280,3)
091 100 (720,1280,3)
092 100 (720,1280,3)
093 100 (720,1280,3)
094 100 (720,1280,3)
095 100 (720,1280,3)
096 100 (720,1280,3)
097 100 (720,1280,3)
098 100 (720,1280,3)
099 100 (720,1280,3)
100 100 (720,1280,3)
101 100 (720,1280,3)
102 100 (720,1280,3)
103 100 (720,1280,3)
104 100 (720,1280,3)
105 100 (720,1280,3)
106 100 (720,1280,3)
107 100 (720,1280,3)
108 100 (720,1280,3)
109 100 (720,1280,3)
110 100 (720,1280,3)
111 100 (720,1280,3)
112 100 (720,1280,3)
113 100 (720,1280,3)
114 100 (720,1280,3)
115 100 (720,1280,3)
116 100 (720,1280,3)
117 100 (720,1280,3)
118 100 (720,1280,3)
119 100 (720,1280,3)
120 100 (720,1280,3)
121 100 (720,1280,3)
122 100 (720,1280,3)
123 100 (720,1280,3)
124 100 (720,1280,3)
125 100 (720,1280,3)
126 100 (720,1280,3)
127 100 (720,1280,3)
128 100 (720,1280,3)
129 100 (720,1280,3)
130 100 (720,1280,3)
131 100 (720,1280,3)
132 100 (720,1280,3)
133 100 (720,1280,3)
134 100 (720,1280,3)
135 100 (720,1280,3)
136 100 (720,1280,3)
137 100 (720,1280,3)
138 100 (720,1280,3)
139 100 (720,1280,3)
140 100 (720,1280,3)
141 100 (720,1280,3)
142 100 (720,1280,3)
143 100 (720,1280,3)
144 100 (720,1280,3)
145 100 (720,1280,3)
146 100 (720,1280,3)
147 100 (720,1280,3)
148 100 (720,1280,3)
149 100 (720,1280,3)
150 100 (720,1280,3)
151 100 (720,1280,3)
152 100 (720,1280,3)
153 100 (720,1280,3)
154 100 (720,1280,3)
155 100 (720,1280,3)
156 100 (720,1280,3)
157 100 (720,1280,3)
158 100 (720,1280,3)
159 100 (720,1280,3)
160 100 (720,1280,3)
161 100 (720,1280,3)
162 100 (720,1280,3)
163 100 (720,1280,3)
164 100 (720,1280,3)
165 100 (720,1280,3)
166 100 (720,1280,3)
167 100 (720,1280,3)
168 100 (720,1280,3)
169 100 (720,1280,3)
170 100 (720,1280,3)
171 100 (720,1280,3)
172 100 (720,1280,3)
173 100 (720,1280,3)
174 100 (720,1280,3)
175 100 (720,1280,3)
176 100 (720,1280,3)
177 100 (720,1280,3)
178 100 (720,1280,3)
179 100 (720,1280,3)
180 100 (720,1280,3)
181 100 (720,1280,3)
182 100 (720,1280,3)
183 100 (720,1280,3)
184 100 (720,1280,3)
185 100 (720,1280,3)
186 100 (720,1280,3)
187 100 (720,1280,3)
188 100 (720,1280,3)
189 100 (720,1280,3)
190 100 (720,1280,3)
191 100 (720,1280,3)
192 100 (720,1280,3)
193 100 (720,1280,3)
194 100 (720,1280,3)
195 100 (720,1280,3)
196 100 (720,1280,3)
197 100 (720,1280,3)
198 100 (720,1280,3)
199 100 (720,1280,3)
200 100 (720,1280,3)
201 100 (720,1280,3)
202 100 (720,1280,3)
203 100 (720,1280,3)
204 100 (720,1280,3)
205 100 (720,1280,3)
206 100 (720,1280,3)
207 100 (720,1280,3)
208 100 (720,1280,3)
209 100 (720,1280,3)
210 100 (720,1280,3)
211 100 (720,1280,3)
212 100 (720,1280,3)
213 100 (720,1280,3)
214 100 (720,1280,3)
215 100 (720,1280,3)
216 100 (720,1280,3)
217 100 (720,1280,3)
218 100 (720,1280,3)
219 100 (720,1280,3)
220 100 (720,1280,3)
221 100 (720,1280,3)
222 100 (720,1280,3)
223 100 (720,1280,3)
224 100 (720,1280,3)
225 100 (720,1280,3)
226 100 (720,1280,3)
227 100 (720,1280,3)
228 100 (720,1280,3)
229 100 (720,1280,3)
230 100 (720,1280,3)
231 100 (720,1280,3)
232 100 (720,1280,3)
233 100 (720,1280,3)
234 100 (720,1280,3)
235 100 (720,1280,3)
236 100 (720,1280,3)
237 100 (720,1280,3)
238 100 (720,1280,3)
239 100 (720,1280,3)
240 100 (720,1280,3)
241 100 (720,1280,3)
242 100 (720,1280,3)
243 100 (720,1280,3)
244 100 (720,1280,3)
245 100 (720,1280,3)
246 100 (720,1280,3)
247 100 (720,1280,3)
248 100 (720,1280,3)
249 100 (720,1280,3)
250 100 (720,1280,3)
251 100 (720,1280,3)
252 100 (720,1280,3)
253 100 (720,1280,3)
254 100 (720,1280,3)
255 100 (720,1280,3)
256 100 (720,1280,3)
257 100 (720,1280,3)
258 100 (720,1280,3)
259 100 (720,1280,3)
260 100 (720,1280,3)
261 100 (720,1280,3)
262 100 (720,1280,3)
263 100 (720,1280,3)
264 100 (720,1280,3)
265 100 (720,1280,3)
266 100 (720,1280,3)
267 100 (720,1280,3)
268 100 (720,1280,3)
269 100 (720,1280,3)

View File

@@ -0,0 +1,4 @@
240 100 (720,1280,3)
241 100 (720,1280,3)
246 100 (720,1280,3)
257 100 (720,1280,3)

View File

@@ -0,0 +1,30 @@
240 100 (720,1280,3)
241 100 (720,1280,3)
242 100 (720,1280,3)
243 100 (720,1280,3)
244 100 (720,1280,3)
245 100 (720,1280,3)
246 100 (720,1280,3)
247 100 (720,1280,3)
248 100 (720,1280,3)
249 100 (720,1280,3)
250 100 (720,1280,3)
251 100 (720,1280,3)
252 100 (720,1280,3)
253 100 (720,1280,3)
254 100 (720,1280,3)
255 100 (720,1280,3)
256 100 (720,1280,3)
257 100 (720,1280,3)
258 100 (720,1280,3)
259 100 (720,1280,3)
260 100 (720,1280,3)
261 100 (720,1280,3)
262 100 (720,1280,3)
263 100 (720,1280,3)
264 100 (720,1280,3)
265 100 (720,1280,3)
266 100 (720,1280,3)
267 100 (720,1280,3)
268 100 (720,1280,3)
269 100 (720,1280,3)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,106 @@
from torch.utils import data as data
from torchvision.transforms.functional import normalize
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class PairedImageDataset(data.Dataset):
"""Paired image dataset for image restoration.
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
There are three modes:
1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
2. **meta_info_file**: Use meta information file to generate paths. \
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
3. **folder**: Scan folders to generate paths. The rest.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
meta_info_file (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
Default: '{}'.
gt_size (int): Cropped patched size for gt patches.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
scale (bool): Scale, which will be added automatically.
phase (str): 'train' or 'val'.
"""
def __init__(self, opt):
super(PairedImageDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.mean = opt['mean'] if 'mean' in opt else None
self.std = opt['std'] if 'std' in opt else None
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
if 'filename_tmpl' in opt:
self.filename_tmpl = opt['filename_tmpl']
else:
self.filename_tmpl = '{}'
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
self.io_backend_opt['client_keys'] = ['lq', 'gt']
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
self.opt['meta_info_file'], self.filename_tmpl)
else:
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
scale = self.opt['scale']
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path = self.paths[index]['gt_path']
img_bytes = self.file_client.get(gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
lq_path = self.paths[index]['lq_path']
img_bytes = self.file_client.get(lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
# augmentation for training
if self.opt['phase'] == 'train':
gt_size = self.opt['gt_size']
# random crop
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
# color space transform
if 'color' in self.opt and self.opt['color'] == 'y':
img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None]
img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None]
# crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
# TODO: It is better to update the datasets, rather than force to crop
if self.opt['phase'] != 'train':
img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
normalize(img_gt, self.mean, self.std, inplace=True)
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
def __len__(self):
return len(self.paths)

View File

@@ -0,0 +1,122 @@
import queue as Queue
import threading
import torch
from torch.utils.data import DataLoader
class PrefetchGenerator(threading.Thread):
"""A general prefetch generator.
Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
Args:
generator: Python generator.
num_prefetch_queue (int): Number of prefetch queue.
"""
def __init__(self, generator, num_prefetch_queue):
threading.Thread.__init__(self)
self.queue = Queue.Queue(num_prefetch_queue)
self.generator = generator
self.daemon = True
self.start()
def run(self):
for item in self.generator:
self.queue.put(item)
self.queue.put(None)
def __next__(self):
next_item = self.queue.get()
if next_item is None:
raise StopIteration
return next_item
def __iter__(self):
return self
class PrefetchDataLoader(DataLoader):
"""Prefetch version of dataloader.
Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
TODO:
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
ddp.
Args:
num_prefetch_queue (int): Number of prefetch queue.
kwargs (dict): Other arguments for dataloader.
"""
def __init__(self, num_prefetch_queue, **kwargs):
self.num_prefetch_queue = num_prefetch_queue
super(PrefetchDataLoader, self).__init__(**kwargs)
def __iter__(self):
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
class CPUPrefetcher():
"""CPU prefetcher.
Args:
loader: Dataloader.
"""
def __init__(self, loader):
self.ori_loader = loader
self.loader = iter(loader)
def next(self):
try:
return next(self.loader)
except StopIteration:
return None
def reset(self):
self.loader = iter(self.ori_loader)
class CUDAPrefetcher():
"""CUDA prefetcher.
Reference: https://github.com/NVIDIA/apex/issues/304#
It may consume more GPU memory.
Args:
loader: Dataloader.
opt (dict): Options.
"""
def __init__(self, loader, opt):
self.ori_loader = loader
self.loader = iter(loader)
self.opt = opt
self.stream = torch.cuda.Stream()
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
self.preload()
def preload(self):
try:
self.batch = next(self.loader) # self.batch is a dict
except StopIteration:
self.batch = None
return None
# put tensors to gpu
with torch.cuda.stream(self.stream):
for k, v in self.batch.items():
if torch.is_tensor(v):
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
batch = self.batch
self.preload()
return batch
def reset(self):
self.loader = iter(self.ori_loader)
self.preload()

View File

@@ -0,0 +1,384 @@
import cv2
import math
import numpy as np
import os
import os.path as osp
import random
import time
import torch
from pathlib import Path
import albumentations
import torch.nn.functional as F
from torch.utils import data as data
from basicsr.utils import DiffJPEG
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
from basicsr.utils.img_process_util import filter2D
from basicsr.data.transforms import paired_random_crop, random_crop
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
from utils import util_image
def readline_txt(txt_file):
txt_file = [txt_file, ] if isinstance(txt_file, str) else txt_file
out = []
for txt_file_current in txt_file:
with open(txt_file_current, 'r') as ff:
out.extend([x[:-1] for x in ff.readlines()])
return out
@DATASET_REGISTRY.register(suffix='basicsr')
class RealESRGANDataset(data.Dataset):
"""Dataset used for Real-ESRGAN model:
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
It loads gt (Ground-Truth) images, and augments them.
It also generates blur kernels and sinc kernels for generating low-quality images.
Note that the low-quality images are processed in tensors on GPUS for faster processing.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
meta_info (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
Please see more options in the codes.
"""
def __init__(self, opt, mode='training'):
super(RealESRGANDataset, self).__init__()
self.opt = opt
self.file_client = None
self.io_backend_opt = opt['io_backend']
# file client (lmdb io backend)
self.image_paths = []
self.text_paths = []
self.moment_paths = []
if opt.get('data_source', None) is not None:
for ii in range(len(opt['data_source'])):
configs = opt['data_source'].get(f'source{ii+1}')
root_path = Path(configs.root_path)
im_folder = root_path / configs.image_path
im_ext = configs.im_ext
image_stems = sorted([x.stem for x in im_folder.glob(f"*.{im_ext}")])
if configs.get('length', None) is not None:
assert configs.length < len(image_stems)
image_stems = image_stems[:configs.length]
if configs.get("text_path", None) is not None:
text_folder = root_path / configs.text_path
text_stems = [x.stem for x in text_folder.glob("*.txt")]
image_stems = sorted(list(set(image_stems).intersection(set(text_stems))))
self.text_paths.extend([str(text_folder / f"{x}.txt") for x in image_stems])
else:
self.text_paths.extend([None, ] * len(image_stems))
self.image_paths.extend([str(im_folder / f"{x}.{im_ext}") for x in image_stems])
if configs.get("moment_path", None) is not None:
moment_folder = root_path / configs.moment_path
self.moment_paths.extend([str(moment_folder / f"{x}.npy") for x in image_stems])
else:
self.moment_paths.extend([None, ] * len(image_stems))
# blur settings for the first degradation
self.blur_kernel_size = opt['blur_kernel_size']
self.kernel_list = opt['kernel_list']
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
self.blur_sigma = opt['blur_sigma']
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
# blur settings for the second degradation
self.blur_kernel_size2 = opt['blur_kernel_size2']
self.kernel_list2 = opt['kernel_list2']
self.kernel_prob2 = opt['kernel_prob2']
self.blur_sigma2 = opt['blur_sigma2']
self.betag_range2 = opt['betag_range2']
self.betap_range2 = opt['betap_range2']
self.sinc_prob2 = opt['sinc_prob2']
# a final sinc filter
self.final_sinc_prob = opt['final_sinc_prob']
self.kernel_range1 = [x for x in range(3, opt['blur_kernel_size'], 2)] # kernel size ranges from 7 to 21
self.kernel_range2 = [x for x in range(3, opt['blur_kernel_size2'], 2)] # kernel size ranges from 7 to 21
# TODO: kernel range is now hard-coded, should be in the configure file
# convolving with pulse tensor brings no blurry effect
self.pulse_tensor = torch.zeros(opt['blur_kernel_size2'], opt['blur_kernel_size2']).float()
self.pulse_tensor[opt['blur_kernel_size2']//2, opt['blur_kernel_size2']//2] = 1
self.mode = mode
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# -------------------------------- Load gt images -------------------------------- #
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
gt_path = self.image_paths[index]
# avoid errors caused by high latency in reading files
retry = 3
while retry > 0:
try:
img_bytes = self.file_client.get(gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
except:
index = random.randint(0, self.__len__())
gt_path = self.image_paths[index]
time.sleep(1) # sleep 1s for occasional server congestion
finally:
retry -= 1
if self.mode == 'testing':
if not hasattr(self, 'test_aug'):
self.test_aug = albumentations.Compose([
albumentations.SmallestMaxSize(
max_size=self.opt['gt_size'],
interpolation=cv2.INTER_AREA,
),
albumentations.CenterCrop(self.opt['gt_size'], self.opt['gt_size']),
])
img_gt = self.test_aug(image=img_gt)['image']
elif self.mode == 'training':
# -------------------- Do augmentation for training: flip, rotation -------------------- #
if self.opt['use_hflip'] or self.opt['use_rot']:
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
h, w = img_gt.shape[0:2]
gt_size = self.opt['gt_size']
# resize or pad
if not self.opt['random_crop']:
if not min(h, w) == gt_size:
if not hasattr(self, 'smallest_resizer'):
self.smallest_resizer = util_image.SmallestMaxSize(
max_size=gt_size, pass_resize=False,
)
img_gt = self.smallest_resizer(img_gt)
# center crop
if not hasattr(self, 'center_cropper'):
self.center_cropper = albumentations.CenterCrop(gt_size, gt_size)
img_gt = self.center_cropper(image=img_gt)['image']
else:
img_gt = random_crop(img_gt, self.opt['gt_size'])
else:
raise ValueError(f'Unexpected value {self.mode} for mode parameter')
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range1)
if np.random.uniform() < self.opt['sinc_prob']:
# this sinc filter setting is for kernels ranging from [7, 21]
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel = random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
kernel_size,
self.blur_sigma,
self.blur_sigma, [-math.pi, math.pi],
self.betag_range,
self.betap_range,
noise_range=None)
# pad kernel
pad_size = (self.blur_kernel_size - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range2)
if np.random.uniform() < self.opt['sinc_prob2']:
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel2 = random_mixed_kernels(
self.kernel_list2,
self.kernel_prob2,
kernel_size,
self.blur_sigma2,
self.blur_sigma2, [-math.pi, math.pi],
self.betag_range2,
self.betap_range2,
noise_range=None)
# pad kernel
pad_size = (self.blur_kernel_size2 - kernel_size) // 2
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------------------- the final sinc kernel ------------------------------------- #
if np.random.uniform() < self.opt['final_sinc_prob']:
kernel_size = random.choice(self.kernel_range2)
omega_c = np.random.uniform(np.pi / 3, np.pi)
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=self.blur_kernel_size2)
sinc_kernel = torch.FloatTensor(sinc_kernel)
else:
sinc_kernel = self.pulse_tensor
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
kernel = torch.FloatTensor(kernel)
kernel2 = torch.FloatTensor(kernel2)
if self.text_paths[index] is None or self.opt['random_crop']:
prompt = ""
else:
with open(self.text_paths[index], 'r') as ff:
prompt = ff.read()
if self.opt.max_token_length is not None:
prompt = prompt[:self.opt.max_token_length]
return_d = {
'gt': img_gt,
'gt_path': gt_path,
'txt': prompt,
'kernel1': kernel,
'kernel2': kernel2,
'sinc_kernel': sinc_kernel,
}
if self.moment_paths[index] is not None and (not self.opt['random_crop']):
return_d['gt_moment'] = np.load(self.moment_paths[index])
return return_d
def __len__(self):
return len(self.image_paths)
def degrade_fun(self, conf_degradation, im_gt, kernel1, kernel2, sinc_kernel):
if not hasattr(self, 'jpeger'):
self.jpeger = DiffJPEG(differentiable=False) # simulate JPEG compression artifacts
ori_h, ori_w = im_gt.size()[2:4]
sf = conf_degradation.sf
# ----------------------- The first degradation process ----------------------- #
# blur
out = filter2D(im_gt, kernel1)
# random resize
updown_type = random.choices(
['up', 'down', 'keep'],
conf_degradation['resize_prob'],
)[0]
if updown_type == 'up':
scale = random.uniform(1, conf_degradation['resize_range'][1])
elif updown_type == 'down':
scale = random.uniform(conf_degradation['resize_range'][0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, scale_factor=scale, mode=mode)
# add noise
gray_noise_prob = conf_degradation['gray_noise_prob']
if random.random() < conf_degradation['gaussian_noise_prob']:
out = random_add_gaussian_noise_pt(
out,
sigma_range=conf_degradation['noise_range'],
clip=True,
rounds=False,
gray_prob=gray_noise_prob,
)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=conf_degradation['poisson_scale_range'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*conf_degradation['jpeg_range'])
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
out = self.jpeger(out, quality=jpeg_p)
# ----------------------- The second degradation process ----------------------- #
# blur
if random.random() < conf_degradation['second_order_prob']:
if random.random() < conf_degradation['second_blur_prob']:
out = filter2D(out, kernel2)
# random resize
updown_type = random.choices(
['up', 'down', 'keep'],
conf_degradation['resize_prob2'],
)[0]
if updown_type == 'up':
scale = random.uniform(1, conf_degradation['resize_range2'][1])
elif updown_type == 'down':
scale = random.uniform(conf_degradation['resize_range2'][0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
out,
size=(int(ori_h / sf * scale), int(ori_w / sf * scale)),
mode=mode,
)
# add noise
gray_noise_prob = conf_degradation['gray_noise_prob2']
if random.random() < conf_degradation['gaussian_noise_prob2']:
out = random_add_gaussian_noise_pt(
out,
sigma_range=conf_degradation['noise_range2'],
clip=True,
rounds=False,
gray_prob=gray_noise_prob,
)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=conf_degradation['poisson_scale_range2'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False,
)
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if random.random() < 0.5:
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
out,
size=(ori_h // sf, ori_w // sf),
mode=mode,
)
out = filter2D(out, sinc_kernel)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*conf_degradation['jpeg_range2'])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
else:
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*conf_degradation['jpeg_range2'])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
out,
size=(ori_h // sf, ori_w // sf),
mode=mode,
)
out = filter2D(out, sinc_kernel)
# clamp and round
im_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
return {'lq':im_lq.contiguous(), 'gt':im_gt}

View File

@@ -0,0 +1,106 @@
import os
from torch.utils import data as data
from torchvision.transforms.functional import normalize
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register(suffix='basicsr')
class RealESRGANPairedDataset(data.Dataset):
"""Paired image dataset for image restoration.
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
There are three modes:
1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
2. **meta_info_file**: Use meta information file to generate paths. \
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
3. **folder**: Scan folders to generate paths. The rest.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
meta_info (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
Default: '{}'.
gt_size (int): Cropped patched size for gt patches.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
scale (bool): Scale, which will be added automatically.
phase (str): 'train' or 'val'.
"""
def __init__(self, opt):
super(RealESRGANPairedDataset, self).__init__()
self.opt = opt
self.file_client = None
self.io_backend_opt = opt['io_backend']
# mean and std for normalizing the input images
self.mean = opt['mean'] if 'mean' in opt else None
self.std = opt['std'] if 'std' in opt else None
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
# file client (lmdb io backend)
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
self.io_backend_opt['client_keys'] = ['lq', 'gt']
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
# disk backend with meta_info
# Each line in the meta_info describes the relative path to an image
with open(self.opt['meta_info']) as fin:
paths = [line.strip() for line in fin]
self.paths = []
for path in paths:
gt_path, lq_path = path.split(', ')
gt_path = os.path.join(self.gt_folder, gt_path)
lq_path = os.path.join(self.lq_folder, lq_path)
self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
else:
# disk backend
# it will scan the whole folder to get meta info
# it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
scale = self.opt['scale']
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path = self.paths[index]['gt_path']
img_bytes = self.file_client.get(gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
lq_path = self.paths[index]['lq_path']
img_bytes = self.file_client.get(lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
# augmentation for training
if self.opt['phase'] == 'train':
gt_size = self.opt['gt_size']
# random crop
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
normalize(img_gt, self.mean, self.std, inplace=True)
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
def __len__(self):
return len(self.paths)

View File

@@ -0,0 +1,352 @@
import numpy as np
import random
import torch
from pathlib import Path
from torch.utils import data as data
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.flow_util import dequantize_flow
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class REDSDataset(data.Dataset):
"""REDS dataset for training.
The keys are generated from a meta info txt file.
basicsr/data/meta_info/meta_info_REDS_GT.txt
Each line contains:
1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
a white space.
Examples:
000 100 (720,1280,3)
001 100 (720,1280,3)
...
Key examples: "000/00000000"
GT (gt): Ground-Truth;
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
Args:
opt (dict): Config for train dataset. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
dataroot_flow (str, optional): Data root path for flow.
meta_info_file (str): Path for meta information file.
val_partition (str): Validation partition types. 'REDS4' or 'official'.
io_backend (dict): IO backend type and other kwarg.
num_frame (int): Window size for input frames.
gt_size (int): Cropped patched size for gt patches.
interval_list (list): Interval list for temporal augmentation.
random_reverse (bool): Random reverse input frames.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
scale (bool): Scale, which will be added automatically.
"""
def __init__(self, opt):
super(REDSDataset, self).__init__()
self.opt = opt
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
self.flow_root = Path(opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None
assert opt['num_frame'] % 2 == 1, (f'num_frame should be odd number, but got {opt["num_frame"]}')
self.num_frame = opt['num_frame']
self.num_half_frames = opt['num_frame'] // 2
self.keys = []
with open(opt['meta_info_file'], 'r') as fin:
for line in fin:
folder, frame_num, _ = line.split(' ')
self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
# remove the video clips used in validation
if opt['val_partition'] == 'REDS4':
val_partition = ['000', '011', '015', '020']
elif opt['val_partition'] == 'official':
val_partition = [f'{v:03d}' for v in range(240, 270)]
else:
raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
f"Supported ones are ['official', 'REDS4'].")
self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.is_lmdb = False
if self.io_backend_opt['type'] == 'lmdb':
self.is_lmdb = True
if self.flow_root is not None:
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
else:
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
self.io_backend_opt['client_keys'] = ['lq', 'gt']
# temporal augmentation configs
self.interval_list = opt['interval_list']
self.random_reverse = opt['random_reverse']
interval_str = ','.join(str(x) for x in opt['interval_list'])
logger = get_root_logger()
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
f'random reverse is {self.random_reverse}.')
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
scale = self.opt['scale']
gt_size = self.opt['gt_size']
key = self.keys[index]
clip_name, frame_name = key.split('/') # key example: 000/00000000
center_frame_idx = int(frame_name)
# determine the neighboring frames
interval = random.choice(self.interval_list)
# ensure not exceeding the borders
start_frame_idx = center_frame_idx - self.num_half_frames * interval
end_frame_idx = center_frame_idx + self.num_half_frames * interval
# each clip has 100 frames starting from 0 to 99
while (start_frame_idx < 0) or (end_frame_idx > 99):
center_frame_idx = random.randint(0, 99)
start_frame_idx = (center_frame_idx - self.num_half_frames * interval)
end_frame_idx = center_frame_idx + self.num_half_frames * interval
frame_name = f'{center_frame_idx:08d}'
neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval))
# random reverse
if self.random_reverse and random.random() < 0.5:
neighbor_list.reverse()
assert len(neighbor_list) == self.num_frame, (f'Wrong length of neighbor list: {len(neighbor_list)}')
# get the GT frame (as the center frame)
if self.is_lmdb:
img_gt_path = f'{clip_name}/{frame_name}'
else:
img_gt_path = self.gt_root / clip_name / f'{frame_name}.png'
img_bytes = self.file_client.get(img_gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
# get the neighboring LQ frames
img_lqs = []
for neighbor in neighbor_list:
if self.is_lmdb:
img_lq_path = f'{clip_name}/{neighbor:08d}'
else:
img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
img_bytes = self.file_client.get(img_lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
img_lqs.append(img_lq)
# get flows
if self.flow_root is not None:
img_flows = []
# read previous flows
for i in range(self.num_half_frames, 0, -1):
if self.is_lmdb:
flow_path = f'{clip_name}/{frame_name}_p{i}'
else:
flow_path = (self.flow_root / clip_name / f'{frame_name}_p{i}.png')
img_bytes = self.file_client.get(flow_path, 'flow')
cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
dx, dy = np.split(cat_flow, 2, axis=0)
flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
img_flows.append(flow)
# read next flows
for i in range(1, self.num_half_frames + 1):
if self.is_lmdb:
flow_path = f'{clip_name}/{frame_name}_n{i}'
else:
flow_path = (self.flow_root / clip_name / f'{frame_name}_n{i}.png')
img_bytes = self.file_client.get(flow_path, 'flow')
cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
dx, dy = np.split(cat_flow, 2, axis=0)
flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
img_flows.append(flow)
# for random crop, here, img_flows and img_lqs have the same
# spatial size
img_lqs.extend(img_flows)
# randomly crop
img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
if self.flow_root is not None:
img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.num_frame:]
# augmentation - flip, rotate
img_lqs.append(img_gt)
if self.flow_root is not None:
img_results, img_flows = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'], img_flows)
else:
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
img_results = img2tensor(img_results)
img_lqs = torch.stack(img_results[0:-1], dim=0)
img_gt = img_results[-1]
if self.flow_root is not None:
img_flows = img2tensor(img_flows)
# add the zero center flow
img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0]))
img_flows = torch.stack(img_flows, dim=0)
# img_lqs: (t, c, h, w)
# img_flows: (t, 2, h, w)
# img_gt: (c, h, w)
# key: str
if self.flow_root is not None:
return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key}
else:
return {'lq': img_lqs, 'gt': img_gt, 'key': key}
def __len__(self):
return len(self.keys)
@DATASET_REGISTRY.register()
class REDSRecurrentDataset(data.Dataset):
"""REDS dataset for training recurrent networks.
The keys are generated from a meta info txt file.
basicsr/data/meta_info/meta_info_REDS_GT.txt
Each line contains:
1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
a white space.
Examples:
000 100 (720,1280,3)
001 100 (720,1280,3)
...
Key examples: "000/00000000"
GT (gt): Ground-Truth;
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
Args:
opt (dict): Config for train dataset. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
dataroot_flow (str, optional): Data root path for flow.
meta_info_file (str): Path for meta information file.
val_partition (str): Validation partition types. 'REDS4' or 'official'.
io_backend (dict): IO backend type and other kwarg.
num_frame (int): Window size for input frames.
gt_size (int): Cropped patched size for gt patches.
interval_list (list): Interval list for temporal augmentation.
random_reverse (bool): Random reverse input frames.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
scale (bool): Scale, which will be added automatically.
"""
def __init__(self, opt):
super(REDSRecurrentDataset, self).__init__()
self.opt = opt
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
self.num_frame = opt['num_frame']
self.keys = []
with open(opt['meta_info_file'], 'r') as fin:
for line in fin:
folder, frame_num, _ = line.split(' ')
self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
# remove the video clips used in validation
if opt['val_partition'] == 'REDS4':
val_partition = ['000', '011', '015', '020']
elif opt['val_partition'] == 'official':
val_partition = [f'{v:03d}' for v in range(240, 270)]
else:
raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
f"Supported ones are ['official', 'REDS4'].")
if opt['test_mode']:
self.keys = [v for v in self.keys if v.split('/')[0] in val_partition]
else:
self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.is_lmdb = False
if self.io_backend_opt['type'] == 'lmdb':
self.is_lmdb = True
if hasattr(self, 'flow_root') and self.flow_root is not None:
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
else:
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
self.io_backend_opt['client_keys'] = ['lq', 'gt']
# temporal augmentation configs
self.interval_list = opt.get('interval_list', [1])
self.random_reverse = opt.get('random_reverse', False)
interval_str = ','.join(str(x) for x in self.interval_list)
logger = get_root_logger()
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
f'random reverse is {self.random_reverse}.')
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
scale = self.opt['scale']
gt_size = self.opt['gt_size']
key = self.keys[index]
clip_name, frame_name = key.split('/') # key example: 000/00000000
# determine the neighboring frames
interval = random.choice(self.interval_list)
# ensure not exceeding the borders
start_frame_idx = int(frame_name)
if start_frame_idx > 100 - self.num_frame * interval:
start_frame_idx = random.randint(0, 100 - self.num_frame * interval)
end_frame_idx = start_frame_idx + self.num_frame * interval
neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))
# random reverse
if self.random_reverse and random.random() < 0.5:
neighbor_list.reverse()
# get the neighboring LQ and GT frames
img_lqs = []
img_gts = []
for neighbor in neighbor_list:
if self.is_lmdb:
img_lq_path = f'{clip_name}/{neighbor:08d}'
img_gt_path = f'{clip_name}/{neighbor:08d}'
else:
img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
img_gt_path = self.gt_root / clip_name / f'{neighbor:08d}.png'
# get LQ
img_bytes = self.file_client.get(img_lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
img_lqs.append(img_lq)
# get GT
img_bytes = self.file_client.get(img_gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
img_gts.append(img_gt)
# randomly crop
img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
# augmentation - flip, rotate
img_lqs.extend(img_gts)
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
img_results = img2tensor(img_results)
img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0)
img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0)
# img_lqs: (t, c, h, w)
# img_gts: (t, c, h, w)
# key: str
return {'lq': img_lqs, 'gt': img_gts, 'key': key}
def __len__(self):
return len(self.keys)

View File

@@ -0,0 +1,68 @@
from os import path as osp
from torch.utils import data as data
from torchvision.transforms.functional import normalize
from basicsr.data.data_util import paths_from_lmdb
from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class SingleImageDataset(data.Dataset):
"""Read only lq images in the test phase.
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
There are two modes:
1. 'meta_info_file': Use meta information file to generate paths.
2. 'folder': Scan folders to generate paths.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_lq (str): Data root path for lq.
meta_info_file (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
"""
def __init__(self, opt):
super(SingleImageDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.mean = opt['mean'] if 'mean' in opt else None
self.std = opt['std'] if 'std' in opt else None
self.lq_folder = opt['dataroot_lq']
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = [self.lq_folder]
self.io_backend_opt['client_keys'] = ['lq']
self.paths = paths_from_lmdb(self.lq_folder)
elif 'meta_info_file' in self.opt:
with open(self.opt['meta_info_file'], 'r') as fin:
self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin]
else:
self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# load lq image
lq_path = self.paths[index]
img_bytes = self.file_client.get(lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
# color space transform
if 'color' in self.opt and self.opt['color'] == 'y':
img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
# BGR to RGB, HWC to CHW, numpy to tensor
img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
return {'lq': img_lq, 'lq_path': lq_path}
def __len__(self):
return len(self.paths)

207
basicsr/data/transforms.py Normal file
View File

@@ -0,0 +1,207 @@
import cv2
import random
import torch
def mod_crop(img, scale):
"""Mod crop images, used during testing.
Args:
img (ndarray): Input image.
scale (int): Scale factor.
Returns:
ndarray: Result image.
"""
img = img.copy()
if img.ndim in (2, 3):
h, w = img.shape[0], img.shape[1]
h_remainder, w_remainder = h % scale, w % scale
img = img[:h - h_remainder, :w - w_remainder, ...]
else:
raise ValueError(f'Wrong img ndim: {img.ndim}.')
return img
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
"""Paired random crop. Support Numpy array and Tensor inputs.
It crops lists of lq and gt images with corresponding locations.
Args:
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.
gt_path (str): Path to ground-truth. Default: None.
Returns:
list[ndarray] | ndarray: GT images and LQ images. If returned results
only have one element, just return ndarray.
"""
if not isinstance(img_gts, list):
img_gts = [img_gts]
if not isinstance(img_lqs, list):
img_lqs = [img_lqs]
# determine input type: Numpy array or Tensor
input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
if input_type == 'Tensor':
h_lq, w_lq = img_lqs[0].size()[-2:]
h_gt, w_gt = img_gts[0].size()[-2:]
else:
h_lq, w_lq = img_lqs[0].shape[0:2]
h_gt, w_gt = img_gts[0].shape[0:2]
lq_patch_size = gt_patch_size // scale
if h_gt != h_lq * scale or w_gt != w_lq * scale:
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
f'multiplication of LQ ({h_lq}, {w_lq}).')
if h_lq < lq_patch_size or w_lq < lq_patch_size:
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
f'({lq_patch_size}, {lq_patch_size}). '
f'Please remove {gt_path}.')
# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)
# crop lq patch
if input_type == 'Tensor':
img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
else:
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
if input_type == 'Tensor':
img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
else:
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
if len(img_gts) == 1:
img_gts = img_gts[0]
if len(img_lqs) == 1:
img_lqs = img_lqs[0]
return img_gts, img_lqs
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
We use vertical flip and transpose for rotation implementation.
All the images in the list use the same augmentation.
Args:
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
is an ndarray, it will be transformed to a list.
hflip (bool): Horizontal flip. Default: True.
rotation (bool): Ratotation. Default: True.
flows (list[ndarray]: Flows to be augmented. If the input is an
ndarray, it will be transformed to a list.
Dimension is (h, w, 2). Default: None.
return_status (bool): Return the status of flip and rotation.
Default: False.
Returns:
list[ndarray] | ndarray: Augmented images and flows. If returned
results only have one element, just return ndarray.
"""
hflip = hflip and random.random() < 0.5
vflip = rotation and random.random() < 0.5
rot90 = rotation and random.random() < 0.5
def _augment(img):
if hflip: # horizontal
cv2.flip(img, 1, img)
if vflip: # vertical
cv2.flip(img, 0, img)
if rot90:
img = img.transpose(1, 0, 2)
return img
def _augment_flow(flow):
if hflip: # horizontal
cv2.flip(flow, 1, flow)
flow[:, :, 0] *= -1
if vflip: # vertical
cv2.flip(flow, 0, flow)
flow[:, :, 1] *= -1
if rot90:
flow = flow.transpose(1, 0, 2)
flow = flow[:, :, [1, 0]]
return flow
if not isinstance(imgs, list):
imgs = [imgs]
imgs = [_augment(img) for img in imgs]
if len(imgs) == 1:
imgs = imgs[0]
if flows is not None:
if not isinstance(flows, list):
flows = [flows]
flows = [_augment_flow(flow) for flow in flows]
if len(flows) == 1:
flows = flows[0]
return imgs, flows
else:
if return_status:
return imgs, (hflip, vflip, rot90)
else:
return imgs
def img_rotate(img, angle, center=None, scale=1.0):
"""Rotate image.
Args:
img (ndarray): Image to be rotated.
angle (float): Rotation angle in degrees. Positive values mean
counter-clockwise rotation.
center (tuple[int]): Rotation center. If the center is None,
initialize it as the center of the image. Default: None.
scale (float): Isotropic scale factor. Default: 1.0.
"""
(h, w) = img.shape[:2]
if center is None:
center = (w // 2, h // 2)
matrix = cv2.getRotationMatrix2D(center, angle, scale)
rotated_img = cv2.warpAffine(img, matrix, (w, h))
return rotated_img
def random_crop(im, pch_size):
'''
Randomly crop a patch from the give image.
'''
h, w = im.shape[:2]
# padding if necessary
if h < pch_size or w < pch_size:
pad_h = min(max(0, pch_size - h), h)
pad_w = min(max(0, pch_size - w), w)
im = cv2.copyMakeBorder(im, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
h, w = im.shape[:2]
if h == pch_size:
ind_h = 0
elif h > pch_size:
ind_h = random.randint(0, h-pch_size)
else:
raise ValueError('Image height is smaller than the patch size')
if w == pch_size:
ind_w = 0
elif w > pch_size:
ind_w = random.randint(0, w-pch_size)
else:
raise ValueError('Image width is smaller than the patch size')
im_pch = im[ind_h:ind_h+pch_size, ind_w:ind_w+pch_size,]
return im_pch

View File

@@ -0,0 +1,283 @@
import glob
import torch
from os import path as osp
from torch.utils import data as data
from basicsr.data.data_util import duf_downsample, generate_frame_indices, read_img_seq
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class VideoTestDataset(data.Dataset):
"""Video test dataset.
Supported datasets: Vid4, REDS4, REDSofficial.
More generally, it supports testing dataset with following structures:
::
dataroot
├── subfolder1
├── frame000
├── frame001
├── ...
├── subfolder2
├── frame000
├── frame001
├── ...
├── ...
For testing datasets, there is no need to prepare LMDB files.
Args:
opt (dict): Config for train dataset. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
io_backend (dict): IO backend type and other kwarg.
cache_data (bool): Whether to cache testing datasets.
name (str): Dataset name.
meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
in the dataroot will be used.
num_frame (int): Window size for input frames.
padding (str): Padding mode.
"""
def __init__(self, opt):
super(VideoTestDataset, self).__init__()
self.opt = opt
self.cache_data = opt['cache_data']
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
logger = get_root_logger()
logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
self.imgs_lq, self.imgs_gt = {}, {}
if 'meta_info_file' in opt:
with open(opt['meta_info_file'], 'r') as fin:
subfolders = [line.split(' ')[0] for line in fin]
subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders]
subfolders_gt = [osp.join(self.gt_root, key) for key in subfolders]
else:
subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*')))
if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']:
for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt):
# get frame list for lq and gt
subfolder_name = osp.basename(subfolder_lq)
img_paths_lq = sorted(list(scandir(subfolder_lq, full_path=True)))
img_paths_gt = sorted(list(scandir(subfolder_gt, full_path=True)))
max_idx = len(img_paths_lq)
assert max_idx == len(img_paths_gt), (f'Different number of images in lq ({max_idx})'
f' and gt folders ({len(img_paths_gt)})')
self.data_info['lq_path'].extend(img_paths_lq)
self.data_info['gt_path'].extend(img_paths_gt)
self.data_info['folder'].extend([subfolder_name] * max_idx)
for i in range(max_idx):
self.data_info['idx'].append(f'{i}/{max_idx}')
border_l = [0] * max_idx
for i in range(self.opt['num_frame'] // 2):
border_l[i] = 1
border_l[max_idx - i - 1] = 1
self.data_info['border'].extend(border_l)
# cache data or save the frame list
if self.cache_data:
logger.info(f'Cache {subfolder_name} for VideoTestDataset...')
self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq)
self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt)
else:
self.imgs_lq[subfolder_name] = img_paths_lq
self.imgs_gt[subfolder_name] = img_paths_gt
else:
raise ValueError(f'Non-supported video test dataset: {type(opt["name"])}')
def __getitem__(self, index):
folder = self.data_info['folder'][index]
idx, max_idx = self.data_info['idx'][index].split('/')
idx, max_idx = int(idx), int(max_idx)
border = self.data_info['border'][index]
lq_path = self.data_info['lq_path'][index]
select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
if self.cache_data:
imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
img_gt = self.imgs_gt[folder][idx]
else:
img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
imgs_lq = read_img_seq(img_paths_lq)
img_gt = read_img_seq([self.imgs_gt[folder][idx]])
img_gt.squeeze_(0)
return {
'lq': imgs_lq, # (t, c, h, w)
'gt': img_gt, # (c, h, w)
'folder': folder, # folder name
'idx': self.data_info['idx'][index], # e.g., 0/99
'border': border, # 1 for border, 0 for non-border
'lq_path': lq_path # center frame
}
def __len__(self):
return len(self.data_info['gt_path'])
@DATASET_REGISTRY.register()
class VideoTestVimeo90KDataset(data.Dataset):
"""Video test dataset for Vimeo90k-Test dataset.
It only keeps the center frame for testing.
For testing datasets, there is no need to prepare LMDB files.
Args:
opt (dict): Config for train dataset. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
io_backend (dict): IO backend type and other kwarg.
cache_data (bool): Whether to cache testing datasets.
name (str): Dataset name.
meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
in the dataroot will be used.
num_frame (int): Window size for input frames.
padding (str): Padding mode.
"""
def __init__(self, opt):
super(VideoTestVimeo90KDataset, self).__init__()
self.opt = opt
self.cache_data = opt['cache_data']
if self.cache_data:
raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.')
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
logger = get_root_logger()
logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
with open(opt['meta_info_file'], 'r') as fin:
subfolders = [line.split(' ')[0] for line in fin]
for idx, subfolder in enumerate(subfolders):
gt_path = osp.join(self.gt_root, subfolder, 'im4.png')
self.data_info['gt_path'].append(gt_path)
lq_paths = [osp.join(self.lq_root, subfolder, f'im{i}.png') for i in neighbor_list]
self.data_info['lq_path'].append(lq_paths)
self.data_info['folder'].append('vimeo90k')
self.data_info['idx'].append(f'{idx}/{len(subfolders)}')
self.data_info['border'].append(0)
def __getitem__(self, index):
lq_path = self.data_info['lq_path'][index]
gt_path = self.data_info['gt_path'][index]
imgs_lq = read_img_seq(lq_path)
img_gt = read_img_seq([gt_path])
img_gt.squeeze_(0)
return {
'lq': imgs_lq, # (t, c, h, w)
'gt': img_gt, # (c, h, w)
'folder': self.data_info['folder'][index], # folder name
'idx': self.data_info['idx'][index], # e.g., 0/843
'border': self.data_info['border'][index], # 0 for non-border
'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame
}
def __len__(self):
return len(self.data_info['gt_path'])
@DATASET_REGISTRY.register()
class VideoTestDUFDataset(VideoTestDataset):
""" Video test dataset for DUF dataset.
Args:
opt (dict): Config for train dataset. Most of keys are the same as VideoTestDataset.
It has the following extra keys:
use_duf_downsampling (bool): Whether to use duf downsampling to generate low-resolution frames.
scale (bool): Scale, which will be added automatically.
"""
def __getitem__(self, index):
folder = self.data_info['folder'][index]
idx, max_idx = self.data_info['idx'][index].split('/')
idx, max_idx = int(idx), int(max_idx)
border = self.data_info['border'][index]
lq_path = self.data_info['lq_path'][index]
select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
if self.cache_data:
if self.opt['use_duf_downsampling']:
# read imgs_gt to generate low-resolution frames
imgs_lq = self.imgs_gt[folder].index_select(0, torch.LongTensor(select_idx))
imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
else:
imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
img_gt = self.imgs_gt[folder][idx]
else:
if self.opt['use_duf_downsampling']:
img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx]
# read imgs_gt to generate low-resolution frames
imgs_lq = read_img_seq(img_paths_lq, require_mod_crop=True, scale=self.opt['scale'])
imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
else:
img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
imgs_lq = read_img_seq(img_paths_lq)
img_gt = read_img_seq([self.imgs_gt[folder][idx]], require_mod_crop=True, scale=self.opt['scale'])
img_gt.squeeze_(0)
return {
'lq': imgs_lq, # (t, c, h, w)
'gt': img_gt, # (c, h, w)
'folder': folder, # folder name
'idx': self.data_info['idx'][index], # e.g., 0/99
'border': border, # 1 for border, 0 for non-border
'lq_path': lq_path # center frame
}
@DATASET_REGISTRY.register()
class VideoRecurrentTestDataset(VideoTestDataset):
"""Video test dataset for recurrent architectures, which takes LR video
frames as input and output corresponding HR video frames.
Args:
opt (dict): Same as VideoTestDataset. Unused opt:
padding (str): Padding mode.
"""
def __init__(self, opt):
super(VideoRecurrentTestDataset, self).__init__(opt)
# Find unique folder strings
self.folders = sorted(list(set(self.data_info['folder'])))
def __getitem__(self, index):
folder = self.folders[index]
if self.cache_data:
imgs_lq = self.imgs_lq[folder]
imgs_gt = self.imgs_gt[folder]
else:
raise NotImplementedError('Without cache_data is not implemented.')
return {
'lq': imgs_lq,
'gt': imgs_gt,
'folder': folder,
}
def __len__(self):
return len(self.folders)

View File

@@ -0,0 +1,199 @@
import random
import torch
from pathlib import Path
from torch.utils import data as data
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class Vimeo90KDataset(data.Dataset):
"""Vimeo90K dataset for training.
The keys are generated from a meta info txt file.
basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
Each line contains the following items, separated by a white space.
1. clip name;
2. frame number;
3. image shape
Examples:
::
00001/0001 7 (256,448,3)
00001/0002 7 (256,448,3)
- Key examples: "00001/0001"
- GT (gt): Ground-Truth;
- LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
The neighboring frame list for different num_frame:
::
num_frame | frame list
1 | 4
3 | 3,4,5
5 | 2,3,4,5,6
7 | 1,2,3,4,5,6,7
Args:
opt (dict): Config for train dataset. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
meta_info_file (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
num_frame (int): Window size for input frames.
gt_size (int): Cropped patched size for gt patches.
random_reverse (bool): Random reverse input frames.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
scale (bool): Scale, which will be added automatically.
"""
def __init__(self, opt):
super(Vimeo90KDataset, self).__init__()
self.opt = opt
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
with open(opt['meta_info_file'], 'r') as fin:
self.keys = [line.split(' ')[0] for line in fin]
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.is_lmdb = False
if self.io_backend_opt['type'] == 'lmdb':
self.is_lmdb = True
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
self.io_backend_opt['client_keys'] = ['lq', 'gt']
# indices of input images
self.neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
# temporal augmentation configs
self.random_reverse = opt['random_reverse']
logger = get_root_logger()
logger.info(f'Random reverse is {self.random_reverse}.')
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# random reverse
if self.random_reverse and random.random() < 0.5:
self.neighbor_list.reverse()
scale = self.opt['scale']
gt_size = self.opt['gt_size']
key = self.keys[index]
clip, seq = key.split('/') # key example: 00001/0001
# get the GT frame (im4.png)
if self.is_lmdb:
img_gt_path = f'{key}/im4'
else:
img_gt_path = self.gt_root / clip / seq / 'im4.png'
img_bytes = self.file_client.get(img_gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
# get the neighboring LQ frames
img_lqs = []
for neighbor in self.neighbor_list:
if self.is_lmdb:
img_lq_path = f'{clip}/{seq}/im{neighbor}'
else:
img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
img_bytes = self.file_client.get(img_lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
img_lqs.append(img_lq)
# randomly crop
img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
# augmentation - flip, rotate
img_lqs.append(img_gt)
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
img_results = img2tensor(img_results)
img_lqs = torch.stack(img_results[0:-1], dim=0)
img_gt = img_results[-1]
# img_lqs: (t, c, h, w)
# img_gt: (c, h, w)
# key: str
return {'lq': img_lqs, 'gt': img_gt, 'key': key}
def __len__(self):
return len(self.keys)
@DATASET_REGISTRY.register()
class Vimeo90KRecurrentDataset(Vimeo90KDataset):
def __init__(self, opt):
super(Vimeo90KRecurrentDataset, self).__init__(opt)
self.flip_sequence = opt['flip_sequence']
self.neighbor_list = [1, 2, 3, 4, 5, 6, 7]
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# random reverse
if self.random_reverse and random.random() < 0.5:
self.neighbor_list.reverse()
scale = self.opt['scale']
gt_size = self.opt['gt_size']
key = self.keys[index]
clip, seq = key.split('/') # key example: 00001/0001
# get the neighboring LQ and GT frames
img_lqs = []
img_gts = []
for neighbor in self.neighbor_list:
if self.is_lmdb:
img_lq_path = f'{clip}/{seq}/im{neighbor}'
img_gt_path = f'{clip}/{seq}/im{neighbor}'
else:
img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
img_gt_path = self.gt_root / clip / seq / f'im{neighbor}.png'
# LQ
img_bytes = self.file_client.get(img_lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
# GT
img_bytes = self.file_client.get(img_gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
img_lqs.append(img_lq)
img_gts.append(img_gt)
# randomly crop
img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
# augmentation - flip, rotate
img_lqs.extend(img_gts)
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
img_results = img2tensor(img_results)
img_lqs = torch.stack(img_results[:7], dim=0)
img_gts = torch.stack(img_results[7:], dim=0)
if self.flip_sequence: # flip the sequence: 7 frames to 14 frames
img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0)
img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0)
# img_lqs: (t, c, h, w)
# img_gt: (c, h, w)
# key: str
return {'lq': img_lqs, 'gt': img_gts, 'key': key}
def __len__(self):
return len(self.keys)

47
basicsr/utils/__init__.py Normal file
View File

@@ -0,0 +1,47 @@
from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb
from .diffjpeg import DiffJPEG
from .file_client import FileClient
from .img_process_util import USMSharp, usm_sharp
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
from .options import yaml_load
__all__ = [
# color_util.py
'bgr2ycbcr',
'rgb2ycbcr',
'rgb2ycbcr_pt',
'ycbcr2bgr',
'ycbcr2rgb',
# file_client.py
'FileClient',
# img_util.py
'img2tensor',
'tensor2img',
'imfrombytes',
'imwrite',
'crop_border',
# logger.py
'MessageLogger',
'AvgTimer',
'init_tb_logger',
'init_wandb_logger',
'get_root_logger',
'get_env_info',
# misc.py
'set_random_seed',
'get_time_str',
'mkdir_and_rename',
'make_exp_dirs',
'scandir',
'check_resume',
'sizeof_fmt',
# diffjpeg
'DiffJPEG',
# img_process_util
'USMSharp',
'usm_sharp',
# options
'yaml_load'
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

208
basicsr/utils/color_util.py Normal file
View File

@@ -0,0 +1,208 @@
import numpy as np
import torch
def rgb2ycbcr(img, y_only=False):
"""Convert a RGB image to YCbCr image.
This function produces the same results as Matlab's `rgb2ycbcr` function.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
if y_only:
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
else:
out_img = np.matmul(
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def bgr2ycbcr(img, y_only=False):
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
if y_only:
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
else:
out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def ycbcr2rgb(img):
"""Convert a YCbCr image to RGB image.
This function produces the same results as Matlab's ycbcr2rgb function.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
ndarray: The converted RGB image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def ycbcr2bgr(img):
"""Convert a YCbCr image to BGR image.
The bgr version of ycbcr2rgb.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
ndarray: The converted BGR image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
[0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def _convert_input_type_range(img):
"""Convert the type and range of the input image.
It converts the input image to np.float32 type and range of [0, 1].
It is mainly used for pre-processing the input image in colorspace
conversion functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
(ndarray): The converted image with type of np.float32 and range of
[0, 1].
"""
img_type = img.dtype
img = img.astype(np.float32)
if img_type == np.float32:
pass
elif img_type == np.uint8:
img /= 255.
else:
raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
return img
def _convert_output_type_range(img, dst_type):
"""Convert the type and range of the image according to dst_type.
It converts the image to desired type and range. If `dst_type` is np.uint8,
images will be converted to np.uint8 type with range [0, 255]. If
`dst_type` is np.float32, it converts the image to np.float32 type with
range [0, 1].
It is mainly used for post-processing images in colorspace conversion
functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The image to be converted with np.float32 type and
range [0, 255].
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
converts the image to np.uint8 type with range [0, 255]. If
dst_type is np.float32, it converts the image to np.float32 type
with range [0, 1].
Returns:
(ndarray): The converted image with desired type and range.
"""
if dst_type not in (np.uint8, np.float32):
raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
if dst_type == np.uint8:
img = img.round()
else:
img /= 255.
return img.astype(dst_type)
def rgb2ycbcr_pt(img, y_only=False):
"""Convert RGB images to YCbCr images (PyTorch version).
It implements the ITU-R BT.601 conversion for standard-definition television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
Args:
img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format.
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
(Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float.
"""
if y_only:
weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img)
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
else:
weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img)
bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img)
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
out_img = out_img / 255.
return out_img

515
basicsr/utils/diffjpeg.py Normal file
View File

@@ -0,0 +1,515 @@
"""
Modified from https://github.com/mlomnitz/DiffJPEG
For images not divisible by 8
https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343
"""
import itertools
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
# ------------------------ utils ------------------------#
y_table = np.array(
[[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56],
[14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92],
[49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
dtype=np.float32).T
y_table = nn.Parameter(torch.from_numpy(y_table))
c_table = np.empty((8, 8), dtype=np.float32)
c_table.fill(99)
c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T
c_table = nn.Parameter(torch.from_numpy(c_table))
def diff_round(x):
""" Differentiable rounding function
"""
return torch.round(x) + (x - torch.round(x))**3
def quality_to_factor(quality):
""" Calculate factor corresponding to quality
Args:
quality(float): Quality for jpeg compression.
Returns:
float: Compression factor.
"""
if quality < 50:
quality = 5000. / quality
else:
quality = 200. - quality * 2
return quality / 100.
# ------------------------ compression ------------------------#
class RGB2YCbCrJpeg(nn.Module):
""" Converts RGB image to YCbCr
"""
def __init__(self):
super(RGB2YCbCrJpeg, self).__init__()
matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]],
dtype=np.float32).T
self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
self.matrix = nn.Parameter(torch.from_numpy(matrix))
def forward(self, image):
"""
Args:
image(Tensor): batch x 3 x height x width
Returns:
Tensor: batch x height x width x 3
"""
image = image.permute(0, 2, 3, 1)
result = torch.tensordot(image, self.matrix, dims=1) + self.shift
return result.view(image.shape)
class ChromaSubsampling(nn.Module):
""" Chroma subsampling on CbCr channels
"""
def __init__(self):
super(ChromaSubsampling, self).__init__()
def forward(self, image):
"""
Args:
image(tensor): batch x height x width x 3
Returns:
y(tensor): batch x height x width
cb(tensor): batch x height/2 x width/2
cr(tensor): batch x height/2 x width/2
"""
image_2 = image.permute(0, 3, 1, 2).clone()
cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
cb = cb.permute(0, 2, 3, 1)
cr = cr.permute(0, 2, 3, 1)
return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
class BlockSplitting(nn.Module):
""" Splitting image into patches
"""
def __init__(self):
super(BlockSplitting, self).__init__()
self.k = 8
def forward(self, image):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x h*w/64 x h x w
"""
height, _ = image.shape[1:3]
batch_size = image.shape[0]
image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
class DCT8x8(nn.Module):
""" Discrete Cosine Transformation
"""
def __init__(self):
super(DCT8x8, self).__init__()
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
for x, y, u, v in itertools.product(range(8), repeat=4):
tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())
def forward(self, image):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
image = image - 128
result = self.scale * torch.tensordot(image, self.tensor, dims=2)
result.view(image.shape)
return result
class YQuantize(nn.Module):
""" JPEG Quantization for Y channel
Args:
rounding(function): rounding function to use
"""
def __init__(self, rounding):
super(YQuantize, self).__init__()
self.rounding = rounding
self.y_table = y_table
def forward(self, image, factor=1):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if isinstance(factor, (int, float)):
image = image.float() / (self.y_table * factor)
else:
b = factor.size(0)
table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
image = image.float() / table
image = self.rounding(image)
return image
class CQuantize(nn.Module):
""" JPEG Quantization for CbCr channels
Args:
rounding(function): rounding function to use
"""
def __init__(self, rounding):
super(CQuantize, self).__init__()
self.rounding = rounding
self.c_table = c_table
def forward(self, image, factor=1):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if isinstance(factor, (int, float)):
image = image.float() / (self.c_table * factor)
else:
b = factor.size(0)
table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
image = image.float() / table
image = self.rounding(image)
return image
class CompressJpeg(nn.Module):
"""Full JPEG compression algorithm
Args:
rounding(function): rounding function to use
"""
def __init__(self, rounding=torch.round):
super(CompressJpeg, self).__init__()
self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling())
self.l2 = nn.Sequential(BlockSplitting(), DCT8x8())
self.c_quantize = CQuantize(rounding=rounding)
self.y_quantize = YQuantize(rounding=rounding)
def forward(self, image, factor=1):
"""
Args:
image(tensor): batch x 3 x height x width
Returns:
dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8.
"""
y, cb, cr = self.l1(image * 255)
components = {'y': y, 'cb': cb, 'cr': cr}
for k in components.keys():
comp = self.l2(components[k])
if k in ('cb', 'cr'):
comp = self.c_quantize(comp, factor=factor)
else:
comp = self.y_quantize(comp, factor=factor)
components[k] = comp
return components['y'], components['cb'], components['cr']
# ------------------------ decompression ------------------------#
class YDequantize(nn.Module):
"""Dequantize Y channel
"""
def __init__(self):
super(YDequantize, self).__init__()
self.y_table = y_table
def forward(self, image, factor=1):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if isinstance(factor, (int, float)):
out = image * (self.y_table * factor)
else:
b = factor.size(0)
table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
out = image * table
return out
class CDequantize(nn.Module):
"""Dequantize CbCr channel
"""
def __init__(self):
super(CDequantize, self).__init__()
self.c_table = c_table
def forward(self, image, factor=1):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if isinstance(factor, (int, float)):
out = image * (self.c_table * factor)
else:
b = factor.size(0)
table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
out = image * table
return out
class iDCT8x8(nn.Module):
"""Inverse discrete Cosine Transformation
"""
def __init__(self):
super(iDCT8x8, self).__init__()
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
for x, y, u, v in itertools.product(range(8), repeat=4):
tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16)
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
def forward(self, image):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
image = image * self.alpha
result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
result.view(image.shape)
return result
class BlockMerging(nn.Module):
"""Merge patches into image
"""
def __init__(self):
super(BlockMerging, self).__init__()
def forward(self, patches, height, width):
"""
Args:
patches(tensor) batch x height*width/64, height x width
height(int)
width(int)
Returns:
Tensor: batch x height x width
"""
k = 8
batch_size = patches.shape[0]
image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
return image_transposed.contiguous().view(batch_size, height, width)
class ChromaUpsampling(nn.Module):
"""Upsample chroma layers
"""
def __init__(self):
super(ChromaUpsampling, self).__init__()
def forward(self, y, cb, cr):
"""
Args:
y(tensor): y channel image
cb(tensor): cb channel
cr(tensor): cr channel
Returns:
Tensor: batch x height x width x 3
"""
def repeat(x, k=2):
height, width = x.shape[1:3]
x = x.unsqueeze(-1)
x = x.repeat(1, 1, k, k)
x = x.view(-1, height * k, width * k)
return x
cb = repeat(cb)
cr = repeat(cr)
return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
class YCbCr2RGBJpeg(nn.Module):
"""Converts YCbCr image to RGB JPEG
"""
def __init__(self):
super(YCbCr2RGBJpeg, self).__init__()
matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T
self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
self.matrix = nn.Parameter(torch.from_numpy(matrix))
def forward(self, image):
"""
Args:
image(tensor): batch x height x width x 3
Returns:
Tensor: batch x 3 x height x width
"""
result = torch.tensordot(image + self.shift, self.matrix, dims=1)
return result.view(image.shape).permute(0, 3, 1, 2)
class DeCompressJpeg(nn.Module):
"""Full JPEG decompression algorithm
Args:
rounding(function): rounding function to use
"""
def __init__(self, rounding=torch.round):
super(DeCompressJpeg, self).__init__()
self.c_dequantize = CDequantize()
self.y_dequantize = YDequantize()
self.idct = iDCT8x8()
self.merging = BlockMerging()
self.chroma = ChromaUpsampling()
self.colors = YCbCr2RGBJpeg()
def forward(self, y, cb, cr, imgh, imgw, factor=1):
"""
Args:
compressed(dict(tensor)): batch x h*w/64 x 8 x 8
imgh(int)
imgw(int)
factor(float)
Returns:
Tensor: batch x 3 x height x width
"""
components = {'y': y, 'cb': cb, 'cr': cr}
for k in components.keys():
if k in ('cb', 'cr'):
comp = self.c_dequantize(components[k], factor=factor)
height, width = int(imgh / 2), int(imgw / 2)
else:
comp = self.y_dequantize(components[k], factor=factor)
height, width = imgh, imgw
comp = self.idct(comp)
components[k] = self.merging(comp, height, width)
#
image = self.chroma(components['y'], components['cb'], components['cr'])
image = self.colors(image)
image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image))
return image / 255
# ------------------------ main DiffJPEG ------------------------ #
class DiffJPEG(nn.Module):
"""This JPEG algorithm result is slightly different from cv2.
DiffJPEG supports batch processing.
Args:
differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round
"""
def __init__(self, differentiable=True):
super(DiffJPEG, self).__init__()
if differentiable:
rounding = diff_round
else:
rounding = torch.round
self.compress = CompressJpeg(rounding=rounding)
self.decompress = DeCompressJpeg(rounding=rounding)
def forward(self, x, quality):
"""
Args:
x (Tensor): Input image, bchw, rgb, [0, 1]
quality(float): Quality factor for jpeg compression scheme.
"""
factor = quality
if isinstance(factor, (int, float)):
factor = quality_to_factor(factor)
else:
for i in range(factor.size(0)):
factor[i] = quality_to_factor(factor[i])
h, w = x.size()[-2:]
h_pad, w_pad = 0, 0
# why should use 16
if h % 16 != 0:
h_pad = 16 - h % 16
if w % 16 != 0:
w_pad = 16 - w % 16
x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0)
y, cb, cr = self.compress(x, factor=factor)
recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor)
recovered = recovered[:, :, 0:h, 0:w]
return recovered
if __name__ == '__main__':
import cv2
from basicsr.utils import img2tensor, tensor2img
img_gt = cv2.imread('test.png') / 255.
# -------------- cv2 -------------- #
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20]
_, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param)
img_lq = np.float32(cv2.imdecode(encimg, 1))
cv2.imwrite('cv2_JPEG_20.png', img_lq)
# -------------- DiffJPEG -------------- #
jpeger = DiffJPEG(differentiable=False).cuda()
img_gt = img2tensor(img_gt)
img_gt = torch.stack([img_gt, img_gt]).cuda()
quality = img_gt.new_tensor([20, 40])
out = jpeger(img_gt, quality=quality)
cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0]))
cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1]))

View File

@@ -0,0 +1,82 @@
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
import functools
import os
import subprocess
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_slurm(backend, **kwargs)
else:
raise ValueError(f'Invalid launcher type: {launcher}')
def _init_dist_pytorch(backend, **kwargs):
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_slurm(backend, port=None):
"""Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
environment variable, then a default port ``29500`` will be used.
Args:
backend (str): Backend of torch.distributed.
port (int, optional): Master port. Defaults to None.
"""
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus)
addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
# specify master port
if port is not None:
os.environ['MASTER_PORT'] = str(port)
elif 'MASTER_PORT' in os.environ:
pass # use MASTER_PORT in the environment variable
else:
# 29500 is torch.distributed default port
os.environ['MASTER_PORT'] = '29500'
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
os.environ['RANK'] = str(proc_id)
dist.init_process_group(backend=backend)
def get_dist_info():
if dist.is_available():
initialized = dist.is_initialized()
else:
initialized = False
if initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
def master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)
return wrapper

View File

@@ -0,0 +1,98 @@
import math
import os
import requests
from torch.hub import download_url_to_file, get_dir
from tqdm import tqdm
from urllib.parse import urlparse
from .misc import sizeof_fmt
def download_file_from_google_drive(file_id, save_path):
"""Download files from google drive.
Reference: https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive
Args:
file_id (str): File id.
save_path (str): Save path.
"""
session = requests.Session()
URL = 'https://docs.google.com/uc?export=download'
params = {'id': file_id}
response = session.get(URL, params=params, stream=True)
token = get_confirm_token(response)
if token:
params['confirm'] = token
response = session.get(URL, params=params, stream=True)
# get file size
response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
if 'Content-Range' in response_file_size.headers:
file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
else:
file_size = None
save_response_content(response, save_path, file_size)
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def save_response_content(response, destination, file_size=None, chunk_size=32768):
if file_size is not None:
pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
readable_file_size = sizeof_fmt(file_size)
else:
pbar = None
with open(destination, 'wb') as f:
downloaded_size = 0
for chunk in response.iter_content(chunk_size):
downloaded_size += chunk_size
if pbar is not None:
pbar.update(1)
pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
if chunk: # filter out keep-alive new chunks
f.write(chunk)
if pbar is not None:
pbar.close()
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Load file form http url, will download models if necessary.
Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
Args:
url (str): URL to be downloaded.
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
Default: None.
progress (bool): Whether to show the download progress. Default: True.
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
Returns:
str: The path to the downloaded file.
"""
if model_dir is None: # use the pytorch hub_dir
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file

View File

@@ -0,0 +1,167 @@
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
from abc import ABCMeta, abstractmethod
class BaseStorageBackend(metaclass=ABCMeta):
"""Abstract class of storage backends.
All backends need to implement two apis: ``get()`` and ``get_text()``.
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
as texts.
"""
@abstractmethod
def get(self, filepath):
pass
@abstractmethod
def get_text(self, filepath):
pass
class MemcachedBackend(BaseStorageBackend):
"""Memcached storage backend.
Attributes:
server_list_cfg (str): Config file for memcached server list.
client_cfg (str): Config file for memcached client.
sys_path (str | None): Additional path to be appended to `sys.path`.
Default: None.
"""
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
if sys_path is not None:
import sys
sys.path.append(sys_path)
try:
import mc
except ImportError:
raise ImportError('Please install memcached to enable MemcachedBackend.')
self.server_list_cfg = server_list_cfg
self.client_cfg = client_cfg
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
# mc.pyvector servers as a point which points to a memory cache
self._mc_buffer = mc.pyvector()
def get(self, filepath):
filepath = str(filepath)
import mc
self._client.Get(filepath, self._mc_buffer)
value_buf = mc.ConvertBuffer(self._mc_buffer)
return value_buf
def get_text(self, filepath):
raise NotImplementedError
class HardDiskBackend(BaseStorageBackend):
"""Raw hard disks storage backend."""
def get(self, filepath):
filepath = str(filepath)
with open(filepath, 'rb') as f:
value_buf = f.read()
return value_buf
def get_text(self, filepath):
filepath = str(filepath)
with open(filepath, 'r') as f:
value_buf = f.read()
return value_buf
class LmdbBackend(BaseStorageBackend):
"""Lmdb storage backend.
Args:
db_paths (str | list[str]): Lmdb database paths.
client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
readonly (bool, optional): Lmdb environment parameter. If True,
disallow any write operations. Default: True.
lock (bool, optional): Lmdb environment parameter. If False, when
concurrent access occurs, do not lock the database. Default: False.
readahead (bool, optional): Lmdb environment parameter. If False,
disable the OS filesystem readahead mechanism, which may improve
random read performance when a database is larger than RAM.
Default: False.
Attributes:
db_paths (list): Lmdb database path.
_client (list): A list of several lmdb envs.
"""
def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
try:
import lmdb
except ImportError:
raise ImportError('Please install lmdb to enable LmdbBackend.')
if isinstance(client_keys, str):
client_keys = [client_keys]
if isinstance(db_paths, list):
self.db_paths = [str(v) for v in db_paths]
elif isinstance(db_paths, str):
self.db_paths = [str(db_paths)]
assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
f'but received {len(client_keys)} and {len(self.db_paths)}.')
self._client = {}
for client, path in zip(client_keys, self.db_paths):
self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
def get(self, filepath, client_key):
"""Get values according to the filepath from one lmdb named client_key.
Args:
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
client_key (str): Used for distinguishing different lmdb envs.
"""
filepath = str(filepath)
assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.')
client = self._client[client_key]
with client.begin(write=False) as txn:
value_buf = txn.get(filepath.encode('ascii'))
return value_buf
def get_text(self, filepath):
raise NotImplementedError
class FileClient(object):
"""A general file client to access files in different backend.
The client loads a file or text in a specified backend from its path
and return it as a binary file. it can also register other backend
accessor with a given name and backend class.
Attributes:
backend (str): The storage backend type. Options are "disk",
"memcached" and "lmdb".
client (:obj:`BaseStorageBackend`): The backend object.
"""
_backends = {
'disk': HardDiskBackend,
'memcached': MemcachedBackend,
'lmdb': LmdbBackend,
}
def __init__(self, backend='disk', **kwargs):
if backend not in self._backends:
raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
f' are {list(self._backends.keys())}')
self.backend = backend
self.client = self._backends[backend](**kwargs)
def get(self, filepath, client_key='default'):
# client_key is used only for lmdb, where different fileclients have
# different lmdb environments.
if self.backend == 'lmdb':
return self.client.get(filepath, client_key)
else:
return self.client.get(filepath)
def get_text(self, filepath):
return self.client.get_text(filepath)

170
basicsr/utils/flow_util.py Normal file
View File

@@ -0,0 +1,170 @@
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
import cv2
import numpy as np
import os
def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
"""Read an optical flow map.
Args:
flow_path (ndarray or str): Flow path.
quantize (bool): whether to read quantized pair, if set to True,
remaining args will be passed to :func:`dequantize_flow`.
concat_axis (int): The axis that dx and dy are concatenated,
can be either 0 or 1. Ignored if quantize is False.
Returns:
ndarray: Optical flow represented as a (h, w, 2) numpy array
"""
if quantize:
assert concat_axis in [0, 1]
cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
if cat_flow.ndim != 2:
raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.')
assert cat_flow.shape[concat_axis] % 2 == 0
dx, dy = np.split(cat_flow, 2, axis=concat_axis)
flow = dequantize_flow(dx, dy, *args, **kwargs)
else:
with open(flow_path, 'rb') as f:
try:
header = f.read(4).decode('utf-8')
except Exception:
raise IOError(f'Invalid flow file: {flow_path}')
else:
if header != 'PIEH':
raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH')
w = np.fromfile(f, np.int32, 1).squeeze()
h = np.fromfile(f, np.int32, 1).squeeze()
flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
return flow.astype(np.float32)
def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
"""Write optical flow to file.
If the flow is not quantized, it will be saved as a .flo file losslessly,
otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
will be concatenated horizontally into a single image if quantize is True.)
Args:
flow (ndarray): (h, w, 2) array of optical flow.
filename (str): Output filepath.
quantize (bool): Whether to quantize the flow and save it to 2 jpeg
images. If set to True, remaining args will be passed to
:func:`quantize_flow`.
concat_axis (int): The axis that dx and dy are concatenated,
can be either 0 or 1. Ignored if quantize is False.
"""
if not quantize:
with open(filename, 'wb') as f:
f.write('PIEH'.encode('utf-8'))
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
flow = flow.astype(np.float32)
flow.tofile(f)
f.flush()
else:
assert concat_axis in [0, 1]
dx, dy = quantize_flow(flow, *args, **kwargs)
dxdy = np.concatenate((dx, dy), axis=concat_axis)
os.makedirs(os.path.dirname(filename), exist_ok=True)
cv2.imwrite(filename, dxdy)
def quantize_flow(flow, max_val=0.02, norm=True):
"""Quantize flow to [0, 255].
After this step, the size of flow will be much smaller, and can be
dumped as jpeg images.
Args:
flow (ndarray): (h, w, 2) array of optical flow.
max_val (float): Maximum value of flow, values beyond
[-max_val, max_val] will be truncated.
norm (bool): Whether to divide flow values by image width/height.
Returns:
tuple[ndarray]: Quantized dx and dy.
"""
h, w, _ = flow.shape
dx = flow[..., 0]
dy = flow[..., 1]
if norm:
dx = dx / w # avoid inplace operations
dy = dy / h
# use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]]
return tuple(flow_comps)
def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
"""Recover from quantized flow.
Args:
dx (ndarray): Quantized dx.
dy (ndarray): Quantized dy.
max_val (float): Maximum value used when quantizing.
denorm (bool): Whether to multiply flow values with width/height.
Returns:
ndarray: Dequantized flow.
"""
assert dx.shape == dy.shape
assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
if denorm:
dx *= dx.shape[1]
dy *= dx.shape[0]
flow = np.dstack((dx, dy))
return flow
def quantize(arr, min_val, max_val, levels, dtype=np.int64):
"""Quantize an array of (-inf, inf) to [0, levels-1].
Args:
arr (ndarray): Input array.
min_val (scalar): Minimum value to be clipped.
max_val (scalar): Maximum value to be clipped.
levels (int): Quantization levels.
dtype (np.type): The type of the quantized array.
Returns:
tuple: Quantized array.
"""
if not (isinstance(levels, int) and levels > 1):
raise ValueError(f'levels must be a positive integer, but got {levels}')
if min_val >= max_val:
raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
arr = np.clip(arr, min_val, max_val) - min_val
quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
return quantized_arr
def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
"""Dequantize an array.
Args:
arr (ndarray): Input array.
min_val (scalar): Minimum value to be clipped.
max_val (scalar): Maximum value to be clipped.
levels (int): Quantization levels.
dtype (np.type): The type of the dequantized array.
Returns:
tuple: Dequantized array.
"""
if not (isinstance(levels, int) and levels > 1):
raise ValueError(f'levels must be a positive integer, but got {levels}')
if min_val >= max_val:
raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val
return dequantized_arr

View File

@@ -0,0 +1,83 @@
import cv2
import numpy as np
import torch
from torch.nn import functional as F
def filter2D(img, kernel):
"""PyTorch version of cv2.filter2D
Args:
img (Tensor): (b, c, h, w)
kernel (Tensor): (b, k, k)
"""
k = kernel.size(-1)
b, c, h, w = img.size()
if k % 2 == 1:
img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
else:
raise ValueError('Wrong kernel size')
ph, pw = img.size()[-2:]
if kernel.size(0) == 1:
# apply the same kernel to all batch images
img = img.view(b * c, 1, ph, pw)
kernel = kernel.view(1, 1, k, k)
return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
else:
img = img.view(1, b * c, ph, pw)
kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
def usm_sharp(img, weight=0.5, radius=50, threshold=10):
"""USM sharpening.
Input image: I; Blurry image: B.
1. sharp = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * sharp + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if radius % 2 == 0:
radius += 1
blur = cv2.GaussianBlur(img, (radius, radius), 0)
residual = img - blur
mask = np.abs(residual) * 255 > threshold
mask = mask.astype('float32')
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
sharp = img + weight * residual
sharp = np.clip(sharp, 0, 1)
return soft_mask * sharp + (1 - soft_mask) * img
class USMSharp(torch.nn.Module):
def __init__(self, radius=50, sigma=0):
super(USMSharp, self).__init__()
if radius % 2 == 0:
radius += 1
self.radius = radius
kernel = cv2.getGaussianKernel(radius, sigma)
kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
self.register_buffer('kernel', kernel)
def forward(self, img, weight=0.5, threshold=10):
blur = filter2D(img, self.kernel)
residual = img - blur
mask = torch.abs(residual) * 255 > threshold
mask = mask.float()
soft_mask = filter2D(mask, self.kernel)
sharp = img + weight * residual
sharp = torch.clip(sharp, 0, 1)
return soft_mask * sharp + (1 - soft_mask) * img

172
basicsr/utils/img_util.py Normal file
View File

@@ -0,0 +1,172 @@
import cv2
import math
import numpy as np
import os
import torch
from torchvision.utils import make_grid
def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""
def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
if img.dtype == 'float64':
img = img.astype('float32')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img
if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
"""Convert torch Tensors into image numpy arrays.
After clamping to [min, max], values will be normalized to [0, 1].
Args:
tensor (Tensor or list[Tensor]): Accept shapes:
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
2) 3D Tensor of shape (3/1 x H x W);
3) 2D Tensor of shape (H x W).
Tensor channel should be in RGB order.
rgb2bgr (bool): Whether to change rgb to bgr.
out_type (numpy type): output types. If ``np.uint8``, transform outputs
to uint8 type with range [0, 255]; otherwise, float type with
range [0, 1]. Default: ``np.uint8``.
min_max (tuple[int]): min and max values for clamp.
Returns:
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
shape (H x W). The channel order is BGR.
"""
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
if torch.is_tensor(tensor):
tensor = [tensor]
result = []
for _tensor in tensor:
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
n_dim = _tensor.dim()
if n_dim == 4:
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
img_np = img_np.transpose(1, 2, 0)
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 3:
img_np = _tensor.numpy()
img_np = img_np.transpose(1, 2, 0)
if img_np.shape[2] == 1: # gray image
img_np = np.squeeze(img_np, axis=2)
else:
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 2:
img_np = _tensor.numpy()
else:
raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
if out_type == np.uint8:
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
img_np = (img_np * 255.0).round()
img_np = img_np.astype(out_type)
result.append(img_np)
if len(result) == 1 and torch.is_tensor(tensor):
result = result[0]
return result
def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
"""This implementation is slightly faster than tensor2img.
It now only supports torch tensor with shape (1, c, h, w).
Args:
tensor (Tensor): Now only support torch tensor with (1, c, h, w).
rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
min_max (tuple[int]): min and max values for clamp.
"""
output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
output = output.type(torch.uint8).cpu().numpy()
if rgb2bgr:
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
def imfrombytes(content, flag='color', float32=False):
"""Read an image from bytes.
Args:
content (bytes): Image bytes got from files or other streams.
flag (str): Flags specifying the color type of a loaded image,
candidates are `color`, `grayscale` and `unchanged`.
float32 (bool): Whether to change to float32., If True, will also norm
to [0, 1]. Default: False.
Returns:
ndarray: Loaded image array.
"""
img_np = np.frombuffer(content, np.uint8)
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
img = cv2.imdecode(img_np, imread_flags[flag])
if float32:
img = img.astype(np.float32) / 255.
return img
def imwrite(img, file_path, params=None, auto_mkdir=True):
"""Write image to file.
Args:
img (ndarray): Image array to be written.
file_path (str): Image file path.
params (None or list): Same as opencv's :func:`imwrite` interface.
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
whether to create it automatically.
Returns:
bool: Successful or not.
"""
if auto_mkdir:
dir_name = os.path.abspath(os.path.dirname(file_path))
os.makedirs(dir_name, exist_ok=True)
ok = cv2.imwrite(file_path, img, params)
if not ok:
raise IOError('Failed in writing images.')
def crop_border(imgs, crop_border):
"""Crop borders of images.
Args:
imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
crop_border (int): Crop border for each end of height and weight.
Returns:
list[ndarray]: Cropped images.
"""
if crop_border == 0:
return imgs
else:
if isinstance(imgs, list):
return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
else:
return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]

199
basicsr/utils/lmdb_util.py Normal file
View File

@@ -0,0 +1,199 @@
import cv2
import lmdb
import sys
from multiprocessing import Pool
from os import path as osp
from tqdm import tqdm
def make_lmdb_from_imgs(data_path,
lmdb_path,
img_path_list,
keys,
batch=5000,
compress_level=1,
multiprocessing_read=False,
n_thread=40,
map_size=None):
"""Make lmdb from images.
Contents of lmdb. The file structure is:
::
example.lmdb
├── data.mdb
├── lock.mdb
├── meta_info.txt
The data.mdb and lock.mdb are standard lmdb files and you can refer to
https://lmdb.readthedocs.io/en/release/ for more details.
The meta_info.txt is a specified txt file to record the meta information
of our datasets. It will be automatically created when preparing
datasets by our provided dataset tools.
Each line in the txt file records 1)image name (with extension),
2)image shape, and 3)compression level, separated by a white space.
For example, the meta information could be:
`000_00000000.png (720,1280,3) 1`, which means:
1) image name (with extension): 000_00000000.png;
2) image shape: (720,1280,3);
3) compression level: 1
We use the image name without extension as the lmdb key.
If `multiprocessing_read` is True, it will read all the images to memory
using multiprocessing. Thus, your server needs to have enough memory.
Args:
data_path (str): Data path for reading images.
lmdb_path (str): Lmdb save path.
img_path_list (str): Image path list.
keys (str): Used for lmdb keys.
batch (int): After processing batch images, lmdb commits.
Default: 5000.
compress_level (int): Compress level when encoding images. Default: 1.
multiprocessing_read (bool): Whether use multiprocessing to read all
the images to memory. Default: False.
n_thread (int): For multiprocessing.
map_size (int | None): Map size for lmdb env. If None, use the
estimated size from images. Default: None
"""
assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
f'but got {len(img_path_list)} and {len(keys)}')
print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
print(f'Totoal images: {len(img_path_list)}')
if not lmdb_path.endswith('.lmdb'):
raise ValueError("lmdb_path must end with '.lmdb'.")
if osp.exists(lmdb_path):
print(f'Folder {lmdb_path} already exists. Exit.')
sys.exit(1)
if multiprocessing_read:
# read all the images to memory (multiprocessing)
dataset = {} # use dict to keep the order for multiprocessing
shapes = {}
print(f'Read images with multiprocessing, #thread: {n_thread} ...')
pbar = tqdm(total=len(img_path_list), unit='image')
def callback(arg):
"""get the image data and update pbar."""
key, dataset[key], shapes[key] = arg
pbar.update(1)
pbar.set_description(f'Read {key}')
pool = Pool(n_thread)
for path, key in zip(img_path_list, keys):
pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
pool.close()
pool.join()
pbar.close()
print(f'Finish reading {len(img_path_list)} images.')
# create lmdb environment
if map_size is None:
# obtain data size for one image
img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
data_size_per_img = img_byte.nbytes
print('Data size per image is: ', data_size_per_img)
data_size = data_size_per_img * len(img_path_list)
map_size = data_size * 10
env = lmdb.open(lmdb_path, map_size=map_size)
# write data to lmdb
pbar = tqdm(total=len(img_path_list), unit='chunk')
txn = env.begin(write=True)
txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
for idx, (path, key) in enumerate(zip(img_path_list, keys)):
pbar.update(1)
pbar.set_description(f'Write {key}')
key_byte = key.encode('ascii')
if multiprocessing_read:
img_byte = dataset[key]
h, w, c = shapes[key]
else:
_, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
h, w, c = img_shape
txn.put(key_byte, img_byte)
# write meta information
txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
if idx % batch == 0:
txn.commit()
txn = env.begin(write=True)
pbar.close()
txn.commit()
env.close()
txt_file.close()
print('\nFinish writing lmdb.')
def read_img_worker(path, key, compress_level):
"""Read image worker.
Args:
path (str): Image path.
key (str): Image key.
compress_level (int): Compress level when encoding images.
Returns:
str: Image key.
byte: Image byte.
tuple[int]: Image shape.
"""
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if img.ndim == 2:
h, w = img.shape
c = 1
else:
h, w, c = img.shape
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
return (key, img_byte, (h, w, c))
class LmdbMaker():
"""LMDB Maker.
Args:
lmdb_path (str): Lmdb save path.
map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
batch (int): After processing batch images, lmdb commits.
Default: 5000.
compress_level (int): Compress level when encoding images. Default: 1.
"""
def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
if not lmdb_path.endswith('.lmdb'):
raise ValueError("lmdb_path must end with '.lmdb'.")
if osp.exists(lmdb_path):
print(f'Folder {lmdb_path} already exists. Exit.')
sys.exit(1)
self.lmdb_path = lmdb_path
self.batch = batch
self.compress_level = compress_level
self.env = lmdb.open(lmdb_path, map_size=map_size)
self.txn = self.env.begin(write=True)
self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
self.counter = 0
def put(self, img_byte, key, img_shape):
self.counter += 1
key_byte = key.encode('ascii')
self.txn.put(key_byte, img_byte)
# write meta information
h, w, c = img_shape
self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
if self.counter % self.batch == 0:
self.txn.commit()
self.txn = self.env.begin(write=True)
def close(self):
self.txn.commit()
self.env.close()
self.txt_file.close()

213
basicsr/utils/logger.py Normal file
View File

@@ -0,0 +1,213 @@
import datetime
import logging
import time
from .dist_util import get_dist_info, master_only
initialized_logger = {}
class AvgTimer():
def __init__(self, window=200):
self.window = window # average window
self.current_time = 0
self.total_time = 0
self.count = 0
self.avg_time = 0
self.start()
def start(self):
self.start_time = self.tic = time.time()
def record(self):
self.count += 1
self.toc = time.time()
self.current_time = self.toc - self.tic
self.total_time += self.current_time
# calculate average time
self.avg_time = self.total_time / self.count
# reset
if self.count > self.window:
self.count = 0
self.total_time = 0
self.tic = time.time()
def get_current_time(self):
return self.current_time
def get_avg_time(self):
return self.avg_time
class MessageLogger():
"""Message logger for printing.
Args:
opt (dict): Config. It contains the following keys:
name (str): Exp name.
logger (dict): Contains 'print_freq' (str) for logger interval.
train (dict): Contains 'total_iter' (int) for total iters.
use_tb_logger (bool): Use tensorboard logger.
start_iter (int): Start iter. Default: 1.
tb_logger (obj:`tb_logger`): Tensorboard logger. Default None.
"""
def __init__(self, opt, start_iter=1, tb_logger=None):
self.exp_name = opt['name']
self.interval = opt['logger']['print_freq']
self.start_iter = start_iter
self.max_iters = opt['train']['total_iter']
self.use_tb_logger = opt['logger']['use_tb_logger']
self.tb_logger = tb_logger
self.start_time = time.time()
self.logger = get_root_logger()
def reset_start_time(self):
self.start_time = time.time()
@master_only
def __call__(self, log_vars):
"""Format logging message.
Args:
log_vars (dict): It contains the following keys:
epoch (int): Epoch number.
iter (int): Current iter.
lrs (list): List for learning rates.
time (float): Iter time.
data_time (float): Data time for each iter.
"""
# epoch, iter, learning rates
epoch = log_vars.pop('epoch')
current_iter = log_vars.pop('iter')
lrs = log_vars.pop('lrs')
message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
for v in lrs:
message += f'{v:.3e},'
message += ')] '
# time and estimated time
if 'time' in log_vars.keys():
iter_time = log_vars.pop('time')
data_time = log_vars.pop('data_time')
total_time = time.time() - self.start_time
time_sec_avg = total_time / (current_iter - self.start_iter + 1)
eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
message += f'[eta: {eta_str}, '
message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
# other items, especially losses
for k, v in log_vars.items():
message += f'{k}: {v:.4e} '
# tensorboard logger
if self.use_tb_logger and 'debug' not in self.exp_name:
if k.startswith('l_'):
self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
else:
self.tb_logger.add_scalar(k, v, current_iter)
self.logger.info(message)
@master_only
def init_tb_logger(log_dir):
from torch.utils.tensorboard import SummaryWriter
tb_logger = SummaryWriter(log_dir=log_dir)
return tb_logger
@master_only
def init_wandb_logger(opt):
"""We now only use wandb to sync tensorboard log."""
import wandb
logger = get_root_logger()
project = opt['logger']['wandb']['project']
resume_id = opt['logger']['wandb'].get('resume_id')
if resume_id:
wandb_id = resume_id
resume = 'allow'
logger.warning(f'Resume wandb logger with id={wandb_id}.')
else:
wandb_id = wandb.util.generate_id()
resume = 'never'
wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
"""Get the root logger.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added.
Args:
logger_name (str): root logger name. Default: 'basicsr'.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the root logger.
log_level (int): The root logger level. Note that only the process of
rank 0 is affected, while other processes will set the level to
"Error" and be silent most of the time.
Returns:
logging.Logger: The root logger.
"""
logger = logging.getLogger(logger_name)
# if the logger has been initialized, just return it
if logger_name in initialized_logger:
return logger
format_str = '%(asctime)s %(levelname)s: %(message)s'
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter(format_str))
logger.addHandler(stream_handler)
logger.propagate = False
rank, _ = get_dist_info()
if rank != 0:
logger.setLevel('ERROR')
elif log_file is not None:
logger.setLevel(log_level)
# add file handler
file_handler = logging.FileHandler(log_file, 'w')
file_handler.setFormatter(logging.Formatter(format_str))
file_handler.setLevel(log_level)
logger.addHandler(file_handler)
initialized_logger[logger_name] = True
return logger
def get_env_info():
"""Get environment information.
Currently, only log the software version.
"""
import torch
import torchvision
from basicsr.version import __version__
msg = r"""
____ _ _____ ____
/ __ ) ____ _ _____ (_)_____/ ___/ / __ \
/ __ |/ __ `// ___// // ___/\__ \ / /_/ /
/ /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
/_____/ \__,_//____//_/ \___//____//_/ |_|
______ __ __ __ __
/ ____/____ ____ ____/ / / / __ __ _____ / /__ / /
/ / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
/ /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
\____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
"""
msg += ('\nVersion Information: '
f'\n\tBasicSR: {__version__}'
f'\n\tPyTorch: {torch.__version__}'
f'\n\tTorchVision: {torchvision.__version__}')
return msg

View File

@@ -0,0 +1,178 @@
import math
import numpy as np
import torch
def cubic(x):
"""cubic function used for calculate_weights_indices."""
absx = torch.abs(x)
absx2 = absx**2
absx3 = absx**3
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
(absx <= 2)).type_as(absx))
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
"""Calculate weights and indices, used for imresize function.
Args:
in_length (int): Input length.
out_length (int): Output length.
scale (float): Scale factor.
kernel_width (int): Kernel width.
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
"""
if (scale < 1) and antialiasing:
# Use a modified kernel (larger kernel width) to simultaneously
# interpolate and antialias
kernel_width = kernel_width / scale
# Output-space coordinates
x = torch.linspace(1, out_length, out_length)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5 + scale in output
# space maps to 1.5 in input space.
u = x / scale + 0.5 * (1 - 1 / scale)
# What is the left-most pixel that can be involved in the computation?
left = torch.floor(u - kernel_width / 2)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
p = math.ceil(kernel_width) + 2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
out_length, p)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
# apply cubic kernel
if (scale < 1) and antialiasing:
weights = scale * cubic(distance_to_center * scale)
else:
weights = cubic(distance_to_center)
# Normalize the weights matrix so that each row sums to 1.
weights_sum = torch.sum(weights, 1).view(out_length, 1)
weights = weights / weights_sum.expand(out_length, p)
# If a column in weights is all zero, get rid of it. only consider the
# first and last column.
weights_zero_tmp = torch.sum((weights == 0), 0)
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
indices = indices.narrow(1, 1, p - 2)
weights = weights.narrow(1, 1, p - 2)
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
indices = indices.narrow(1, 0, p - 2)
weights = weights.narrow(1, 0, p - 2)
weights = weights.contiguous()
indices = indices.contiguous()
sym_len_s = -indices.min() + 1
sym_len_e = indices.max() - in_length
indices = indices + sym_len_s - 1
return weights, indices, int(sym_len_s), int(sym_len_e)
@torch.no_grad()
def imresize(img, scale, antialiasing=True):
"""imresize function same as MATLAB.
It now only supports bicubic.
The same scale applies for both height and width.
Args:
img (Tensor | Numpy array):
Tensor: Input image with shape (c, h, w), [0, 1] range.
Numpy: Input image with shape (h, w, c), [0, 1] range.
scale (float): Scale factor. The same scale applies for both height
and width.
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
Default: True.
Returns:
Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
"""
squeeze_flag = False
if type(img).__module__ == np.__name__: # numpy type
numpy_type = True
if img.ndim == 2:
img = img[:, :, None]
squeeze_flag = True
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
else:
numpy_type = False
if img.ndim == 2:
img = img.unsqueeze(0)
squeeze_flag = True
in_c, in_h, in_w = img.size()
out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
kernel_width = 4
kernel = 'cubic'
# get weights and indices
weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
antialiasing)
weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
antialiasing)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
sym_patch = img[:, :sym_len_hs, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
sym_patch = img[:, -sym_len_he:, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
out_1 = torch.FloatTensor(in_c, out_h, in_w)
kernel_width = weights_h.size(1)
for i in range(out_h):
idx = int(indices_h[i][0])
for j in range(in_c):
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
# process W dimension
# symmetric copying
out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
sym_patch = out_1[:, :, :sym_len_ws]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
sym_patch = out_1[:, :, -sym_len_we:]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
out_2 = torch.FloatTensor(in_c, out_h, out_w)
kernel_width = weights_w.size(1)
for i in range(out_w):
idx = int(indices_w[i][0])
for j in range(in_c):
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
if squeeze_flag:
out_2 = out_2.squeeze(0)
if numpy_type:
out_2 = out_2.numpy()
if not squeeze_flag:
out_2 = out_2.transpose(1, 2, 0)
return out_2

141
basicsr/utils/misc.py Normal file
View File

@@ -0,0 +1,141 @@
import numpy as np
import os
import random
import time
import torch
from os import path as osp
from .dist_util import master_only
def set_random_seed(seed):
"""Set random seeds."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_time_str():
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
def mkdir_and_rename(path):
"""mkdirs. If path exists, rename it with timestamp and create a new one.
Args:
path (str): Folder path.
"""
if osp.exists(path):
new_name = path + '_archived_' + get_time_str()
print(f'Path already exists. Rename it to {new_name}', flush=True)
os.rename(path, new_name)
os.makedirs(path, exist_ok=True)
@master_only
def make_exp_dirs(opt):
"""Make dirs for experiments."""
path_opt = opt['path'].copy()
if opt['is_train']:
mkdir_and_rename(path_opt.pop('experiments_root'))
else:
mkdir_and_rename(path_opt.pop('results_root'))
for key, path in path_opt.items():
if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key):
continue
else:
os.makedirs(path, exist_ok=True)
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
"""Scan a directory to find the interested files.
Args:
dir_path (str): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
full_path (bool, optional): If set to True, include the dir_path.
Default: False.
Returns:
A generator for all the interested files with relative paths.
"""
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
root = dir_path
def _scandir(dir_path, suffix, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
if full_path:
return_path = entry.path
else:
return_path = osp.relpath(entry.path, root)
if suffix is None:
yield return_path
elif return_path.endswith(suffix):
yield return_path
else:
if recursive:
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
else:
continue
return _scandir(dir_path, suffix=suffix, recursive=recursive)
def check_resume(opt, resume_iter):
"""Check resume states and pretrain_network paths.
Args:
opt (dict): Options.
resume_iter (int): Resume iteration.
"""
if opt['path']['resume_state']:
# get all the networks
networks = [key for key in opt.keys() if key.startswith('network_')]
flag_pretrain = False
for network in networks:
if opt['path'].get(f'pretrain_{network}') is not None:
flag_pretrain = True
if flag_pretrain:
print('pretrain_network path will be ignored during resuming.')
# set pretrained model paths
for network in networks:
name = f'pretrain_{network}'
basename = network.replace('network_', '')
if opt['path'].get('ignore_resume_networks') is None or (network
not in opt['path']['ignore_resume_networks']):
opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
print(f"Set {name} to {opt['path'][name]}")
# change param_key to params in resume
param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')]
for param_key in param_keys:
if opt['path'][param_key] == 'params_ema':
opt['path'][param_key] = 'params'
print(f'Set {param_key} to params')
def sizeof_fmt(size, suffix='B'):
"""Get human readable file size.
Args:
size (int): File size.
suffix (str): Suffix. Default: 'B'.
Return:
str: Formatted file size.
"""
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
if abs(size) < 1024.0:
return f'{size:3.1f} {unit}{suffix}'
size /= 1024.0
return f'{size:3.1f} Y{suffix}'

210
basicsr/utils/options.py Normal file
View File

@@ -0,0 +1,210 @@
import argparse
import os
import random
import torch
import yaml
from collections import OrderedDict
from os import path as osp
from basicsr.utils import set_random_seed
from basicsr.utils.dist_util import get_dist_info, init_dist, master_only
def ordered_yaml():
"""Support OrderedDict for yaml.
Returns:
tuple: yaml Loader and Dumper.
"""
try:
from yaml import CDumper as Dumper
from yaml import CLoader as Loader
except ImportError:
from yaml import Dumper, Loader
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
def dict_representer(dumper, data):
return dumper.represent_dict(data.items())
def dict_constructor(loader, node):
return OrderedDict(loader.construct_pairs(node))
Dumper.add_representer(OrderedDict, dict_representer)
Loader.add_constructor(_mapping_tag, dict_constructor)
return Loader, Dumper
def yaml_load(f):
"""Load yaml file or string.
Args:
f (str): File path or a python string.
Returns:
dict: Loaded dict.
"""
if os.path.isfile(f):
with open(f, 'r') as f:
return yaml.load(f, Loader=ordered_yaml()[0])
else:
return yaml.load(f, Loader=ordered_yaml()[0])
def dict2str(opt, indent_level=1):
"""dict to string for printing options.
Args:
opt (dict): Option dict.
indent_level (int): Indent level. Default: 1.
Return:
(str): Option string for printing.
"""
msg = '\n'
for k, v in opt.items():
if isinstance(v, dict):
msg += ' ' * (indent_level * 2) + k + ':['
msg += dict2str(v, indent_level + 1)
msg += ' ' * (indent_level * 2) + ']\n'
else:
msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
return msg
def _postprocess_yml_value(value):
# None
if value == '~' or value.lower() == 'none':
return None
# bool
if value.lower() == 'true':
return True
elif value.lower() == 'false':
return False
# !!float number
if value.startswith('!!float'):
return float(value.replace('!!float', ''))
# number
if value.isdigit():
return int(value)
elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
return float(value)
# list
if value.startswith('['):
return eval(value)
# str
return value
def parse_options(root_path, is_train=True):
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
parser.add_argument('--auto_resume', action='store_true')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
args = parser.parse_args()
# parse yml to dict
opt = yaml_load(args.opt)
# distributed settings
if args.launcher == 'none':
opt['dist'] = False
print('Disable distributed.', flush=True)
else:
opt['dist'] = True
if args.launcher == 'slurm' and 'dist_params' in opt:
init_dist(args.launcher, **opt['dist_params'])
else:
init_dist(args.launcher)
opt['rank'], opt['world_size'] = get_dist_info()
# random seed
seed = opt.get('manual_seed')
if seed is None:
seed = random.randint(1, 10000)
opt['manual_seed'] = seed
set_random_seed(seed + opt['rank'])
# force to update yml options
if args.force_yml is not None:
for entry in args.force_yml:
# now do not support creating new keys
keys, value = entry.split('=')
keys, value = keys.strip(), value.strip()
value = _postprocess_yml_value(value)
eval_str = 'opt'
for key in keys.split(':'):
eval_str += f'["{key}"]'
eval_str += '=value'
# using exec function
exec(eval_str)
opt['auto_resume'] = args.auto_resume
opt['is_train'] = is_train
# debug setting
if args.debug and not opt['name'].startswith('debug'):
opt['name'] = 'debug_' + opt['name']
if opt['num_gpu'] == 'auto':
opt['num_gpu'] = torch.cuda.device_count()
# datasets
for phase, dataset in opt['datasets'].items():
# for multiple datasets, e.g., val_1, val_2; test_1, test_2
phase = phase.split('_')[0]
dataset['phase'] = phase
if 'scale' in opt:
dataset['scale'] = opt['scale']
if dataset.get('dataroot_gt') is not None:
dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
if dataset.get('dataroot_lq') is not None:
dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
# paths
for key, val in opt['path'].items():
if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
opt['path'][key] = osp.expanduser(val)
if is_train:
experiments_root = osp.join(root_path, 'experiments', opt['name'])
opt['path']['experiments_root'] = experiments_root
opt['path']['models'] = osp.join(experiments_root, 'models')
opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
opt['path']['log'] = experiments_root
opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
# change some options for debug mode
if 'debug' in opt['name']:
if 'val' in opt:
opt['val']['val_freq'] = 8
opt['logger']['print_freq'] = 1
opt['logger']['save_checkpoint_freq'] = 8
else: # test
results_root = osp.join(root_path, 'results', opt['name'])
opt['path']['results_root'] = results_root
opt['path']['log'] = results_root
opt['path']['visualization'] = osp.join(results_root, 'visualization')
return opt, args
@master_only
def copy_opt_file(opt_file, experiments_root):
# copy the yml file to the experiment root
import sys
import time
from shutil import copyfile
cmd = ' '.join(sys.argv)
filename = osp.join(experiments_root, osp.basename(opt_file))
copyfile(opt_file, filename)
with open(filename, 'r+') as f:
lines = f.readlines()
lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
f.seek(0)
f.writelines(lines)

View File

@@ -0,0 +1,83 @@
import re
def read_data_from_tensorboard(log_path, tag):
"""Get raw data (steps and values) from tensorboard events.
Args:
log_path (str): Path to the tensorboard log.
tag (str): tag to be read.
"""
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
# tensorboard event
event_acc = EventAccumulator(log_path)
event_acc.Reload()
scalar_list = event_acc.Tags()['scalars']
print('tag list: ', scalar_list)
steps = [int(s.step) for s in event_acc.Scalars(tag)]
values = [s.value for s in event_acc.Scalars(tag)]
return steps, values
def read_data_from_txt_2v(path, pattern, step_one=False):
"""Read data from txt with 2 returned values (usually [step, value]).
Args:
path (str): path to the txt file.
pattern (str): re (regular expression) pattern.
step_one (bool): add 1 to steps. Default: False.
"""
with open(path) as f:
lines = f.readlines()
lines = [line.strip() for line in lines]
steps = []
values = []
pattern = re.compile(pattern)
for line in lines:
match = pattern.match(line)
if match:
steps.append(int(match.group(1)))
values.append(float(match.group(2)))
if step_one:
steps = [v + 1 for v in steps]
return steps, values
def read_data_from_txt_1v(path, pattern):
"""Read data from txt with 1 returned values.
Args:
path (str): path to the txt file.
pattern (str): re (regular expression) pattern.
"""
with open(path) as f:
lines = f.readlines()
lines = [line.strip() for line in lines]
data = []
pattern = re.compile(pattern)
for line in lines:
match = pattern.match(line)
if match:
data.append(float(match.group(1)))
return data
def smooth_data(values, smooth_weight):
""" Smooth data using 1st-order IIR low-pass filter (what tensorflow does).
Reference: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704 # noqa: E501
Args:
values (list): A list of values to be smoothed.
smooth_weight (float): Smooth weight.
"""
values_sm = []
last_sm_value = values[0]
for value in values:
value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value
values_sm.append(value_sm)
last_sm_value = value_sm
return values_sm

View File

@@ -0,0 +1,293 @@
import cv2
import math
import numpy as np
import os
import queue
import threading
import torch
from basicsr.utils.download_util import load_file_from_url
from torch.nn import functional as F
# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class RealESRGANer():
"""A helper class for upsampling images with RealESRGAN.
Args:
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
model (nn.Module): The defined network. Default: None.
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
input images into tiles, and then process each of them. Finally, they will be merged into one image.
0 denotes for do not use tile. Default: 0.
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
half (float): Whether to use half precision during inference. Default: False.
"""
def __init__(self,
scale,
model_path,
model=None,
tile=0,
tile_pad=10,
pre_pad=10,
half=False,
device=None,
gpu_id=None):
self.scale = scale
self.tile_size = tile
self.tile_pad = tile_pad
self.pre_pad = pre_pad
self.mod_scale = None
self.half = half
# initialize model
if gpu_id:
self.device = torch.device(
f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
else:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
if model_path.startswith('https://'):
model_path = load_file_from_url(
url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None)
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
# prefer to use params_ema
if 'params_ema' in loadnet:
keyname = 'params_ema'
else:
keyname = 'params'
model.load_state_dict(loadnet[keyname], strict=True)
model.eval()
self.model = model.to(self.device)
if self.half:
self.model = self.model.half()
def pre_process(self, img):
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
"""
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
self.img = img.unsqueeze(0).to(self.device)
if self.half:
self.img = self.img.half()
# pre_pad
if self.pre_pad != 0:
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
# mod pad for divisible borders
if self.scale == 2:
self.mod_scale = 2
elif self.scale == 1:
self.mod_scale = 4
if self.mod_scale is not None:
self.mod_pad_h, self.mod_pad_w = 0, 0
_, _, h, w = self.img.size()
if (h % self.mod_scale != 0):
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
if (w % self.mod_scale != 0):
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
def process(self):
# model inference
self.output = self.model(self.img)
def tile_process(self):
"""It will first crop input images to tiles, and then process each tile.
Finally, all the processed tiles are merged into one images.
Modified from: https://github.com/ata4/esrgan-launcher
"""
batch, channel, height, width = self.img.shape
output_height = height * self.scale
output_width = width * self.scale
output_shape = (batch, channel, output_height, output_width)
# start with black image
self.output = self.img.new_zeros(output_shape)
tiles_x = math.ceil(width / self.tile_size)
tiles_y = math.ceil(height / self.tile_size)
# loop over all tiles
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * self.tile_size
ofs_y = y * self.tile_size
# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + self.tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + self.tile_size, height)
# input tile area on total image with padding
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
input_end_x_pad = min(input_end_x + self.tile_pad, width)
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
input_end_y_pad = min(input_end_y + self.tile_pad, height)
# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x + 1
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
# upscale tile
try:
with torch.no_grad():
output_tile = self.model(input_tile)
except RuntimeError as error:
print('Error', error)
# print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
# output tile area on total image
output_start_x = input_start_x * self.scale
output_end_x = input_end_x * self.scale
output_start_y = input_start_y * self.scale
output_end_y = input_end_y * self.scale
# output tile area without padding
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
# put tile into output image
self.output[:, :, output_start_y:output_end_y,
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile]
def post_process(self):
# remove extra pad
if self.mod_scale is not None:
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
# remove prepad
if self.pre_pad != 0:
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
return self.output
@torch.no_grad()
def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
h_input, w_input = img.shape[0:2]
# img: numpy
img = img.astype(np.float32)
if np.max(img) > 256: # 16-bit image
max_range = 65535
print('\tInput is a 16-bit image')
else:
max_range = 255
img = img / max_range
if len(img.shape) == 2: # gray image
img_mode = 'L'
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 4: # RGBA image with alpha channel
img_mode = 'RGBA'
alpha = img[:, :, 3]
img = img[:, :, 0:3]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if alpha_upsampler == 'realesrgan':
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
else:
img_mode = 'RGB'
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# ------------------- process image (without the alpha channel) ------------------- #
self.pre_process(img)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_img = self.post_process()
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
if img_mode == 'L':
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
# ------------------- process the alpha channel if necessary ------------------- #
if img_mode == 'RGBA':
if alpha_upsampler == 'realesrgan':
self.pre_process(alpha)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_alpha = self.post_process()
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
else: # use the cv2 resize for alpha channel
h, w = alpha.shape[0:2]
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
# merge the alpha channel
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
output_img[:, :, 3] = output_alpha
# ------------------------------ return ------------------------------ #
if max_range == 65535: # 16-bit image
output = (output_img * 65535.0).round().astype(np.uint16)
else:
output = (output_img * 255.0).round().astype(np.uint8)
if outscale is not None and outscale != float(self.scale):
output = cv2.resize(
output, (
int(w_input * outscale),
int(h_input * outscale),
), interpolation=cv2.INTER_LANCZOS4)
return output, img_mode
class PrefetchReader(threading.Thread):
"""Prefetch images.
Args:
img_list (list[str]): A image list of image paths to be read.
num_prefetch_queue (int): Number of prefetch queue.
"""
def __init__(self, img_list, num_prefetch_queue):
super().__init__()
self.que = queue.Queue(num_prefetch_queue)
self.img_list = img_list
def run(self):
for img_path in self.img_list:
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
self.que.put(img)
self.que.put(None)
def __next__(self):
next_item = self.que.get()
if next_item is None:
raise StopIteration
return next_item
def __iter__(self):
return self
class IOConsumer(threading.Thread):
def __init__(self, opt, que, qid):
super().__init__()
self._queue = que
self.qid = qid
self.opt = opt
def run(self):
while True:
msg = self._queue.get()
if isinstance(msg, str) and msg == 'quit':
break
output = msg['output']
save_path = msg['save_path']
cv2.imwrite(save_path, output)
print(f'IO worker {self.qid} is done.')

88
basicsr/utils/registry.py Normal file
View File

@@ -0,0 +1,88 @@
# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
class Registry():
"""
The registry that provides name -> object mapping, to support third-party
users' custom modules.
To create a registry (e.g. a backbone registry):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone():
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""
def __init__(self, name):
"""
Args:
name (str): the name of this registry
"""
self._name = name
self._obj_map = {}
def _do_register(self, name, obj, suffix=None):
if isinstance(suffix, str):
name = name + '_' + suffix
assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
f"in '{self._name}' registry!")
self._obj_map[name] = obj
def register(self, obj=None, suffix=None):
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not.
See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class):
name = func_or_class.__name__
self._do_register(name, func_or_class, suffix)
return func_or_class
return deco
# used as a function call
name = obj.__name__
self._do_register(name, obj, suffix)
def get(self, name, suffix='basicsr'):
ret = self._obj_map.get(name)
if ret is None:
ret = self._obj_map.get(name + '_' + suffix)
print(f'Name {name} is not found, use name: {name}_{suffix}!')
if ret is None:
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
return ret
def __contains__(self, name):
return name in self._obj_map
def __iter__(self):
return iter(self._obj_map.items())
def keys(self):
return self._obj_map.keys()
DATASET_REGISTRY = Registry('dataset')
ARCH_REGISTRY = Registry('arch')
MODEL_REGISTRY = Registry('model')
LOSS_REGISTRY = Registry('loss')
METRIC_REGISTRY = Registry('metric')