mirror of
https://github.com/aljazceru/InvSR.git
synced 2025-12-18 14:54:24 +01:00
first commit
This commit is contained in:
BIN
basicsr/.DS_Store
vendored
Normal file
BIN
basicsr/.DS_Store
vendored
Normal file
Binary file not shown.
4
basicsr/__init__.py
Normal file
4
basicsr/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# https://github.com/xinntao/BasicSR
|
||||
# flake8: noqa
|
||||
from .data import *
|
||||
from .utils import *
|
||||
BIN
basicsr/data/.DS_Store
vendored
Normal file
BIN
basicsr/data/.DS_Store
vendored
Normal file
Binary file not shown.
101
basicsr/data/__init__.py
Normal file
101
basicsr/data/__init__.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import importlib
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from os import path as osp
|
||||
|
||||
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
|
||||
from basicsr.utils import get_root_logger, scandir
|
||||
from basicsr.utils.dist_util import get_dist_info
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
|
||||
__all__ = ['build_dataset', 'build_dataloader']
|
||||
|
||||
# automatically scan and import dataset modules for registry
|
||||
# scan all the files under the data folder with '_dataset' in file names
|
||||
data_folder = osp.dirname(osp.abspath(__file__))
|
||||
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
||||
# import all the dataset modules
|
||||
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
|
||||
|
||||
|
||||
def build_dataset(dataset_opt):
|
||||
"""Build dataset from options.
|
||||
|
||||
Args:
|
||||
dataset_opt (dict): Configuration for dataset. It must contain:
|
||||
name (str): Dataset name.
|
||||
type (str): Dataset type.
|
||||
"""
|
||||
dataset_opt = deepcopy(dataset_opt)
|
||||
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
|
||||
return dataset
|
||||
|
||||
|
||||
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
|
||||
"""Build dataloader.
|
||||
|
||||
Args:
|
||||
dataset (torch.utils.data.Dataset): Dataset.
|
||||
dataset_opt (dict): Dataset options. It contains the following keys:
|
||||
phase (str): 'train' or 'val'.
|
||||
num_worker_per_gpu (int): Number of workers for each GPU.
|
||||
batch_size_per_gpu (int): Training batch size for each GPU.
|
||||
num_gpu (int): Number of GPUs. Used only in the train phase.
|
||||
Default: 1.
|
||||
dist (bool): Whether in distributed training. Used only in the train
|
||||
phase. Default: False.
|
||||
sampler (torch.utils.data.sampler): Data sampler. Default: None.
|
||||
seed (int | None): Seed. Default: None
|
||||
"""
|
||||
phase = dataset_opt['phase']
|
||||
rank, _ = get_dist_info()
|
||||
if phase == 'train':
|
||||
if dist: # distributed training
|
||||
batch_size = dataset_opt['batch_size_per_gpu']
|
||||
num_workers = dataset_opt['num_worker_per_gpu']
|
||||
else: # non-distributed training
|
||||
multiplier = 1 if num_gpu == 0 else num_gpu
|
||||
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
|
||||
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
|
||||
dataloader_args = dict(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
sampler=sampler,
|
||||
drop_last=True)
|
||||
if sampler is None:
|
||||
dataloader_args['shuffle'] = True
|
||||
dataloader_args['worker_init_fn'] = partial(
|
||||
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
|
||||
elif phase in ['val', 'test']: # validation
|
||||
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
||||
else:
|
||||
raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
|
||||
|
||||
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
|
||||
dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
|
||||
|
||||
prefetch_mode = dataset_opt.get('prefetch_mode')
|
||||
if prefetch_mode == 'cpu': # CPUPrefetcher
|
||||
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
|
||||
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
|
||||
else:
|
||||
# prefetch_mode=None: Normal dataloader
|
||||
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
|
||||
return torch.utils.data.DataLoader(**dataloader_args)
|
||||
|
||||
|
||||
def worker_init_fn(worker_id, num_workers, rank, seed):
|
||||
# Set the worker seed to num_workers * rank + worker_id + seed
|
||||
worker_seed = num_workers * rank + worker_id + seed
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
BIN
basicsr/data/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/data_util.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/data_util.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/data_util.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/data_util.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/degradations.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/degradations.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/degradations.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/degradations.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/ffhq_dataset.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/ffhq_dataset.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/ffhq_dataset.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/ffhq_dataset.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/paired_image_dataset.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/paired_image_dataset.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/paired_image_dataset.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/paired_image_dataset.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/prefetch_dataloader.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/prefetch_dataloader.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/prefetch_dataloader.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/prefetch_dataloader.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/realesrgan_dataset.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/realesrgan_dataset.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/realesrgan_dataset.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/realesrgan_dataset.cpython-38.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
basicsr/data/__pycache__/reds_dataset.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/reds_dataset.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/reds_dataset.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/reds_dataset.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/single_image_dataset.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/single_image_dataset.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/single_image_dataset.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/single_image_dataset.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/transforms.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/transforms.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/transforms.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/transforms.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/video_test_dataset.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/video_test_dataset.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/video_test_dataset.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/video_test_dataset.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/vimeo90k_dataset.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/vimeo90k_dataset.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/data/__pycache__/vimeo90k_dataset.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/vimeo90k_dataset.cpython-38.pyc
Normal file
Binary file not shown.
48
basicsr/data/data_sampler.py
Normal file
48
basicsr/data/data_sampler.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import math
|
||||
import torch
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
|
||||
class EnlargedSampler(Sampler):
|
||||
"""Sampler that restricts data loading to a subset of the dataset.
|
||||
|
||||
Modified from torch.utils.data.distributed.DistributedSampler
|
||||
Support enlarging the dataset for iteration-based training, for saving
|
||||
time when restart the dataloader after each epoch
|
||||
|
||||
Args:
|
||||
dataset (torch.utils.data.Dataset): Dataset used for sampling.
|
||||
num_replicas (int | None): Number of processes participating in
|
||||
the training. It is usually the world_size.
|
||||
rank (int | None): Rank of the current process within num_replicas.
|
||||
ratio (int): Enlarging ratio. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas, rank, ratio=1):
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = torch.randperm(self.total_size, generator=g).tolist()
|
||||
|
||||
dataset_size = len(self.dataset)
|
||||
indices = [v % dataset_size for v in indices]
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
315
basicsr/data/data_util.py
Normal file
315
basicsr/data/data_util.py
Normal file
@@ -0,0 +1,315 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from os import path as osp
|
||||
from torch.nn import functional as F
|
||||
|
||||
from basicsr.data.transforms import mod_crop
|
||||
from basicsr.utils import img2tensor, scandir
|
||||
|
||||
|
||||
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
|
||||
"""Read a sequence of images from a given folder path.
|
||||
|
||||
Args:
|
||||
path (list[str] | str): List of image paths or image folder path.
|
||||
require_mod_crop (bool): Require mod crop for each image.
|
||||
Default: False.
|
||||
scale (int): Scale factor for mod_crop. Default: 1.
|
||||
return_imgname(bool): Whether return image names. Default False.
|
||||
|
||||
Returns:
|
||||
Tensor: size (t, c, h, w), RGB, [0, 1].
|
||||
list[str]: Returned image name list.
|
||||
"""
|
||||
if isinstance(path, list):
|
||||
img_paths = path
|
||||
else:
|
||||
img_paths = sorted(list(scandir(path, full_path=True)))
|
||||
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
|
||||
|
||||
if require_mod_crop:
|
||||
imgs = [mod_crop(img, scale) for img in imgs]
|
||||
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
|
||||
imgs = torch.stack(imgs, dim=0)
|
||||
|
||||
if return_imgname:
|
||||
imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
|
||||
return imgs, imgnames
|
||||
else:
|
||||
return imgs
|
||||
|
||||
|
||||
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
|
||||
"""Generate an index list for reading `num_frames` frames from a sequence
|
||||
of images.
|
||||
|
||||
Args:
|
||||
crt_idx (int): Current center index.
|
||||
max_frame_num (int): Max number of the sequence of images (from 1).
|
||||
num_frames (int): Reading num_frames frames.
|
||||
padding (str): Padding mode, one of
|
||||
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
|
||||
Examples: current_idx = 0, num_frames = 5
|
||||
The generated frame indices under different padding mode:
|
||||
replicate: [0, 0, 0, 1, 2]
|
||||
reflection: [2, 1, 0, 1, 2]
|
||||
reflection_circle: [4, 3, 0, 1, 2]
|
||||
circle: [3, 4, 0, 1, 2]
|
||||
|
||||
Returns:
|
||||
list[int]: A list of indices.
|
||||
"""
|
||||
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
|
||||
assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
|
||||
|
||||
max_frame_num = max_frame_num - 1 # start from 0
|
||||
num_pad = num_frames // 2
|
||||
|
||||
indices = []
|
||||
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
|
||||
if i < 0:
|
||||
if padding == 'replicate':
|
||||
pad_idx = 0
|
||||
elif padding == 'reflection':
|
||||
pad_idx = -i
|
||||
elif padding == 'reflection_circle':
|
||||
pad_idx = crt_idx + num_pad - i
|
||||
else:
|
||||
pad_idx = num_frames + i
|
||||
elif i > max_frame_num:
|
||||
if padding == 'replicate':
|
||||
pad_idx = max_frame_num
|
||||
elif padding == 'reflection':
|
||||
pad_idx = max_frame_num * 2 - i
|
||||
elif padding == 'reflection_circle':
|
||||
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
|
||||
else:
|
||||
pad_idx = i - num_frames
|
||||
else:
|
||||
pad_idx = i
|
||||
indices.append(pad_idx)
|
||||
return indices
|
||||
|
||||
|
||||
def paired_paths_from_lmdb(folders, keys):
|
||||
"""Generate paired paths from lmdb files.
|
||||
|
||||
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
|
||||
|
||||
::
|
||||
|
||||
lq.lmdb
|
||||
├── data.mdb
|
||||
├── lock.mdb
|
||||
├── meta_info.txt
|
||||
|
||||
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
||||
https://lmdb.readthedocs.io/en/release/ for more details.
|
||||
|
||||
The meta_info.txt is a specified txt file to record the meta information
|
||||
of our datasets. It will be automatically created when preparing
|
||||
datasets by our provided dataset tools.
|
||||
Each line in the txt file records
|
||||
1)image name (with extension),
|
||||
2)image shape,
|
||||
3)compression level, separated by a white space.
|
||||
Example: `baboon.png (120,125,3) 1`
|
||||
|
||||
We use the image name without extension as the lmdb key.
|
||||
Note that we use the same key for the corresponding lq and gt images.
|
||||
|
||||
Args:
|
||||
folders (list[str]): A list of folder path. The order of list should
|
||||
be [input_folder, gt_folder].
|
||||
keys (list[str]): A list of keys identifying folders. The order should
|
||||
be in consistent with folders, e.g., ['lq', 'gt'].
|
||||
Note that this key is different from lmdb keys.
|
||||
|
||||
Returns:
|
||||
list[str]: Returned path list.
|
||||
"""
|
||||
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
||||
f'But got {len(folders)}')
|
||||
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
||||
input_folder, gt_folder = folders
|
||||
input_key, gt_key = keys
|
||||
|
||||
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
|
||||
raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
|
||||
f'formats. But received {input_key}: {input_folder}; '
|
||||
f'{gt_key}: {gt_folder}')
|
||||
# ensure that the two meta_info files are the same
|
||||
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
|
||||
input_lmdb_keys = [line.split('.')[0] for line in fin]
|
||||
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
|
||||
gt_lmdb_keys = [line.split('.')[0] for line in fin]
|
||||
if set(input_lmdb_keys) != set(gt_lmdb_keys):
|
||||
raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
|
||||
else:
|
||||
paths = []
|
||||
for lmdb_key in sorted(input_lmdb_keys):
|
||||
paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
|
||||
return paths
|
||||
|
||||
|
||||
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
|
||||
"""Generate paired paths from an meta information file.
|
||||
|
||||
Each line in the meta information file contains the image names and
|
||||
image shape (usually for gt), separated by a white space.
|
||||
|
||||
Example of an meta information file:
|
||||
```
|
||||
0001_s001.png (480,480,3)
|
||||
0001_s002.png (480,480,3)
|
||||
```
|
||||
|
||||
Args:
|
||||
folders (list[str]): A list of folder path. The order of list should
|
||||
be [input_folder, gt_folder].
|
||||
keys (list[str]): A list of keys identifying folders. The order should
|
||||
be in consistent with folders, e.g., ['lq', 'gt'].
|
||||
meta_info_file (str): Path to the meta information file.
|
||||
filename_tmpl (str): Template for each filename. Note that the
|
||||
template excludes the file extension. Usually the filename_tmpl is
|
||||
for files in the input folder.
|
||||
|
||||
Returns:
|
||||
list[str]: Returned path list.
|
||||
"""
|
||||
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
||||
f'But got {len(folders)}')
|
||||
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
||||
input_folder, gt_folder = folders
|
||||
input_key, gt_key = keys
|
||||
|
||||
with open(meta_info_file, 'r') as fin:
|
||||
gt_names = [line.strip().split(' ')[0] for line in fin]
|
||||
|
||||
paths = []
|
||||
for gt_name in gt_names:
|
||||
basename, ext = osp.splitext(osp.basename(gt_name))
|
||||
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
||||
input_path = osp.join(input_folder, input_name)
|
||||
gt_path = osp.join(gt_folder, gt_name)
|
||||
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
||||
return paths
|
||||
|
||||
|
||||
def paired_paths_from_folder(folders, keys, filename_tmpl):
|
||||
"""Generate paired paths from folders.
|
||||
|
||||
Args:
|
||||
folders (list[str]): A list of folder path. The order of list should
|
||||
be [input_folder, gt_folder].
|
||||
keys (list[str]): A list of keys identifying folders. The order should
|
||||
be in consistent with folders, e.g., ['lq', 'gt'].
|
||||
filename_tmpl (str): Template for each filename. Note that the
|
||||
template excludes the file extension. Usually the filename_tmpl is
|
||||
for files in the input folder.
|
||||
|
||||
Returns:
|
||||
list[str]: Returned path list.
|
||||
"""
|
||||
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
||||
f'But got {len(folders)}')
|
||||
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
||||
input_folder, gt_folder = folders
|
||||
input_key, gt_key = keys
|
||||
|
||||
input_paths = list(scandir(input_folder))
|
||||
gt_paths = list(scandir(gt_folder))
|
||||
assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
|
||||
f'{len(input_paths)}, {len(gt_paths)}.')
|
||||
paths = []
|
||||
for gt_path in gt_paths:
|
||||
basename, ext = osp.splitext(osp.basename(gt_path))
|
||||
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
||||
input_path = osp.join(input_folder, input_name)
|
||||
assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
|
||||
gt_path = osp.join(gt_folder, gt_path)
|
||||
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
||||
return paths
|
||||
|
||||
|
||||
def paths_from_folder(folder):
|
||||
"""Generate paths from folder.
|
||||
|
||||
Args:
|
||||
folder (str): Folder path.
|
||||
|
||||
Returns:
|
||||
list[str]: Returned path list.
|
||||
"""
|
||||
|
||||
paths = list(scandir(folder))
|
||||
paths = [osp.join(folder, path) for path in paths]
|
||||
return paths
|
||||
|
||||
|
||||
def paths_from_lmdb(folder):
|
||||
"""Generate paths from lmdb.
|
||||
|
||||
Args:
|
||||
folder (str): Folder path.
|
||||
|
||||
Returns:
|
||||
list[str]: Returned path list.
|
||||
"""
|
||||
if not folder.endswith('.lmdb'):
|
||||
raise ValueError(f'Folder {folder}folder should in lmdb format.')
|
||||
with open(osp.join(folder, 'meta_info.txt')) as fin:
|
||||
paths = [line.split('.')[0] for line in fin]
|
||||
return paths
|
||||
|
||||
|
||||
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
|
||||
"""Generate Gaussian kernel used in `duf_downsample`.
|
||||
|
||||
Args:
|
||||
kernel_size (int): Kernel size. Default: 13.
|
||||
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
|
||||
|
||||
Returns:
|
||||
np.array: The Gaussian kernel.
|
||||
"""
|
||||
from scipy.ndimage import filters as filters
|
||||
kernel = np.zeros((kernel_size, kernel_size))
|
||||
# set element at the middle to one, a dirac delta
|
||||
kernel[kernel_size // 2, kernel_size // 2] = 1
|
||||
# gaussian-smooth the dirac, resulting in a gaussian filter
|
||||
return filters.gaussian_filter(kernel, sigma)
|
||||
|
||||
|
||||
def duf_downsample(x, kernel_size=13, scale=4):
|
||||
"""Downsamping with Gaussian kernel used in the DUF official code.
|
||||
|
||||
Args:
|
||||
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
|
||||
kernel_size (int): Kernel size. Default: 13.
|
||||
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
|
||||
Default: 4.
|
||||
|
||||
Returns:
|
||||
Tensor: DUF downsampled frames.
|
||||
"""
|
||||
assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
|
||||
|
||||
squeeze_flag = False
|
||||
if x.ndim == 4:
|
||||
squeeze_flag = True
|
||||
x = x.unsqueeze(0)
|
||||
b, t, c, h, w = x.size()
|
||||
x = x.view(-1, 1, h, w)
|
||||
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
|
||||
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
|
||||
|
||||
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
|
||||
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
|
||||
x = F.conv2d(x, gaussian_filter, stride=scale)
|
||||
x = x[:, :, 2:-2, 2:-2]
|
||||
x = x.view(b, t, c, x.size(2), x.size(3))
|
||||
if squeeze_flag:
|
||||
x = x.squeeze(0)
|
||||
return x
|
||||
765
basicsr/data/degradations.py
Normal file
765
basicsr/data/degradations.py
Normal file
@@ -0,0 +1,765 @@
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
from scipy import special
|
||||
from scipy.stats import multivariate_normal
|
||||
# from torchvision.transforms.functional_tensor import rgb_to_grayscale
|
||||
from torchvision.transforms.functional import rgb_to_grayscale
|
||||
|
||||
# -------------------------------------------------------------------- #
|
||||
# --------------------------- blur kernels --------------------------- #
|
||||
# -------------------------------------------------------------------- #
|
||||
|
||||
|
||||
# --------------------------- util functions --------------------------- #
|
||||
def sigma_matrix2(sig_x, sig_y, theta):
|
||||
"""Calculate the rotated sigma matrix (two dimensional matrix).
|
||||
|
||||
Args:
|
||||
sig_x (float):
|
||||
sig_y (float):
|
||||
theta (float): Radian measurement.
|
||||
|
||||
Returns:
|
||||
ndarray: Rotated sigma matrix.
|
||||
"""
|
||||
d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
|
||||
u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
||||
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
|
||||
|
||||
|
||||
def mesh_grid(kernel_size):
|
||||
"""Generate the mesh grid, centering at zero.
|
||||
|
||||
Args:
|
||||
kernel_size (int):
|
||||
|
||||
Returns:
|
||||
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
|
||||
xx (ndarray): with the shape (kernel_size, kernel_size)
|
||||
yy (ndarray): with the shape (kernel_size, kernel_size)
|
||||
"""
|
||||
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
|
||||
xx, yy = np.meshgrid(ax, ax)
|
||||
xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
|
||||
1))).reshape(kernel_size, kernel_size, 2)
|
||||
return xy, xx, yy
|
||||
|
||||
|
||||
def pdf2(sigma_matrix, grid):
|
||||
"""Calculate PDF of the bivariate Gaussian distribution.
|
||||
|
||||
Args:
|
||||
sigma_matrix (ndarray): with the shape (2, 2)
|
||||
grid (ndarray): generated by :func:`mesh_grid`,
|
||||
with the shape (K, K, 2), K is the kernel size.
|
||||
|
||||
Returns:
|
||||
kernel (ndarrray): un-normalized kernel.
|
||||
"""
|
||||
inverse_sigma = np.linalg.inv(sigma_matrix)
|
||||
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
|
||||
return kernel
|
||||
|
||||
|
||||
def cdf2(d_matrix, grid):
|
||||
"""Calculate the CDF of the standard bivariate Gaussian distribution.
|
||||
Used in skewed Gaussian distribution.
|
||||
|
||||
Args:
|
||||
d_matrix (ndarrasy): skew matrix.
|
||||
grid (ndarray): generated by :func:`mesh_grid`,
|
||||
with the shape (K, K, 2), K is the kernel size.
|
||||
|
||||
Returns:
|
||||
cdf (ndarray): skewed cdf.
|
||||
"""
|
||||
rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
|
||||
grid = np.dot(grid, d_matrix)
|
||||
cdf = rv.cdf(grid)
|
||||
return cdf
|
||||
|
||||
|
||||
def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
|
||||
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
|
||||
|
||||
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
||||
|
||||
Args:
|
||||
kernel_size (int):
|
||||
sig_x (float):
|
||||
sig_y (float):
|
||||
theta (float): Radian measurement.
|
||||
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
||||
with the shape (K, K, 2), K is the kernel size. Default: None
|
||||
isotropic (bool):
|
||||
|
||||
Returns:
|
||||
kernel (ndarray): normalized kernel.
|
||||
"""
|
||||
if grid is None:
|
||||
grid, _, _ = mesh_grid(kernel_size)
|
||||
if isotropic:
|
||||
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
||||
else:
|
||||
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
||||
kernel = pdf2(sigma_matrix, grid)
|
||||
kernel = kernel / np.sum(kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
|
||||
"""Generate a bivariate generalized Gaussian kernel.
|
||||
|
||||
``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions``
|
||||
|
||||
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
||||
|
||||
Args:
|
||||
kernel_size (int):
|
||||
sig_x (float):
|
||||
sig_y (float):
|
||||
theta (float): Radian measurement.
|
||||
beta (float): shape parameter, beta = 1 is the normal distribution.
|
||||
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
||||
with the shape (K, K, 2), K is the kernel size. Default: None
|
||||
|
||||
Returns:
|
||||
kernel (ndarray): normalized kernel.
|
||||
"""
|
||||
if grid is None:
|
||||
grid, _, _ = mesh_grid(kernel_size)
|
||||
if isotropic:
|
||||
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
||||
else:
|
||||
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
||||
inverse_sigma = np.linalg.inv(sigma_matrix)
|
||||
kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
|
||||
kernel = kernel / np.sum(kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
|
||||
"""Generate a plateau-like anisotropic kernel.
|
||||
|
||||
1 / (1+x^(beta))
|
||||
|
||||
Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
|
||||
|
||||
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
||||
|
||||
Args:
|
||||
kernel_size (int):
|
||||
sig_x (float):
|
||||
sig_y (float):
|
||||
theta (float): Radian measurement.
|
||||
beta (float): shape parameter, beta = 1 is the normal distribution.
|
||||
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
||||
with the shape (K, K, 2), K is the kernel size. Default: None
|
||||
|
||||
Returns:
|
||||
kernel (ndarray): normalized kernel.
|
||||
"""
|
||||
if grid is None:
|
||||
grid, _, _ = mesh_grid(kernel_size)
|
||||
if isotropic:
|
||||
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
||||
else:
|
||||
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
||||
inverse_sigma = np.linalg.inv(sigma_matrix)
|
||||
kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
|
||||
kernel = kernel / np.sum(kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def random_bivariate_Gaussian(kernel_size,
|
||||
sigma_x_range,
|
||||
sigma_y_range,
|
||||
rotation_range,
|
||||
noise_range=None,
|
||||
isotropic=True):
|
||||
"""Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
|
||||
|
||||
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
||||
|
||||
Args:
|
||||
kernel_size (int):
|
||||
sigma_x_range (tuple): [0.6, 5]
|
||||
sigma_y_range (tuple): [0.6, 5]
|
||||
rotation range (tuple): [-math.pi, math.pi]
|
||||
noise_range(tuple, optional): multiplicative kernel noise,
|
||||
[0.75, 1.25]. Default: None
|
||||
|
||||
Returns:
|
||||
kernel (ndarray):
|
||||
"""
|
||||
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
||||
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
||||
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
||||
if isotropic is False:
|
||||
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
||||
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
||||
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
||||
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
||||
else:
|
||||
sigma_y = sigma_x
|
||||
rotation = 0
|
||||
|
||||
kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
|
||||
|
||||
# add multiplicative noise
|
||||
if noise_range is not None:
|
||||
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
||||
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
||||
kernel = kernel * noise
|
||||
kernel = kernel / np.sum(kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def random_bivariate_generalized_Gaussian(kernel_size,
|
||||
sigma_x_range,
|
||||
sigma_y_range,
|
||||
rotation_range,
|
||||
beta_range,
|
||||
noise_range=None,
|
||||
isotropic=True):
|
||||
"""Randomly generate bivariate generalized Gaussian kernels.
|
||||
|
||||
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
||||
|
||||
Args:
|
||||
kernel_size (int):
|
||||
sigma_x_range (tuple): [0.6, 5]
|
||||
sigma_y_range (tuple): [0.6, 5]
|
||||
rotation range (tuple): [-math.pi, math.pi]
|
||||
beta_range (tuple): [0.5, 8]
|
||||
noise_range(tuple, optional): multiplicative kernel noise,
|
||||
[0.75, 1.25]. Default: None
|
||||
|
||||
Returns:
|
||||
kernel (ndarray):
|
||||
"""
|
||||
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
||||
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
||||
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
||||
if isotropic is False:
|
||||
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
||||
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
||||
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
||||
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
||||
else:
|
||||
sigma_y = sigma_x
|
||||
rotation = 0
|
||||
|
||||
# assume beta_range[0] < 1 < beta_range[1]
|
||||
if np.random.uniform() < 0.5:
|
||||
beta = np.random.uniform(beta_range[0], 1)
|
||||
else:
|
||||
beta = np.random.uniform(1, beta_range[1])
|
||||
|
||||
kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
|
||||
|
||||
# add multiplicative noise
|
||||
if noise_range is not None:
|
||||
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
||||
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
||||
kernel = kernel * noise
|
||||
kernel = kernel / np.sum(kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def random_bivariate_plateau(kernel_size,
|
||||
sigma_x_range,
|
||||
sigma_y_range,
|
||||
rotation_range,
|
||||
beta_range,
|
||||
noise_range=None,
|
||||
isotropic=True):
|
||||
"""Randomly generate bivariate plateau kernels.
|
||||
|
||||
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
||||
|
||||
Args:
|
||||
kernel_size (int):
|
||||
sigma_x_range (tuple): [0.6, 5]
|
||||
sigma_y_range (tuple): [0.6, 5]
|
||||
rotation range (tuple): [-math.pi/2, math.pi/2]
|
||||
beta_range (tuple): [1, 4]
|
||||
noise_range(tuple, optional): multiplicative kernel noise,
|
||||
[0.75, 1.25]. Default: None
|
||||
|
||||
Returns:
|
||||
kernel (ndarray):
|
||||
"""
|
||||
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
||||
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
||||
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
||||
if isotropic is False:
|
||||
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
||||
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
||||
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
||||
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
||||
else:
|
||||
sigma_y = sigma_x
|
||||
rotation = 0
|
||||
|
||||
# TODO: this may be not proper
|
||||
if np.random.uniform() < 0.5:
|
||||
beta = np.random.uniform(beta_range[0], 1)
|
||||
else:
|
||||
beta = np.random.uniform(1, beta_range[1])
|
||||
|
||||
kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
|
||||
# add multiplicative noise
|
||||
if noise_range is not None:
|
||||
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
||||
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
||||
kernel = kernel * noise
|
||||
kernel = kernel / np.sum(kernel)
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
def random_mixed_kernels(kernel_list,
|
||||
kernel_prob,
|
||||
kernel_size=21,
|
||||
sigma_x_range=(0.6, 5),
|
||||
sigma_y_range=(0.6, 5),
|
||||
rotation_range=(-math.pi, math.pi),
|
||||
betag_range=(0.5, 8),
|
||||
betap_range=(0.5, 8),
|
||||
noise_range=None):
|
||||
"""Randomly generate mixed kernels.
|
||||
|
||||
Args:
|
||||
kernel_list (tuple): a list name of kernel types,
|
||||
support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
|
||||
'plateau_aniso']
|
||||
kernel_prob (tuple): corresponding kernel probability for each
|
||||
kernel type
|
||||
kernel_size (int):
|
||||
sigma_x_range (tuple): [0.6, 5]
|
||||
sigma_y_range (tuple): [0.6, 5]
|
||||
rotation range (tuple): [-math.pi, math.pi]
|
||||
beta_range (tuple): [0.5, 8]
|
||||
noise_range(tuple, optional): multiplicative kernel noise,
|
||||
[0.75, 1.25]. Default: None
|
||||
|
||||
Returns:
|
||||
kernel (ndarray):
|
||||
"""
|
||||
kernel_type = random.choices(kernel_list, kernel_prob)[0]
|
||||
if kernel_type == 'iso':
|
||||
kernel = random_bivariate_Gaussian(
|
||||
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
|
||||
elif kernel_type == 'aniso':
|
||||
kernel = random_bivariate_Gaussian(
|
||||
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
|
||||
elif kernel_type == 'generalized_iso':
|
||||
kernel = random_bivariate_generalized_Gaussian(
|
||||
kernel_size,
|
||||
sigma_x_range,
|
||||
sigma_y_range,
|
||||
rotation_range,
|
||||
betag_range,
|
||||
noise_range=noise_range,
|
||||
isotropic=True)
|
||||
elif kernel_type == 'generalized_aniso':
|
||||
kernel = random_bivariate_generalized_Gaussian(
|
||||
kernel_size,
|
||||
sigma_x_range,
|
||||
sigma_y_range,
|
||||
rotation_range,
|
||||
betag_range,
|
||||
noise_range=noise_range,
|
||||
isotropic=False)
|
||||
elif kernel_type == 'plateau_iso':
|
||||
kernel = random_bivariate_plateau(
|
||||
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
|
||||
elif kernel_type == 'plateau_aniso':
|
||||
kernel = random_bivariate_plateau(
|
||||
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
|
||||
return kernel
|
||||
|
||||
|
||||
np.seterr(divide='ignore', invalid='ignore')
|
||||
|
||||
|
||||
def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
|
||||
"""2D sinc filter
|
||||
|
||||
Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
|
||||
|
||||
Args:
|
||||
cutoff (float): cutoff frequency in radians (pi is max)
|
||||
kernel_size (int): horizontal and vertical size, must be odd.
|
||||
pad_to (int): pad kernel size to desired size, must be odd or zero.
|
||||
"""
|
||||
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
||||
kernel = np.fromfunction(
|
||||
lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
|
||||
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
|
||||
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
|
||||
kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
|
||||
kernel = kernel / np.sum(kernel)
|
||||
if pad_to > kernel_size:
|
||||
pad_size = (pad_to - kernel_size) // 2
|
||||
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
||||
return kernel
|
||||
|
||||
|
||||
# ------------------------------------------------------------- #
|
||||
# --------------------------- noise --------------------------- #
|
||||
# ------------------------------------------------------------- #
|
||||
|
||||
# ----------------------- Gaussian Noise ----------------------- #
|
||||
|
||||
|
||||
def generate_gaussian_noise(img, sigma=10, gray_noise=False):
|
||||
"""Generate Gaussian noise.
|
||||
|
||||
Args:
|
||||
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
||||
sigma (float): Noise scale (measured in range 255). Default: 10.
|
||||
|
||||
Returns:
|
||||
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
if gray_noise:
|
||||
noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
|
||||
noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
|
||||
else:
|
||||
noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
|
||||
return noise
|
||||
|
||||
|
||||
def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
|
||||
"""Add Gaussian noise.
|
||||
|
||||
Args:
|
||||
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
||||
sigma (float): Noise scale (measured in range 255). Default: 10.
|
||||
|
||||
Returns:
|
||||
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
noise = generate_gaussian_noise(img, sigma, gray_noise)
|
||||
out = img + noise
|
||||
if clip and rounds:
|
||||
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
||||
elif clip:
|
||||
out = np.clip(out, 0, 1)
|
||||
elif rounds:
|
||||
out = (out * 255.0).round() / 255.
|
||||
return out
|
||||
|
||||
|
||||
def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
|
||||
"""Add Gaussian noise (PyTorch version).
|
||||
|
||||
Args:
|
||||
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
|
||||
scale (float | Tensor): Noise scale. Default: 1.0.
|
||||
|
||||
Returns:
|
||||
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
b, _, h, w = img.size()
|
||||
if not isinstance(sigma, (float, int)):
|
||||
sigma = sigma.view(img.size(0), 1, 1, 1)
|
||||
if isinstance(gray_noise, (float, int)):
|
||||
cal_gray_noise = gray_noise > 0
|
||||
else:
|
||||
gray_noise = gray_noise.view(b, 1, 1, 1)
|
||||
cal_gray_noise = torch.sum(gray_noise) > 0
|
||||
|
||||
if cal_gray_noise:
|
||||
noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
|
||||
noise_gray = noise_gray.view(b, 1, h, w)
|
||||
|
||||
# always calculate color noise
|
||||
noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
|
||||
|
||||
if cal_gray_noise:
|
||||
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
|
||||
return noise
|
||||
|
||||
|
||||
def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
|
||||
"""Add Gaussian noise (PyTorch version).
|
||||
|
||||
Args:
|
||||
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
|
||||
scale (float | Tensor): Noise scale. Default: 1.0.
|
||||
|
||||
Returns:
|
||||
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
|
||||
out = img + noise
|
||||
if clip and rounds:
|
||||
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
||||
elif clip:
|
||||
out = torch.clamp(out, 0, 1)
|
||||
elif rounds:
|
||||
out = (out * 255.0).round() / 255.
|
||||
return out
|
||||
|
||||
|
||||
# ----------------------- Random Gaussian Noise ----------------------- #
|
||||
def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
|
||||
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
|
||||
if np.random.uniform() < gray_prob:
|
||||
gray_noise = True
|
||||
else:
|
||||
gray_noise = False
|
||||
return generate_gaussian_noise(img, sigma, gray_noise)
|
||||
|
||||
|
||||
def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
||||
noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
|
||||
out = img + noise
|
||||
if clip and rounds:
|
||||
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
||||
elif clip:
|
||||
out = np.clip(out, 0, 1)
|
||||
elif rounds:
|
||||
out = (out * 255.0).round() / 255.
|
||||
return out
|
||||
|
||||
|
||||
def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
|
||||
sigma = torch.rand(
|
||||
img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
|
||||
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
|
||||
gray_noise = (gray_noise < gray_prob).float()
|
||||
return generate_gaussian_noise_pt(img, sigma, gray_noise)
|
||||
|
||||
|
||||
def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
||||
noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
|
||||
out = img + noise
|
||||
if clip and rounds:
|
||||
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
||||
elif clip:
|
||||
out = torch.clamp(out, 0, 1)
|
||||
elif rounds:
|
||||
out = (out * 255.0).round() / 255.
|
||||
return out
|
||||
|
||||
|
||||
# ----------------------- Poisson (Shot) Noise ----------------------- #
|
||||
|
||||
|
||||
def generate_poisson_noise(img, scale=1.0, gray_noise=False):
|
||||
"""Generate poisson noise.
|
||||
|
||||
Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
|
||||
|
||||
Args:
|
||||
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
||||
scale (float): Noise scale. Default: 1.0.
|
||||
gray_noise (bool): Whether generate gray noise. Default: False.
|
||||
|
||||
Returns:
|
||||
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
if gray_noise:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
# round and clip image for counting vals correctly
|
||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
||||
vals = len(np.unique(img))
|
||||
vals = 2**np.ceil(np.log2(vals))
|
||||
out = np.float32(np.random.poisson(img * vals) / float(vals))
|
||||
noise = out - img
|
||||
if gray_noise:
|
||||
noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
|
||||
return noise * scale
|
||||
|
||||
|
||||
def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
|
||||
"""Add poisson noise.
|
||||
|
||||
Args:
|
||||
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
||||
scale (float): Noise scale. Default: 1.0.
|
||||
gray_noise (bool): Whether generate gray noise. Default: False.
|
||||
|
||||
Returns:
|
||||
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
noise = generate_poisson_noise(img, scale, gray_noise)
|
||||
out = img + noise
|
||||
if clip and rounds:
|
||||
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
||||
elif clip:
|
||||
out = np.clip(out, 0, 1)
|
||||
elif rounds:
|
||||
out = (out * 255.0).round() / 255.
|
||||
return out
|
||||
|
||||
|
||||
def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
|
||||
"""Generate a batch of poisson noise (PyTorch version)
|
||||
|
||||
Args:
|
||||
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
|
||||
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
|
||||
Default: 1.0.
|
||||
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
|
||||
0 for False, 1 for True. Default: 0.
|
||||
|
||||
Returns:
|
||||
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
b, _, h, w = img.size()
|
||||
if isinstance(gray_noise, (float, int)):
|
||||
cal_gray_noise = gray_noise > 0
|
||||
else:
|
||||
gray_noise = gray_noise.view(b, 1, 1, 1)
|
||||
cal_gray_noise = torch.sum(gray_noise) > 0
|
||||
if cal_gray_noise:
|
||||
img_gray = rgb_to_grayscale(img, num_output_channels=1)
|
||||
# round and clip image for counting vals correctly
|
||||
img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
|
||||
# use for-loop to get the unique values for each sample
|
||||
vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
|
||||
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
|
||||
vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
|
||||
out = torch.poisson(img_gray * vals) / vals
|
||||
noise_gray = out - img_gray
|
||||
noise_gray = noise_gray.expand(b, 3, h, w)
|
||||
|
||||
# always calculate color noise
|
||||
# round and clip image for counting vals correctly
|
||||
img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
|
||||
# use for-loop to get the unique values for each sample
|
||||
vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
|
||||
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
|
||||
vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
|
||||
out = torch.poisson(img * vals) / vals
|
||||
noise = out - img
|
||||
if cal_gray_noise:
|
||||
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
|
||||
if not isinstance(scale, (float, int)):
|
||||
scale = scale.view(b, 1, 1, 1)
|
||||
return noise * scale
|
||||
|
||||
|
||||
def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
|
||||
"""Add poisson noise to a batch of images (PyTorch version).
|
||||
|
||||
Args:
|
||||
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
|
||||
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
|
||||
Default: 1.0.
|
||||
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
|
||||
0 for False, 1 for True. Default: 0.
|
||||
|
||||
Returns:
|
||||
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
noise = generate_poisson_noise_pt(img, scale, gray_noise)
|
||||
out = img + noise
|
||||
if clip and rounds:
|
||||
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
||||
elif clip:
|
||||
out = torch.clamp(out, 0, 1)
|
||||
elif rounds:
|
||||
out = (out * 255.0).round() / 255.
|
||||
return out
|
||||
|
||||
|
||||
# ----------------------- Random Poisson (Shot) Noise ----------------------- #
|
||||
|
||||
|
||||
def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
|
||||
scale = np.random.uniform(scale_range[0], scale_range[1])
|
||||
if np.random.uniform() < gray_prob:
|
||||
gray_noise = True
|
||||
else:
|
||||
gray_noise = False
|
||||
return generate_poisson_noise(img, scale, gray_noise)
|
||||
|
||||
|
||||
def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
||||
noise = random_generate_poisson_noise(img, scale_range, gray_prob)
|
||||
out = img + noise
|
||||
if clip and rounds:
|
||||
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
||||
elif clip:
|
||||
out = np.clip(out, 0, 1)
|
||||
elif rounds:
|
||||
out = (out * 255.0).round() / 255.
|
||||
return out
|
||||
|
||||
|
||||
def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
|
||||
scale = torch.rand(
|
||||
img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
|
||||
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
|
||||
gray_noise = (gray_noise < gray_prob).float()
|
||||
return generate_poisson_noise_pt(img, scale, gray_noise)
|
||||
|
||||
|
||||
def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
||||
noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
|
||||
out = img + noise
|
||||
if clip and rounds:
|
||||
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
||||
elif clip:
|
||||
out = torch.clamp(out, 0, 1)
|
||||
elif rounds:
|
||||
out = (out * 255.0).round() / 255.
|
||||
return out
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------ #
|
||||
# --------------------------- JPEG compression --------------------------- #
|
||||
# ------------------------------------------------------------------------ #
|
||||
|
||||
|
||||
def add_jpg_compression(img, quality=90):
|
||||
"""Add JPG compression artifacts.
|
||||
|
||||
Args:
|
||||
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
||||
quality (float): JPG compression quality. 0 for lowest quality, 100 for
|
||||
best quality. Default: 90.
|
||||
|
||||
Returns:
|
||||
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
img = np.clip(img, 0, 1)
|
||||
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(quality)]
|
||||
_, encimg = cv2.imencode('.jpg', img * 255., encode_param)
|
||||
img = np.float32(cv2.imdecode(encimg, 1)) / 255.
|
||||
return img
|
||||
|
||||
|
||||
def random_add_jpg_compression(img, quality_range=(90, 100)):
|
||||
"""Randomly add JPG compression artifacts.
|
||||
|
||||
Args:
|
||||
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
||||
quality_range (tuple[float] | list[float]): JPG compression quality
|
||||
range. 0 for lowest quality, 100 for best quality.
|
||||
Default: (90, 100).
|
||||
|
||||
Returns:
|
||||
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
quality = np.random.uniform(quality_range[0], quality_range[1])
|
||||
return add_jpg_compression(img, quality)
|
||||
80
basicsr/data/ffhq_dataset.py
Normal file
80
basicsr/data/ffhq_dataset.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import random
|
||||
import time
|
||||
from os import path as osp
|
||||
from torch.utils import data as data
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from basicsr.data.transforms import augment
|
||||
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class FFHQDataset(data.Dataset):
|
||||
"""FFHQ dataset for StyleGAN.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train datasets. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
mean (list | tuple): Image mean.
|
||||
std (list | tuple): Image std.
|
||||
use_hflip (bool): Whether to horizontally flip.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(FFHQDataset, self).__init__()
|
||||
self.opt = opt
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
|
||||
self.gt_folder = opt['dataroot_gt']
|
||||
self.mean = opt['mean']
|
||||
self.std = opt['std']
|
||||
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = self.gt_folder
|
||||
if not self.gt_folder.endswith('.lmdb'):
|
||||
raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
||||
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
||||
self.paths = [line.split('.')[0] for line in fin]
|
||||
else:
|
||||
# FFHQ has 70000 images in total
|
||||
self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)]
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
# load gt image
|
||||
gt_path = self.paths[index]
|
||||
# avoid errors caused by high latency in reading files
|
||||
retry = 3
|
||||
while retry > 0:
|
||||
try:
|
||||
img_bytes = self.file_client.get(gt_path)
|
||||
except Exception as e:
|
||||
logger = get_root_logger()
|
||||
logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}')
|
||||
# change another file to read
|
||||
index = random.randint(0, self.__len__())
|
||||
gt_path = self.paths[index]
|
||||
time.sleep(1) # sleep 1s for occasional server congestion
|
||||
else:
|
||||
break
|
||||
finally:
|
||||
retry -= 1
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# random horizontal flip
|
||||
img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
|
||||
# normalize
|
||||
normalize(img_gt, self.mean, self.std, inplace=True)
|
||||
return {'gt': img_gt, 'gt_path': gt_path}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
32592
basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt
Normal file
32592
basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt
Normal file
File diff suppressed because it is too large
Load Diff
4
basicsr/data/meta_info/meta_info_REDS4_test_GT.txt
Normal file
4
basicsr/data/meta_info/meta_info_REDS4_test_GT.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
000 100 (720,1280,3)
|
||||
011 100 (720,1280,3)
|
||||
015 100 (720,1280,3)
|
||||
020 100 (720,1280,3)
|
||||
270
basicsr/data/meta_info/meta_info_REDS_GT.txt
Normal file
270
basicsr/data/meta_info/meta_info_REDS_GT.txt
Normal file
@@ -0,0 +1,270 @@
|
||||
000 100 (720,1280,3)
|
||||
001 100 (720,1280,3)
|
||||
002 100 (720,1280,3)
|
||||
003 100 (720,1280,3)
|
||||
004 100 (720,1280,3)
|
||||
005 100 (720,1280,3)
|
||||
006 100 (720,1280,3)
|
||||
007 100 (720,1280,3)
|
||||
008 100 (720,1280,3)
|
||||
009 100 (720,1280,3)
|
||||
010 100 (720,1280,3)
|
||||
011 100 (720,1280,3)
|
||||
012 100 (720,1280,3)
|
||||
013 100 (720,1280,3)
|
||||
014 100 (720,1280,3)
|
||||
015 100 (720,1280,3)
|
||||
016 100 (720,1280,3)
|
||||
017 100 (720,1280,3)
|
||||
018 100 (720,1280,3)
|
||||
019 100 (720,1280,3)
|
||||
020 100 (720,1280,3)
|
||||
021 100 (720,1280,3)
|
||||
022 100 (720,1280,3)
|
||||
023 100 (720,1280,3)
|
||||
024 100 (720,1280,3)
|
||||
025 100 (720,1280,3)
|
||||
026 100 (720,1280,3)
|
||||
027 100 (720,1280,3)
|
||||
028 100 (720,1280,3)
|
||||
029 100 (720,1280,3)
|
||||
030 100 (720,1280,3)
|
||||
031 100 (720,1280,3)
|
||||
032 100 (720,1280,3)
|
||||
033 100 (720,1280,3)
|
||||
034 100 (720,1280,3)
|
||||
035 100 (720,1280,3)
|
||||
036 100 (720,1280,3)
|
||||
037 100 (720,1280,3)
|
||||
038 100 (720,1280,3)
|
||||
039 100 (720,1280,3)
|
||||
040 100 (720,1280,3)
|
||||
041 100 (720,1280,3)
|
||||
042 100 (720,1280,3)
|
||||
043 100 (720,1280,3)
|
||||
044 100 (720,1280,3)
|
||||
045 100 (720,1280,3)
|
||||
046 100 (720,1280,3)
|
||||
047 100 (720,1280,3)
|
||||
048 100 (720,1280,3)
|
||||
049 100 (720,1280,3)
|
||||
050 100 (720,1280,3)
|
||||
051 100 (720,1280,3)
|
||||
052 100 (720,1280,3)
|
||||
053 100 (720,1280,3)
|
||||
054 100 (720,1280,3)
|
||||
055 100 (720,1280,3)
|
||||
056 100 (720,1280,3)
|
||||
057 100 (720,1280,3)
|
||||
058 100 (720,1280,3)
|
||||
059 100 (720,1280,3)
|
||||
060 100 (720,1280,3)
|
||||
061 100 (720,1280,3)
|
||||
062 100 (720,1280,3)
|
||||
063 100 (720,1280,3)
|
||||
064 100 (720,1280,3)
|
||||
065 100 (720,1280,3)
|
||||
066 100 (720,1280,3)
|
||||
067 100 (720,1280,3)
|
||||
068 100 (720,1280,3)
|
||||
069 100 (720,1280,3)
|
||||
070 100 (720,1280,3)
|
||||
071 100 (720,1280,3)
|
||||
072 100 (720,1280,3)
|
||||
073 100 (720,1280,3)
|
||||
074 100 (720,1280,3)
|
||||
075 100 (720,1280,3)
|
||||
076 100 (720,1280,3)
|
||||
077 100 (720,1280,3)
|
||||
078 100 (720,1280,3)
|
||||
079 100 (720,1280,3)
|
||||
080 100 (720,1280,3)
|
||||
081 100 (720,1280,3)
|
||||
082 100 (720,1280,3)
|
||||
083 100 (720,1280,3)
|
||||
084 100 (720,1280,3)
|
||||
085 100 (720,1280,3)
|
||||
086 100 (720,1280,3)
|
||||
087 100 (720,1280,3)
|
||||
088 100 (720,1280,3)
|
||||
089 100 (720,1280,3)
|
||||
090 100 (720,1280,3)
|
||||
091 100 (720,1280,3)
|
||||
092 100 (720,1280,3)
|
||||
093 100 (720,1280,3)
|
||||
094 100 (720,1280,3)
|
||||
095 100 (720,1280,3)
|
||||
096 100 (720,1280,3)
|
||||
097 100 (720,1280,3)
|
||||
098 100 (720,1280,3)
|
||||
099 100 (720,1280,3)
|
||||
100 100 (720,1280,3)
|
||||
101 100 (720,1280,3)
|
||||
102 100 (720,1280,3)
|
||||
103 100 (720,1280,3)
|
||||
104 100 (720,1280,3)
|
||||
105 100 (720,1280,3)
|
||||
106 100 (720,1280,3)
|
||||
107 100 (720,1280,3)
|
||||
108 100 (720,1280,3)
|
||||
109 100 (720,1280,3)
|
||||
110 100 (720,1280,3)
|
||||
111 100 (720,1280,3)
|
||||
112 100 (720,1280,3)
|
||||
113 100 (720,1280,3)
|
||||
114 100 (720,1280,3)
|
||||
115 100 (720,1280,3)
|
||||
116 100 (720,1280,3)
|
||||
117 100 (720,1280,3)
|
||||
118 100 (720,1280,3)
|
||||
119 100 (720,1280,3)
|
||||
120 100 (720,1280,3)
|
||||
121 100 (720,1280,3)
|
||||
122 100 (720,1280,3)
|
||||
123 100 (720,1280,3)
|
||||
124 100 (720,1280,3)
|
||||
125 100 (720,1280,3)
|
||||
126 100 (720,1280,3)
|
||||
127 100 (720,1280,3)
|
||||
128 100 (720,1280,3)
|
||||
129 100 (720,1280,3)
|
||||
130 100 (720,1280,3)
|
||||
131 100 (720,1280,3)
|
||||
132 100 (720,1280,3)
|
||||
133 100 (720,1280,3)
|
||||
134 100 (720,1280,3)
|
||||
135 100 (720,1280,3)
|
||||
136 100 (720,1280,3)
|
||||
137 100 (720,1280,3)
|
||||
138 100 (720,1280,3)
|
||||
139 100 (720,1280,3)
|
||||
140 100 (720,1280,3)
|
||||
141 100 (720,1280,3)
|
||||
142 100 (720,1280,3)
|
||||
143 100 (720,1280,3)
|
||||
144 100 (720,1280,3)
|
||||
145 100 (720,1280,3)
|
||||
146 100 (720,1280,3)
|
||||
147 100 (720,1280,3)
|
||||
148 100 (720,1280,3)
|
||||
149 100 (720,1280,3)
|
||||
150 100 (720,1280,3)
|
||||
151 100 (720,1280,3)
|
||||
152 100 (720,1280,3)
|
||||
153 100 (720,1280,3)
|
||||
154 100 (720,1280,3)
|
||||
155 100 (720,1280,3)
|
||||
156 100 (720,1280,3)
|
||||
157 100 (720,1280,3)
|
||||
158 100 (720,1280,3)
|
||||
159 100 (720,1280,3)
|
||||
160 100 (720,1280,3)
|
||||
161 100 (720,1280,3)
|
||||
162 100 (720,1280,3)
|
||||
163 100 (720,1280,3)
|
||||
164 100 (720,1280,3)
|
||||
165 100 (720,1280,3)
|
||||
166 100 (720,1280,3)
|
||||
167 100 (720,1280,3)
|
||||
168 100 (720,1280,3)
|
||||
169 100 (720,1280,3)
|
||||
170 100 (720,1280,3)
|
||||
171 100 (720,1280,3)
|
||||
172 100 (720,1280,3)
|
||||
173 100 (720,1280,3)
|
||||
174 100 (720,1280,3)
|
||||
175 100 (720,1280,3)
|
||||
176 100 (720,1280,3)
|
||||
177 100 (720,1280,3)
|
||||
178 100 (720,1280,3)
|
||||
179 100 (720,1280,3)
|
||||
180 100 (720,1280,3)
|
||||
181 100 (720,1280,3)
|
||||
182 100 (720,1280,3)
|
||||
183 100 (720,1280,3)
|
||||
184 100 (720,1280,3)
|
||||
185 100 (720,1280,3)
|
||||
186 100 (720,1280,3)
|
||||
187 100 (720,1280,3)
|
||||
188 100 (720,1280,3)
|
||||
189 100 (720,1280,3)
|
||||
190 100 (720,1280,3)
|
||||
191 100 (720,1280,3)
|
||||
192 100 (720,1280,3)
|
||||
193 100 (720,1280,3)
|
||||
194 100 (720,1280,3)
|
||||
195 100 (720,1280,3)
|
||||
196 100 (720,1280,3)
|
||||
197 100 (720,1280,3)
|
||||
198 100 (720,1280,3)
|
||||
199 100 (720,1280,3)
|
||||
200 100 (720,1280,3)
|
||||
201 100 (720,1280,3)
|
||||
202 100 (720,1280,3)
|
||||
203 100 (720,1280,3)
|
||||
204 100 (720,1280,3)
|
||||
205 100 (720,1280,3)
|
||||
206 100 (720,1280,3)
|
||||
207 100 (720,1280,3)
|
||||
208 100 (720,1280,3)
|
||||
209 100 (720,1280,3)
|
||||
210 100 (720,1280,3)
|
||||
211 100 (720,1280,3)
|
||||
212 100 (720,1280,3)
|
||||
213 100 (720,1280,3)
|
||||
214 100 (720,1280,3)
|
||||
215 100 (720,1280,3)
|
||||
216 100 (720,1280,3)
|
||||
217 100 (720,1280,3)
|
||||
218 100 (720,1280,3)
|
||||
219 100 (720,1280,3)
|
||||
220 100 (720,1280,3)
|
||||
221 100 (720,1280,3)
|
||||
222 100 (720,1280,3)
|
||||
223 100 (720,1280,3)
|
||||
224 100 (720,1280,3)
|
||||
225 100 (720,1280,3)
|
||||
226 100 (720,1280,3)
|
||||
227 100 (720,1280,3)
|
||||
228 100 (720,1280,3)
|
||||
229 100 (720,1280,3)
|
||||
230 100 (720,1280,3)
|
||||
231 100 (720,1280,3)
|
||||
232 100 (720,1280,3)
|
||||
233 100 (720,1280,3)
|
||||
234 100 (720,1280,3)
|
||||
235 100 (720,1280,3)
|
||||
236 100 (720,1280,3)
|
||||
237 100 (720,1280,3)
|
||||
238 100 (720,1280,3)
|
||||
239 100 (720,1280,3)
|
||||
240 100 (720,1280,3)
|
||||
241 100 (720,1280,3)
|
||||
242 100 (720,1280,3)
|
||||
243 100 (720,1280,3)
|
||||
244 100 (720,1280,3)
|
||||
245 100 (720,1280,3)
|
||||
246 100 (720,1280,3)
|
||||
247 100 (720,1280,3)
|
||||
248 100 (720,1280,3)
|
||||
249 100 (720,1280,3)
|
||||
250 100 (720,1280,3)
|
||||
251 100 (720,1280,3)
|
||||
252 100 (720,1280,3)
|
||||
253 100 (720,1280,3)
|
||||
254 100 (720,1280,3)
|
||||
255 100 (720,1280,3)
|
||||
256 100 (720,1280,3)
|
||||
257 100 (720,1280,3)
|
||||
258 100 (720,1280,3)
|
||||
259 100 (720,1280,3)
|
||||
260 100 (720,1280,3)
|
||||
261 100 (720,1280,3)
|
||||
262 100 (720,1280,3)
|
||||
263 100 (720,1280,3)
|
||||
264 100 (720,1280,3)
|
||||
265 100 (720,1280,3)
|
||||
266 100 (720,1280,3)
|
||||
267 100 (720,1280,3)
|
||||
268 100 (720,1280,3)
|
||||
269 100 (720,1280,3)
|
||||
@@ -0,0 +1,4 @@
|
||||
240 100 (720,1280,3)
|
||||
241 100 (720,1280,3)
|
||||
246 100 (720,1280,3)
|
||||
257 100 (720,1280,3)
|
||||
@@ -0,0 +1,30 @@
|
||||
240 100 (720,1280,3)
|
||||
241 100 (720,1280,3)
|
||||
242 100 (720,1280,3)
|
||||
243 100 (720,1280,3)
|
||||
244 100 (720,1280,3)
|
||||
245 100 (720,1280,3)
|
||||
246 100 (720,1280,3)
|
||||
247 100 (720,1280,3)
|
||||
248 100 (720,1280,3)
|
||||
249 100 (720,1280,3)
|
||||
250 100 (720,1280,3)
|
||||
251 100 (720,1280,3)
|
||||
252 100 (720,1280,3)
|
||||
253 100 (720,1280,3)
|
||||
254 100 (720,1280,3)
|
||||
255 100 (720,1280,3)
|
||||
256 100 (720,1280,3)
|
||||
257 100 (720,1280,3)
|
||||
258 100 (720,1280,3)
|
||||
259 100 (720,1280,3)
|
||||
260 100 (720,1280,3)
|
||||
261 100 (720,1280,3)
|
||||
262 100 (720,1280,3)
|
||||
263 100 (720,1280,3)
|
||||
264 100 (720,1280,3)
|
||||
265 100 (720,1280,3)
|
||||
266 100 (720,1280,3)
|
||||
267 100 (720,1280,3)
|
||||
268 100 (720,1280,3)
|
||||
269 100 (720,1280,3)
|
||||
7824
basicsr/data/meta_info/meta_info_Vimeo90K_test_GT.txt
Normal file
7824
basicsr/data/meta_info/meta_info_Vimeo90K_test_GT.txt
Normal file
File diff suppressed because it is too large
Load Diff
1225
basicsr/data/meta_info/meta_info_Vimeo90K_test_fast_GT.txt
Normal file
1225
basicsr/data/meta_info/meta_info_Vimeo90K_test_fast_GT.txt
Normal file
File diff suppressed because it is too large
Load Diff
4977
basicsr/data/meta_info/meta_info_Vimeo90K_test_medium_GT.txt
Normal file
4977
basicsr/data/meta_info/meta_info_Vimeo90K_test_medium_GT.txt
Normal file
File diff suppressed because it is too large
Load Diff
1613
basicsr/data/meta_info/meta_info_Vimeo90K_test_slow_GT.txt
Normal file
1613
basicsr/data/meta_info/meta_info_Vimeo90K_test_slow_GT.txt
Normal file
File diff suppressed because it is too large
Load Diff
64612
basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
Normal file
64612
basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
Normal file
File diff suppressed because it is too large
Load Diff
106
basicsr/data/paired_image_dataset.py
Normal file
106
basicsr/data/paired_image_dataset.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from torch.utils import data as data
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
|
||||
from basicsr.data.transforms import augment, paired_random_crop
|
||||
from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class PairedImageDataset(data.Dataset):
|
||||
"""Paired image dataset for image restoration.
|
||||
|
||||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
|
||||
|
||||
There are three modes:
|
||||
|
||||
1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
|
||||
2. **meta_info_file**: Use meta information file to generate paths. \
|
||||
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
|
||||
3. **folder**: Scan folders to generate paths. The rest.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train datasets. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
meta_info_file (str): Path for meta information file.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
|
||||
Default: '{}'.
|
||||
gt_size (int): Cropped patched size for gt patches.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
||||
scale (bool): Scale, which will be added automatically.
|
||||
phase (str): 'train' or 'val'.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(PairedImageDataset, self).__init__()
|
||||
self.opt = opt
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
self.mean = opt['mean'] if 'mean' in opt else None
|
||||
self.std = opt['std'] if 'std' in opt else None
|
||||
|
||||
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
||||
if 'filename_tmpl' in opt:
|
||||
self.filename_tmpl = opt['filename_tmpl']
|
||||
else:
|
||||
self.filename_tmpl = '{}'
|
||||
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
||||
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
||||
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
||||
elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
|
||||
self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
|
||||
self.opt['meta_info_file'], self.filename_tmpl)
|
||||
else:
|
||||
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
scale = self.opt['scale']
|
||||
|
||||
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
||||
# image range: [0, 1], float32.
|
||||
gt_path = self.paths[index]['gt_path']
|
||||
img_bytes = self.file_client.get(gt_path, 'gt')
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
lq_path = self.paths[index]['lq_path']
|
||||
img_bytes = self.file_client.get(lq_path, 'lq')
|
||||
img_lq = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# augmentation for training
|
||||
if self.opt['phase'] == 'train':
|
||||
gt_size = self.opt['gt_size']
|
||||
# random crop
|
||||
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
|
||||
# flip, rotation
|
||||
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
# color space transform
|
||||
if 'color' in self.opt and self.opt['color'] == 'y':
|
||||
img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None]
|
||||
img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None]
|
||||
|
||||
# crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
|
||||
# TODO: It is better to update the datasets, rather than force to crop
|
||||
if self.opt['phase'] != 'train':
|
||||
img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
||||
# normalize
|
||||
if self.mean is not None or self.std is not None:
|
||||
normalize(img_lq, self.mean, self.std, inplace=True)
|
||||
normalize(img_gt, self.mean, self.std, inplace=True)
|
||||
|
||||
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
122
basicsr/data/prefetch_dataloader.py
Normal file
122
basicsr/data/prefetch_dataloader.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import queue as Queue
|
||||
import threading
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class PrefetchGenerator(threading.Thread):
|
||||
"""A general prefetch generator.
|
||||
|
||||
Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
|
||||
|
||||
Args:
|
||||
generator: Python generator.
|
||||
num_prefetch_queue (int): Number of prefetch queue.
|
||||
"""
|
||||
|
||||
def __init__(self, generator, num_prefetch_queue):
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = Queue.Queue(num_prefetch_queue)
|
||||
self.generator = generator
|
||||
self.daemon = True
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
for item in self.generator:
|
||||
self.queue.put(item)
|
||||
self.queue.put(None)
|
||||
|
||||
def __next__(self):
|
||||
next_item = self.queue.get()
|
||||
if next_item is None:
|
||||
raise StopIteration
|
||||
return next_item
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
|
||||
class PrefetchDataLoader(DataLoader):
|
||||
"""Prefetch version of dataloader.
|
||||
|
||||
Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
|
||||
|
||||
TODO:
|
||||
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
|
||||
ddp.
|
||||
|
||||
Args:
|
||||
num_prefetch_queue (int): Number of prefetch queue.
|
||||
kwargs (dict): Other arguments for dataloader.
|
||||
"""
|
||||
|
||||
def __init__(self, num_prefetch_queue, **kwargs):
|
||||
self.num_prefetch_queue = num_prefetch_queue
|
||||
super(PrefetchDataLoader, self).__init__(**kwargs)
|
||||
|
||||
def __iter__(self):
|
||||
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
|
||||
|
||||
|
||||
class CPUPrefetcher():
|
||||
"""CPU prefetcher.
|
||||
|
||||
Args:
|
||||
loader: Dataloader.
|
||||
"""
|
||||
|
||||
def __init__(self, loader):
|
||||
self.ori_loader = loader
|
||||
self.loader = iter(loader)
|
||||
|
||||
def next(self):
|
||||
try:
|
||||
return next(self.loader)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
self.loader = iter(self.ori_loader)
|
||||
|
||||
|
||||
class CUDAPrefetcher():
|
||||
"""CUDA prefetcher.
|
||||
|
||||
Reference: https://github.com/NVIDIA/apex/issues/304#
|
||||
|
||||
It may consume more GPU memory.
|
||||
|
||||
Args:
|
||||
loader: Dataloader.
|
||||
opt (dict): Options.
|
||||
"""
|
||||
|
||||
def __init__(self, loader, opt):
|
||||
self.ori_loader = loader
|
||||
self.loader = iter(loader)
|
||||
self.opt = opt
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
self.batch = next(self.loader) # self.batch is a dict
|
||||
except StopIteration:
|
||||
self.batch = None
|
||||
return None
|
||||
# put tensors to gpu
|
||||
with torch.cuda.stream(self.stream):
|
||||
for k, v in self.batch.items():
|
||||
if torch.is_tensor(v):
|
||||
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
|
||||
|
||||
def next(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
batch = self.batch
|
||||
self.preload()
|
||||
return batch
|
||||
|
||||
def reset(self):
|
||||
self.loader = iter(self.ori_loader)
|
||||
self.preload()
|
||||
384
basicsr/data/realesrgan_dataset.py
Normal file
384
basicsr/data/realesrgan_dataset.py
Normal file
@@ -0,0 +1,384 @@
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
import time
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
import albumentations
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch.utils import data as data
|
||||
|
||||
from basicsr.utils import DiffJPEG
|
||||
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
|
||||
from basicsr.data.transforms import augment
|
||||
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
from basicsr.utils.img_process_util import filter2D
|
||||
from basicsr.data.transforms import paired_random_crop, random_crop
|
||||
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
||||
|
||||
from utils import util_image
|
||||
|
||||
def readline_txt(txt_file):
|
||||
txt_file = [txt_file, ] if isinstance(txt_file, str) else txt_file
|
||||
out = []
|
||||
for txt_file_current in txt_file:
|
||||
with open(txt_file_current, 'r') as ff:
|
||||
out.extend([x[:-1] for x in ff.readlines()])
|
||||
|
||||
return out
|
||||
|
||||
@DATASET_REGISTRY.register(suffix='basicsr')
|
||||
class RealESRGANDataset(data.Dataset):
|
||||
"""Dataset used for Real-ESRGAN model:
|
||||
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
||||
|
||||
It loads gt (Ground-Truth) images, and augments them.
|
||||
It also generates blur kernels and sinc kernels for generating low-quality images.
|
||||
Note that the low-quality images are processed in tensors on GPUS for faster processing.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train datasets. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
meta_info (str): Path for meta information file.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
||||
Please see more options in the codes.
|
||||
"""
|
||||
|
||||
def __init__(self, opt, mode='training'):
|
||||
super(RealESRGANDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
|
||||
# file client (lmdb io backend)
|
||||
self.image_paths = []
|
||||
self.text_paths = []
|
||||
self.moment_paths = []
|
||||
if opt.get('data_source', None) is not None:
|
||||
for ii in range(len(opt['data_source'])):
|
||||
configs = opt['data_source'].get(f'source{ii+1}')
|
||||
root_path = Path(configs.root_path)
|
||||
im_folder = root_path / configs.image_path
|
||||
im_ext = configs.im_ext
|
||||
image_stems = sorted([x.stem for x in im_folder.glob(f"*.{im_ext}")])
|
||||
if configs.get('length', None) is not None:
|
||||
assert configs.length < len(image_stems)
|
||||
image_stems = image_stems[:configs.length]
|
||||
|
||||
if configs.get("text_path", None) is not None:
|
||||
text_folder = root_path / configs.text_path
|
||||
text_stems = [x.stem for x in text_folder.glob("*.txt")]
|
||||
image_stems = sorted(list(set(image_stems).intersection(set(text_stems))))
|
||||
self.text_paths.extend([str(text_folder / f"{x}.txt") for x in image_stems])
|
||||
else:
|
||||
self.text_paths.extend([None, ] * len(image_stems))
|
||||
|
||||
self.image_paths.extend([str(im_folder / f"{x}.{im_ext}") for x in image_stems])
|
||||
|
||||
if configs.get("moment_path", None) is not None:
|
||||
moment_folder = root_path / configs.moment_path
|
||||
self.moment_paths.extend([str(moment_folder / f"{x}.npy") for x in image_stems])
|
||||
else:
|
||||
self.moment_paths.extend([None, ] * len(image_stems))
|
||||
|
||||
# blur settings for the first degradation
|
||||
self.blur_kernel_size = opt['blur_kernel_size']
|
||||
self.kernel_list = opt['kernel_list']
|
||||
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
|
||||
self.blur_sigma = opt['blur_sigma']
|
||||
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
|
||||
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
|
||||
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
|
||||
|
||||
# blur settings for the second degradation
|
||||
self.blur_kernel_size2 = opt['blur_kernel_size2']
|
||||
self.kernel_list2 = opt['kernel_list2']
|
||||
self.kernel_prob2 = opt['kernel_prob2']
|
||||
self.blur_sigma2 = opt['blur_sigma2']
|
||||
self.betag_range2 = opt['betag_range2']
|
||||
self.betap_range2 = opt['betap_range2']
|
||||
self.sinc_prob2 = opt['sinc_prob2']
|
||||
|
||||
# a final sinc filter
|
||||
self.final_sinc_prob = opt['final_sinc_prob']
|
||||
|
||||
self.kernel_range1 = [x for x in range(3, opt['blur_kernel_size'], 2)] # kernel size ranges from 7 to 21
|
||||
self.kernel_range2 = [x for x in range(3, opt['blur_kernel_size2'], 2)] # kernel size ranges from 7 to 21
|
||||
# TODO: kernel range is now hard-coded, should be in the configure file
|
||||
# convolving with pulse tensor brings no blurry effect
|
||||
self.pulse_tensor = torch.zeros(opt['blur_kernel_size2'], opt['blur_kernel_size2']).float()
|
||||
self.pulse_tensor[opt['blur_kernel_size2']//2, opt['blur_kernel_size2']//2] = 1
|
||||
|
||||
self.mode = mode
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
# -------------------------------- Load gt images -------------------------------- #
|
||||
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
||||
gt_path = self.image_paths[index]
|
||||
# avoid errors caused by high latency in reading files
|
||||
retry = 3
|
||||
while retry > 0:
|
||||
try:
|
||||
img_bytes = self.file_client.get(gt_path, 'gt')
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
except:
|
||||
index = random.randint(0, self.__len__())
|
||||
gt_path = self.image_paths[index]
|
||||
time.sleep(1) # sleep 1s for occasional server congestion
|
||||
finally:
|
||||
retry -= 1
|
||||
if self.mode == 'testing':
|
||||
if not hasattr(self, 'test_aug'):
|
||||
self.test_aug = albumentations.Compose([
|
||||
albumentations.SmallestMaxSize(
|
||||
max_size=self.opt['gt_size'],
|
||||
interpolation=cv2.INTER_AREA,
|
||||
),
|
||||
albumentations.CenterCrop(self.opt['gt_size'], self.opt['gt_size']),
|
||||
])
|
||||
img_gt = self.test_aug(image=img_gt)['image']
|
||||
elif self.mode == 'training':
|
||||
# -------------------- Do augmentation for training: flip, rotation -------------------- #
|
||||
if self.opt['use_hflip'] or self.opt['use_rot']:
|
||||
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
h, w = img_gt.shape[0:2]
|
||||
gt_size = self.opt['gt_size']
|
||||
|
||||
# resize or pad
|
||||
if not self.opt['random_crop']:
|
||||
if not min(h, w) == gt_size:
|
||||
if not hasattr(self, 'smallest_resizer'):
|
||||
self.smallest_resizer = util_image.SmallestMaxSize(
|
||||
max_size=gt_size, pass_resize=False,
|
||||
)
|
||||
img_gt = self.smallest_resizer(img_gt)
|
||||
|
||||
# center crop
|
||||
if not hasattr(self, 'center_cropper'):
|
||||
self.center_cropper = albumentations.CenterCrop(gt_size, gt_size)
|
||||
img_gt = self.center_cropper(image=img_gt)['image']
|
||||
else:
|
||||
img_gt = random_crop(img_gt, self.opt['gt_size'])
|
||||
else:
|
||||
raise ValueError(f'Unexpected value {self.mode} for mode parameter')
|
||||
|
||||
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
|
||||
kernel_size = random.choice(self.kernel_range1)
|
||||
if np.random.uniform() < self.opt['sinc_prob']:
|
||||
# this sinc filter setting is for kernels ranging from [7, 21]
|
||||
if kernel_size < 13:
|
||||
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
||||
else:
|
||||
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
||||
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
||||
else:
|
||||
kernel = random_mixed_kernels(
|
||||
self.kernel_list,
|
||||
self.kernel_prob,
|
||||
kernel_size,
|
||||
self.blur_sigma,
|
||||
self.blur_sigma, [-math.pi, math.pi],
|
||||
self.betag_range,
|
||||
self.betap_range,
|
||||
noise_range=None)
|
||||
# pad kernel
|
||||
pad_size = (self.blur_kernel_size - kernel_size) // 2
|
||||
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
||||
|
||||
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
|
||||
kernel_size = random.choice(self.kernel_range2)
|
||||
if np.random.uniform() < self.opt['sinc_prob2']:
|
||||
if kernel_size < 13:
|
||||
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
||||
else:
|
||||
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
||||
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
||||
else:
|
||||
kernel2 = random_mixed_kernels(
|
||||
self.kernel_list2,
|
||||
self.kernel_prob2,
|
||||
kernel_size,
|
||||
self.blur_sigma2,
|
||||
self.blur_sigma2, [-math.pi, math.pi],
|
||||
self.betag_range2,
|
||||
self.betap_range2,
|
||||
noise_range=None)
|
||||
|
||||
# pad kernel
|
||||
pad_size = (self.blur_kernel_size2 - kernel_size) // 2
|
||||
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
||||
|
||||
# ------------------------------------- the final sinc kernel ------------------------------------- #
|
||||
if np.random.uniform() < self.opt['final_sinc_prob']:
|
||||
kernel_size = random.choice(self.kernel_range2)
|
||||
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
||||
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=self.blur_kernel_size2)
|
||||
sinc_kernel = torch.FloatTensor(sinc_kernel)
|
||||
else:
|
||||
sinc_kernel = self.pulse_tensor
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
|
||||
kernel = torch.FloatTensor(kernel)
|
||||
kernel2 = torch.FloatTensor(kernel2)
|
||||
|
||||
if self.text_paths[index] is None or self.opt['random_crop']:
|
||||
prompt = ""
|
||||
else:
|
||||
with open(self.text_paths[index], 'r') as ff:
|
||||
prompt = ff.read()
|
||||
if self.opt.max_token_length is not None:
|
||||
prompt = prompt[:self.opt.max_token_length]
|
||||
|
||||
return_d = {
|
||||
'gt': img_gt,
|
||||
'gt_path': gt_path,
|
||||
'txt': prompt,
|
||||
'kernel1': kernel,
|
||||
'kernel2': kernel2,
|
||||
'sinc_kernel': sinc_kernel,
|
||||
}
|
||||
if self.moment_paths[index] is not None and (not self.opt['random_crop']):
|
||||
return_d['gt_moment'] = np.load(self.moment_paths[index])
|
||||
|
||||
return return_d
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_paths)
|
||||
|
||||
def degrade_fun(self, conf_degradation, im_gt, kernel1, kernel2, sinc_kernel):
|
||||
if not hasattr(self, 'jpeger'):
|
||||
self.jpeger = DiffJPEG(differentiable=False) # simulate JPEG compression artifacts
|
||||
|
||||
ori_h, ori_w = im_gt.size()[2:4]
|
||||
sf = conf_degradation.sf
|
||||
|
||||
# ----------------------- The first degradation process ----------------------- #
|
||||
# blur
|
||||
out = filter2D(im_gt, kernel1)
|
||||
# random resize
|
||||
updown_type = random.choices(
|
||||
['up', 'down', 'keep'],
|
||||
conf_degradation['resize_prob'],
|
||||
)[0]
|
||||
if updown_type == 'up':
|
||||
scale = random.uniform(1, conf_degradation['resize_range'][1])
|
||||
elif updown_type == 'down':
|
||||
scale = random.uniform(conf_degradation['resize_range'][0], 1)
|
||||
else:
|
||||
scale = 1
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
||||
# add noise
|
||||
gray_noise_prob = conf_degradation['gray_noise_prob']
|
||||
if random.random() < conf_degradation['gaussian_noise_prob']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
out,
|
||||
sigma_range=conf_degradation['noise_range'],
|
||||
clip=True,
|
||||
rounds=False,
|
||||
gray_prob=gray_noise_prob,
|
||||
)
|
||||
else:
|
||||
out = random_add_poisson_noise_pt(
|
||||
out,
|
||||
scale_range=conf_degradation['poisson_scale_range'],
|
||||
gray_prob=gray_noise_prob,
|
||||
clip=True,
|
||||
rounds=False)
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*conf_degradation['jpeg_range'])
|
||||
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
|
||||
# ----------------------- The second degradation process ----------------------- #
|
||||
# blur
|
||||
if random.random() < conf_degradation['second_order_prob']:
|
||||
if random.random() < conf_degradation['second_blur_prob']:
|
||||
out = filter2D(out, kernel2)
|
||||
# random resize
|
||||
updown_type = random.choices(
|
||||
['up', 'down', 'keep'],
|
||||
conf_degradation['resize_prob2'],
|
||||
)[0]
|
||||
if updown_type == 'up':
|
||||
scale = random.uniform(1, conf_degradation['resize_range2'][1])
|
||||
elif updown_type == 'down':
|
||||
scale = random.uniform(conf_degradation['resize_range2'][0], 1)
|
||||
else:
|
||||
scale = 1
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(
|
||||
out,
|
||||
size=(int(ori_h / sf * scale), int(ori_w / sf * scale)),
|
||||
mode=mode,
|
||||
)
|
||||
# add noise
|
||||
gray_noise_prob = conf_degradation['gray_noise_prob2']
|
||||
if random.random() < conf_degradation['gaussian_noise_prob2']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
out,
|
||||
sigma_range=conf_degradation['noise_range2'],
|
||||
clip=True,
|
||||
rounds=False,
|
||||
gray_prob=gray_noise_prob,
|
||||
)
|
||||
else:
|
||||
out = random_add_poisson_noise_pt(
|
||||
out,
|
||||
scale_range=conf_degradation['poisson_scale_range2'],
|
||||
gray_prob=gray_noise_prob,
|
||||
clip=True,
|
||||
rounds=False,
|
||||
)
|
||||
|
||||
# JPEG compression + the final sinc filter
|
||||
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
||||
# as one operation.
|
||||
# We consider two orders:
|
||||
# 1. [resize back + sinc filter] + JPEG compression
|
||||
# 2. JPEG compression + [resize back + sinc filter]
|
||||
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
||||
if random.random() < 0.5:
|
||||
# resize back + the final sinc filter
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(
|
||||
out,
|
||||
size=(ori_h // sf, ori_w // sf),
|
||||
mode=mode,
|
||||
)
|
||||
out = filter2D(out, sinc_kernel)
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*conf_degradation['jpeg_range2'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
else:
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*conf_degradation['jpeg_range2'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
# resize back + the final sinc filter
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(
|
||||
out,
|
||||
size=(ori_h // sf, ori_w // sf),
|
||||
mode=mode,
|
||||
)
|
||||
out = filter2D(out, sinc_kernel)
|
||||
|
||||
# clamp and round
|
||||
im_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
||||
|
||||
return {'lq':im_lq.contiguous(), 'gt':im_gt}
|
||||
106
basicsr/data/realesrgan_paired_dataset.py
Normal file
106
basicsr/data/realesrgan_paired_dataset.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
from torch.utils import data as data
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
|
||||
from basicsr.data.transforms import augment, paired_random_crop
|
||||
from basicsr.utils import FileClient, imfrombytes, img2tensor
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register(suffix='basicsr')
|
||||
class RealESRGANPairedDataset(data.Dataset):
|
||||
"""Paired image dataset for image restoration.
|
||||
|
||||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
|
||||
|
||||
There are three modes:
|
||||
|
||||
1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
|
||||
2. **meta_info_file**: Use meta information file to generate paths. \
|
||||
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
|
||||
3. **folder**: Scan folders to generate paths. The rest.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train datasets. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
meta_info (str): Path for meta information file.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
|
||||
Default: '{}'.
|
||||
gt_size (int): Cropped patched size for gt patches.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
||||
scale (bool): Scale, which will be added automatically.
|
||||
phase (str): 'train' or 'val'.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(RealESRGANPairedDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
# mean and std for normalizing the input images
|
||||
self.mean = opt['mean'] if 'mean' in opt else None
|
||||
self.std = opt['std'] if 'std' in opt else None
|
||||
|
||||
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
||||
self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
|
||||
|
||||
# file client (lmdb io backend)
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
||||
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
||||
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
||||
elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
|
||||
# disk backend with meta_info
|
||||
# Each line in the meta_info describes the relative path to an image
|
||||
with open(self.opt['meta_info']) as fin:
|
||||
paths = [line.strip() for line in fin]
|
||||
self.paths = []
|
||||
for path in paths:
|
||||
gt_path, lq_path = path.split(', ')
|
||||
gt_path = os.path.join(self.gt_folder, gt_path)
|
||||
lq_path = os.path.join(self.lq_folder, lq_path)
|
||||
self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
|
||||
else:
|
||||
# disk backend
|
||||
# it will scan the whole folder to get meta info
|
||||
# it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
|
||||
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
scale = self.opt['scale']
|
||||
|
||||
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
||||
# image range: [0, 1], float32.
|
||||
gt_path = self.paths[index]['gt_path']
|
||||
img_bytes = self.file_client.get(gt_path, 'gt')
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
lq_path = self.paths[index]['lq_path']
|
||||
img_bytes = self.file_client.get(lq_path, 'lq')
|
||||
img_lq = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# augmentation for training
|
||||
if self.opt['phase'] == 'train':
|
||||
gt_size = self.opt['gt_size']
|
||||
# random crop
|
||||
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
|
||||
# flip, rotation
|
||||
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
||||
# normalize
|
||||
if self.mean is not None or self.std is not None:
|
||||
normalize(img_lq, self.mean, self.std, inplace=True)
|
||||
normalize(img_gt, self.mean, self.std, inplace=True)
|
||||
|
||||
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
352
basicsr/data/reds_dataset.py
Normal file
352
basicsr/data/reds_dataset.py
Normal file
@@ -0,0 +1,352 @@
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from torch.utils import data as data
|
||||
|
||||
from basicsr.data.transforms import augment, paired_random_crop
|
||||
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
||||
from basicsr.utils.flow_util import dequantize_flow
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class REDSDataset(data.Dataset):
|
||||
"""REDS dataset for training.
|
||||
|
||||
The keys are generated from a meta info txt file.
|
||||
basicsr/data/meta_info/meta_info_REDS_GT.txt
|
||||
|
||||
Each line contains:
|
||||
1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
|
||||
a white space.
|
||||
Examples:
|
||||
000 100 (720,1280,3)
|
||||
001 100 (720,1280,3)
|
||||
...
|
||||
|
||||
Key examples: "000/00000000"
|
||||
GT (gt): Ground-Truth;
|
||||
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train dataset. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
dataroot_flow (str, optional): Data root path for flow.
|
||||
meta_info_file (str): Path for meta information file.
|
||||
val_partition (str): Validation partition types. 'REDS4' or 'official'.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
num_frame (int): Window size for input frames.
|
||||
gt_size (int): Cropped patched size for gt patches.
|
||||
interval_list (list): Interval list for temporal augmentation.
|
||||
random_reverse (bool): Random reverse input frames.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
||||
scale (bool): Scale, which will be added automatically.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(REDSDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
|
||||
self.flow_root = Path(opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None
|
||||
assert opt['num_frame'] % 2 == 1, (f'num_frame should be odd number, but got {opt["num_frame"]}')
|
||||
self.num_frame = opt['num_frame']
|
||||
self.num_half_frames = opt['num_frame'] // 2
|
||||
|
||||
self.keys = []
|
||||
with open(opt['meta_info_file'], 'r') as fin:
|
||||
for line in fin:
|
||||
folder, frame_num, _ = line.split(' ')
|
||||
self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
|
||||
|
||||
# remove the video clips used in validation
|
||||
if opt['val_partition'] == 'REDS4':
|
||||
val_partition = ['000', '011', '015', '020']
|
||||
elif opt['val_partition'] == 'official':
|
||||
val_partition = [f'{v:03d}' for v in range(240, 270)]
|
||||
else:
|
||||
raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
|
||||
f"Supported ones are ['official', 'REDS4'].")
|
||||
self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
|
||||
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
self.is_lmdb = False
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.is_lmdb = True
|
||||
if self.flow_root is not None:
|
||||
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
|
||||
self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
|
||||
else:
|
||||
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
|
||||
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
||||
|
||||
# temporal augmentation configs
|
||||
self.interval_list = opt['interval_list']
|
||||
self.random_reverse = opt['random_reverse']
|
||||
interval_str = ','.join(str(x) for x in opt['interval_list'])
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
|
||||
f'random reverse is {self.random_reverse}.')
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
scale = self.opt['scale']
|
||||
gt_size = self.opt['gt_size']
|
||||
key = self.keys[index]
|
||||
clip_name, frame_name = key.split('/') # key example: 000/00000000
|
||||
center_frame_idx = int(frame_name)
|
||||
|
||||
# determine the neighboring frames
|
||||
interval = random.choice(self.interval_list)
|
||||
|
||||
# ensure not exceeding the borders
|
||||
start_frame_idx = center_frame_idx - self.num_half_frames * interval
|
||||
end_frame_idx = center_frame_idx + self.num_half_frames * interval
|
||||
# each clip has 100 frames starting from 0 to 99
|
||||
while (start_frame_idx < 0) or (end_frame_idx > 99):
|
||||
center_frame_idx = random.randint(0, 99)
|
||||
start_frame_idx = (center_frame_idx - self.num_half_frames * interval)
|
||||
end_frame_idx = center_frame_idx + self.num_half_frames * interval
|
||||
frame_name = f'{center_frame_idx:08d}'
|
||||
neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval))
|
||||
# random reverse
|
||||
if self.random_reverse and random.random() < 0.5:
|
||||
neighbor_list.reverse()
|
||||
|
||||
assert len(neighbor_list) == self.num_frame, (f'Wrong length of neighbor list: {len(neighbor_list)}')
|
||||
|
||||
# get the GT frame (as the center frame)
|
||||
if self.is_lmdb:
|
||||
img_gt_path = f'{clip_name}/{frame_name}'
|
||||
else:
|
||||
img_gt_path = self.gt_root / clip_name / f'{frame_name}.png'
|
||||
img_bytes = self.file_client.get(img_gt_path, 'gt')
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# get the neighboring LQ frames
|
||||
img_lqs = []
|
||||
for neighbor in neighbor_list:
|
||||
if self.is_lmdb:
|
||||
img_lq_path = f'{clip_name}/{neighbor:08d}'
|
||||
else:
|
||||
img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
|
||||
img_bytes = self.file_client.get(img_lq_path, 'lq')
|
||||
img_lq = imfrombytes(img_bytes, float32=True)
|
||||
img_lqs.append(img_lq)
|
||||
|
||||
# get flows
|
||||
if self.flow_root is not None:
|
||||
img_flows = []
|
||||
# read previous flows
|
||||
for i in range(self.num_half_frames, 0, -1):
|
||||
if self.is_lmdb:
|
||||
flow_path = f'{clip_name}/{frame_name}_p{i}'
|
||||
else:
|
||||
flow_path = (self.flow_root / clip_name / f'{frame_name}_p{i}.png')
|
||||
img_bytes = self.file_client.get(flow_path, 'flow')
|
||||
cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
|
||||
dx, dy = np.split(cat_flow, 2, axis=0)
|
||||
flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
|
||||
img_flows.append(flow)
|
||||
# read next flows
|
||||
for i in range(1, self.num_half_frames + 1):
|
||||
if self.is_lmdb:
|
||||
flow_path = f'{clip_name}/{frame_name}_n{i}'
|
||||
else:
|
||||
flow_path = (self.flow_root / clip_name / f'{frame_name}_n{i}.png')
|
||||
img_bytes = self.file_client.get(flow_path, 'flow')
|
||||
cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
|
||||
dx, dy = np.split(cat_flow, 2, axis=0)
|
||||
flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
|
||||
img_flows.append(flow)
|
||||
|
||||
# for random crop, here, img_flows and img_lqs have the same
|
||||
# spatial size
|
||||
img_lqs.extend(img_flows)
|
||||
|
||||
# randomly crop
|
||||
img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
|
||||
if self.flow_root is not None:
|
||||
img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.num_frame:]
|
||||
|
||||
# augmentation - flip, rotate
|
||||
img_lqs.append(img_gt)
|
||||
if self.flow_root is not None:
|
||||
img_results, img_flows = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'], img_flows)
|
||||
else:
|
||||
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
img_results = img2tensor(img_results)
|
||||
img_lqs = torch.stack(img_results[0:-1], dim=0)
|
||||
img_gt = img_results[-1]
|
||||
|
||||
if self.flow_root is not None:
|
||||
img_flows = img2tensor(img_flows)
|
||||
# add the zero center flow
|
||||
img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0]))
|
||||
img_flows = torch.stack(img_flows, dim=0)
|
||||
|
||||
# img_lqs: (t, c, h, w)
|
||||
# img_flows: (t, 2, h, w)
|
||||
# img_gt: (c, h, w)
|
||||
# key: str
|
||||
if self.flow_root is not None:
|
||||
return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key}
|
||||
else:
|
||||
return {'lq': img_lqs, 'gt': img_gt, 'key': key}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.keys)
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class REDSRecurrentDataset(data.Dataset):
|
||||
"""REDS dataset for training recurrent networks.
|
||||
|
||||
The keys are generated from a meta info txt file.
|
||||
basicsr/data/meta_info/meta_info_REDS_GT.txt
|
||||
|
||||
Each line contains:
|
||||
1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
|
||||
a white space.
|
||||
Examples:
|
||||
000 100 (720,1280,3)
|
||||
001 100 (720,1280,3)
|
||||
...
|
||||
|
||||
Key examples: "000/00000000"
|
||||
GT (gt): Ground-Truth;
|
||||
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train dataset. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
dataroot_flow (str, optional): Data root path for flow.
|
||||
meta_info_file (str): Path for meta information file.
|
||||
val_partition (str): Validation partition types. 'REDS4' or 'official'.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
num_frame (int): Window size for input frames.
|
||||
gt_size (int): Cropped patched size for gt patches.
|
||||
interval_list (list): Interval list for temporal augmentation.
|
||||
random_reverse (bool): Random reverse input frames.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
||||
scale (bool): Scale, which will be added automatically.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(REDSRecurrentDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
|
||||
self.num_frame = opt['num_frame']
|
||||
|
||||
self.keys = []
|
||||
with open(opt['meta_info_file'], 'r') as fin:
|
||||
for line in fin:
|
||||
folder, frame_num, _ = line.split(' ')
|
||||
self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
|
||||
|
||||
# remove the video clips used in validation
|
||||
if opt['val_partition'] == 'REDS4':
|
||||
val_partition = ['000', '011', '015', '020']
|
||||
elif opt['val_partition'] == 'official':
|
||||
val_partition = [f'{v:03d}' for v in range(240, 270)]
|
||||
else:
|
||||
raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
|
||||
f"Supported ones are ['official', 'REDS4'].")
|
||||
if opt['test_mode']:
|
||||
self.keys = [v for v in self.keys if v.split('/')[0] in val_partition]
|
||||
else:
|
||||
self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
|
||||
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
self.is_lmdb = False
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.is_lmdb = True
|
||||
if hasattr(self, 'flow_root') and self.flow_root is not None:
|
||||
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
|
||||
self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
|
||||
else:
|
||||
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
|
||||
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
||||
|
||||
# temporal augmentation configs
|
||||
self.interval_list = opt.get('interval_list', [1])
|
||||
self.random_reverse = opt.get('random_reverse', False)
|
||||
interval_str = ','.join(str(x) for x in self.interval_list)
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
|
||||
f'random reverse is {self.random_reverse}.')
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
scale = self.opt['scale']
|
||||
gt_size = self.opt['gt_size']
|
||||
key = self.keys[index]
|
||||
clip_name, frame_name = key.split('/') # key example: 000/00000000
|
||||
|
||||
# determine the neighboring frames
|
||||
interval = random.choice(self.interval_list)
|
||||
|
||||
# ensure not exceeding the borders
|
||||
start_frame_idx = int(frame_name)
|
||||
if start_frame_idx > 100 - self.num_frame * interval:
|
||||
start_frame_idx = random.randint(0, 100 - self.num_frame * interval)
|
||||
end_frame_idx = start_frame_idx + self.num_frame * interval
|
||||
|
||||
neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))
|
||||
|
||||
# random reverse
|
||||
if self.random_reverse and random.random() < 0.5:
|
||||
neighbor_list.reverse()
|
||||
|
||||
# get the neighboring LQ and GT frames
|
||||
img_lqs = []
|
||||
img_gts = []
|
||||
for neighbor in neighbor_list:
|
||||
if self.is_lmdb:
|
||||
img_lq_path = f'{clip_name}/{neighbor:08d}'
|
||||
img_gt_path = f'{clip_name}/{neighbor:08d}'
|
||||
else:
|
||||
img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
|
||||
img_gt_path = self.gt_root / clip_name / f'{neighbor:08d}.png'
|
||||
|
||||
# get LQ
|
||||
img_bytes = self.file_client.get(img_lq_path, 'lq')
|
||||
img_lq = imfrombytes(img_bytes, float32=True)
|
||||
img_lqs.append(img_lq)
|
||||
|
||||
# get GT
|
||||
img_bytes = self.file_client.get(img_gt_path, 'gt')
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
img_gts.append(img_gt)
|
||||
|
||||
# randomly crop
|
||||
img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
|
||||
|
||||
# augmentation - flip, rotate
|
||||
img_lqs.extend(img_gts)
|
||||
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
img_results = img2tensor(img_results)
|
||||
img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0)
|
||||
img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0)
|
||||
|
||||
# img_lqs: (t, c, h, w)
|
||||
# img_gts: (t, c, h, w)
|
||||
# key: str
|
||||
return {'lq': img_lqs, 'gt': img_gts, 'key': key}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.keys)
|
||||
68
basicsr/data/single_image_dataset.py
Normal file
68
basicsr/data/single_image_dataset.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from os import path as osp
|
||||
from torch.utils import data as data
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from basicsr.data.data_util import paths_from_lmdb
|
||||
from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class SingleImageDataset(data.Dataset):
|
||||
"""Read only lq images in the test phase.
|
||||
|
||||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
|
||||
|
||||
There are two modes:
|
||||
1. 'meta_info_file': Use meta information file to generate paths.
|
||||
2. 'folder': Scan folders to generate paths.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train datasets. It contains the following keys:
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
meta_info_file (str): Path for meta information file.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(SingleImageDataset, self).__init__()
|
||||
self.opt = opt
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
self.mean = opt['mean'] if 'mean' in opt else None
|
||||
self.std = opt['std'] if 'std' in opt else None
|
||||
self.lq_folder = opt['dataroot_lq']
|
||||
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = [self.lq_folder]
|
||||
self.io_backend_opt['client_keys'] = ['lq']
|
||||
self.paths = paths_from_lmdb(self.lq_folder)
|
||||
elif 'meta_info_file' in self.opt:
|
||||
with open(self.opt['meta_info_file'], 'r') as fin:
|
||||
self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin]
|
||||
else:
|
||||
self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
# load lq image
|
||||
lq_path = self.paths[index]
|
||||
img_bytes = self.file_client.get(lq_path, 'lq')
|
||||
img_lq = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# color space transform
|
||||
if 'color' in self.opt and self.opt['color'] == 'y':
|
||||
img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
|
||||
# normalize
|
||||
if self.mean is not None or self.std is not None:
|
||||
normalize(img_lq, self.mean, self.std, inplace=True)
|
||||
return {'lq': img_lq, 'lq_path': lq_path}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
207
basicsr/data/transforms.py
Normal file
207
basicsr/data/transforms.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import cv2
|
||||
import random
|
||||
import torch
|
||||
|
||||
def mod_crop(img, scale):
|
||||
"""Mod crop images, used during testing.
|
||||
|
||||
Args:
|
||||
img (ndarray): Input image.
|
||||
scale (int): Scale factor.
|
||||
|
||||
Returns:
|
||||
ndarray: Result image.
|
||||
"""
|
||||
img = img.copy()
|
||||
if img.ndim in (2, 3):
|
||||
h, w = img.shape[0], img.shape[1]
|
||||
h_remainder, w_remainder = h % scale, w % scale
|
||||
img = img[:h - h_remainder, :w - w_remainder, ...]
|
||||
else:
|
||||
raise ValueError(f'Wrong img ndim: {img.ndim}.')
|
||||
return img
|
||||
|
||||
|
||||
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
|
||||
"""Paired random crop. Support Numpy array and Tensor inputs.
|
||||
|
||||
It crops lists of lq and gt images with corresponding locations.
|
||||
|
||||
Args:
|
||||
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
|
||||
should have the same shape. If the input is an ndarray, it will
|
||||
be transformed to a list containing itself.
|
||||
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
|
||||
should have the same shape. If the input is an ndarray, it will
|
||||
be transformed to a list containing itself.
|
||||
gt_patch_size (int): GT patch size.
|
||||
scale (int): Scale factor.
|
||||
gt_path (str): Path to ground-truth. Default: None.
|
||||
|
||||
Returns:
|
||||
list[ndarray] | ndarray: GT images and LQ images. If returned results
|
||||
only have one element, just return ndarray.
|
||||
"""
|
||||
|
||||
if not isinstance(img_gts, list):
|
||||
img_gts = [img_gts]
|
||||
if not isinstance(img_lqs, list):
|
||||
img_lqs = [img_lqs]
|
||||
|
||||
# determine input type: Numpy array or Tensor
|
||||
input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
|
||||
|
||||
if input_type == 'Tensor':
|
||||
h_lq, w_lq = img_lqs[0].size()[-2:]
|
||||
h_gt, w_gt = img_gts[0].size()[-2:]
|
||||
else:
|
||||
h_lq, w_lq = img_lqs[0].shape[0:2]
|
||||
h_gt, w_gt = img_gts[0].shape[0:2]
|
||||
lq_patch_size = gt_patch_size // scale
|
||||
|
||||
if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
||||
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
||||
f'multiplication of LQ ({h_lq}, {w_lq}).')
|
||||
if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
||||
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
||||
f'({lq_patch_size}, {lq_patch_size}). '
|
||||
f'Please remove {gt_path}.')
|
||||
|
||||
# randomly choose top and left coordinates for lq patch
|
||||
top = random.randint(0, h_lq - lq_patch_size)
|
||||
left = random.randint(0, w_lq - lq_patch_size)
|
||||
|
||||
# crop lq patch
|
||||
if input_type == 'Tensor':
|
||||
img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
|
||||
else:
|
||||
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
|
||||
|
||||
# crop corresponding gt patch
|
||||
top_gt, left_gt = int(top * scale), int(left * scale)
|
||||
if input_type == 'Tensor':
|
||||
img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
|
||||
else:
|
||||
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
|
||||
if len(img_gts) == 1:
|
||||
img_gts = img_gts[0]
|
||||
if len(img_lqs) == 1:
|
||||
img_lqs = img_lqs[0]
|
||||
return img_gts, img_lqs
|
||||
|
||||
|
||||
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
|
||||
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
|
||||
|
||||
We use vertical flip and transpose for rotation implementation.
|
||||
All the images in the list use the same augmentation.
|
||||
|
||||
Args:
|
||||
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
|
||||
is an ndarray, it will be transformed to a list.
|
||||
hflip (bool): Horizontal flip. Default: True.
|
||||
rotation (bool): Ratotation. Default: True.
|
||||
flows (list[ndarray]: Flows to be augmented. If the input is an
|
||||
ndarray, it will be transformed to a list.
|
||||
Dimension is (h, w, 2). Default: None.
|
||||
return_status (bool): Return the status of flip and rotation.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
list[ndarray] | ndarray: Augmented images and flows. If returned
|
||||
results only have one element, just return ndarray.
|
||||
|
||||
"""
|
||||
hflip = hflip and random.random() < 0.5
|
||||
vflip = rotation and random.random() < 0.5
|
||||
rot90 = rotation and random.random() < 0.5
|
||||
|
||||
def _augment(img):
|
||||
if hflip: # horizontal
|
||||
cv2.flip(img, 1, img)
|
||||
if vflip: # vertical
|
||||
cv2.flip(img, 0, img)
|
||||
if rot90:
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
|
||||
def _augment_flow(flow):
|
||||
if hflip: # horizontal
|
||||
cv2.flip(flow, 1, flow)
|
||||
flow[:, :, 0] *= -1
|
||||
if vflip: # vertical
|
||||
cv2.flip(flow, 0, flow)
|
||||
flow[:, :, 1] *= -1
|
||||
if rot90:
|
||||
flow = flow.transpose(1, 0, 2)
|
||||
flow = flow[:, :, [1, 0]]
|
||||
return flow
|
||||
|
||||
if not isinstance(imgs, list):
|
||||
imgs = [imgs]
|
||||
imgs = [_augment(img) for img in imgs]
|
||||
if len(imgs) == 1:
|
||||
imgs = imgs[0]
|
||||
|
||||
if flows is not None:
|
||||
if not isinstance(flows, list):
|
||||
flows = [flows]
|
||||
flows = [_augment_flow(flow) for flow in flows]
|
||||
if len(flows) == 1:
|
||||
flows = flows[0]
|
||||
return imgs, flows
|
||||
else:
|
||||
if return_status:
|
||||
return imgs, (hflip, vflip, rot90)
|
||||
else:
|
||||
return imgs
|
||||
|
||||
|
||||
def img_rotate(img, angle, center=None, scale=1.0):
|
||||
"""Rotate image.
|
||||
|
||||
Args:
|
||||
img (ndarray): Image to be rotated.
|
||||
angle (float): Rotation angle in degrees. Positive values mean
|
||||
counter-clockwise rotation.
|
||||
center (tuple[int]): Rotation center. If the center is None,
|
||||
initialize it as the center of the image. Default: None.
|
||||
scale (float): Isotropic scale factor. Default: 1.0.
|
||||
"""
|
||||
(h, w) = img.shape[:2]
|
||||
|
||||
if center is None:
|
||||
center = (w // 2, h // 2)
|
||||
|
||||
matrix = cv2.getRotationMatrix2D(center, angle, scale)
|
||||
rotated_img = cv2.warpAffine(img, matrix, (w, h))
|
||||
return rotated_img
|
||||
|
||||
def random_crop(im, pch_size):
|
||||
'''
|
||||
Randomly crop a patch from the give image.
|
||||
'''
|
||||
h, w = im.shape[:2]
|
||||
# padding if necessary
|
||||
if h < pch_size or w < pch_size:
|
||||
pad_h = min(max(0, pch_size - h), h)
|
||||
pad_w = min(max(0, pch_size - w), w)
|
||||
im = cv2.copyMakeBorder(im, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
|
||||
|
||||
h, w = im.shape[:2]
|
||||
if h == pch_size:
|
||||
ind_h = 0
|
||||
elif h > pch_size:
|
||||
ind_h = random.randint(0, h-pch_size)
|
||||
else:
|
||||
raise ValueError('Image height is smaller than the patch size')
|
||||
if w == pch_size:
|
||||
ind_w = 0
|
||||
elif w > pch_size:
|
||||
ind_w = random.randint(0, w-pch_size)
|
||||
else:
|
||||
raise ValueError('Image width is smaller than the patch size')
|
||||
|
||||
im_pch = im[ind_h:ind_h+pch_size, ind_w:ind_w+pch_size,]
|
||||
|
||||
return im_pch
|
||||
283
basicsr/data/video_test_dataset.py
Normal file
283
basicsr/data/video_test_dataset.py
Normal file
@@ -0,0 +1,283 @@
|
||||
import glob
|
||||
import torch
|
||||
from os import path as osp
|
||||
from torch.utils import data as data
|
||||
|
||||
from basicsr.data.data_util import duf_downsample, generate_frame_indices, read_img_seq
|
||||
from basicsr.utils import get_root_logger, scandir
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class VideoTestDataset(data.Dataset):
|
||||
"""Video test dataset.
|
||||
|
||||
Supported datasets: Vid4, REDS4, REDSofficial.
|
||||
More generally, it supports testing dataset with following structures:
|
||||
|
||||
::
|
||||
|
||||
dataroot
|
||||
├── subfolder1
|
||||
├── frame000
|
||||
├── frame001
|
||||
├── ...
|
||||
├── subfolder2
|
||||
├── frame000
|
||||
├── frame001
|
||||
├── ...
|
||||
├── ...
|
||||
|
||||
For testing datasets, there is no need to prepare LMDB files.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train dataset. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
cache_data (bool): Whether to cache testing datasets.
|
||||
name (str): Dataset name.
|
||||
meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
|
||||
in the dataroot will be used.
|
||||
num_frame (int): Window size for input frames.
|
||||
padding (str): Padding mode.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(VideoTestDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.cache_data = opt['cache_data']
|
||||
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
|
||||
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
|
||||
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
|
||||
self.imgs_lq, self.imgs_gt = {}, {}
|
||||
if 'meta_info_file' in opt:
|
||||
with open(opt['meta_info_file'], 'r') as fin:
|
||||
subfolders = [line.split(' ')[0] for line in fin]
|
||||
subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders]
|
||||
subfolders_gt = [osp.join(self.gt_root, key) for key in subfolders]
|
||||
else:
|
||||
subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
|
||||
subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*')))
|
||||
|
||||
if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']:
|
||||
for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt):
|
||||
# get frame list for lq and gt
|
||||
subfolder_name = osp.basename(subfolder_lq)
|
||||
img_paths_lq = sorted(list(scandir(subfolder_lq, full_path=True)))
|
||||
img_paths_gt = sorted(list(scandir(subfolder_gt, full_path=True)))
|
||||
|
||||
max_idx = len(img_paths_lq)
|
||||
assert max_idx == len(img_paths_gt), (f'Different number of images in lq ({max_idx})'
|
||||
f' and gt folders ({len(img_paths_gt)})')
|
||||
|
||||
self.data_info['lq_path'].extend(img_paths_lq)
|
||||
self.data_info['gt_path'].extend(img_paths_gt)
|
||||
self.data_info['folder'].extend([subfolder_name] * max_idx)
|
||||
for i in range(max_idx):
|
||||
self.data_info['idx'].append(f'{i}/{max_idx}')
|
||||
border_l = [0] * max_idx
|
||||
for i in range(self.opt['num_frame'] // 2):
|
||||
border_l[i] = 1
|
||||
border_l[max_idx - i - 1] = 1
|
||||
self.data_info['border'].extend(border_l)
|
||||
|
||||
# cache data or save the frame list
|
||||
if self.cache_data:
|
||||
logger.info(f'Cache {subfolder_name} for VideoTestDataset...')
|
||||
self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq)
|
||||
self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt)
|
||||
else:
|
||||
self.imgs_lq[subfolder_name] = img_paths_lq
|
||||
self.imgs_gt[subfolder_name] = img_paths_gt
|
||||
else:
|
||||
raise ValueError(f'Non-supported video test dataset: {type(opt["name"])}')
|
||||
|
||||
def __getitem__(self, index):
|
||||
folder = self.data_info['folder'][index]
|
||||
idx, max_idx = self.data_info['idx'][index].split('/')
|
||||
idx, max_idx = int(idx), int(max_idx)
|
||||
border = self.data_info['border'][index]
|
||||
lq_path = self.data_info['lq_path'][index]
|
||||
|
||||
select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
|
||||
|
||||
if self.cache_data:
|
||||
imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
|
||||
img_gt = self.imgs_gt[folder][idx]
|
||||
else:
|
||||
img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
|
||||
imgs_lq = read_img_seq(img_paths_lq)
|
||||
img_gt = read_img_seq([self.imgs_gt[folder][idx]])
|
||||
img_gt.squeeze_(0)
|
||||
|
||||
return {
|
||||
'lq': imgs_lq, # (t, c, h, w)
|
||||
'gt': img_gt, # (c, h, w)
|
||||
'folder': folder, # folder name
|
||||
'idx': self.data_info['idx'][index], # e.g., 0/99
|
||||
'border': border, # 1 for border, 0 for non-border
|
||||
'lq_path': lq_path # center frame
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_info['gt_path'])
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class VideoTestVimeo90KDataset(data.Dataset):
|
||||
"""Video test dataset for Vimeo90k-Test dataset.
|
||||
|
||||
It only keeps the center frame for testing.
|
||||
For testing datasets, there is no need to prepare LMDB files.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train dataset. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
cache_data (bool): Whether to cache testing datasets.
|
||||
name (str): Dataset name.
|
||||
meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
|
||||
in the dataroot will be used.
|
||||
num_frame (int): Window size for input frames.
|
||||
padding (str): Padding mode.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(VideoTestVimeo90KDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.cache_data = opt['cache_data']
|
||||
if self.cache_data:
|
||||
raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.')
|
||||
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
|
||||
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
|
||||
neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
|
||||
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
|
||||
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
|
||||
with open(opt['meta_info_file'], 'r') as fin:
|
||||
subfolders = [line.split(' ')[0] for line in fin]
|
||||
for idx, subfolder in enumerate(subfolders):
|
||||
gt_path = osp.join(self.gt_root, subfolder, 'im4.png')
|
||||
self.data_info['gt_path'].append(gt_path)
|
||||
lq_paths = [osp.join(self.lq_root, subfolder, f'im{i}.png') for i in neighbor_list]
|
||||
self.data_info['lq_path'].append(lq_paths)
|
||||
self.data_info['folder'].append('vimeo90k')
|
||||
self.data_info['idx'].append(f'{idx}/{len(subfolders)}')
|
||||
self.data_info['border'].append(0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
lq_path = self.data_info['lq_path'][index]
|
||||
gt_path = self.data_info['gt_path'][index]
|
||||
imgs_lq = read_img_seq(lq_path)
|
||||
img_gt = read_img_seq([gt_path])
|
||||
img_gt.squeeze_(0)
|
||||
|
||||
return {
|
||||
'lq': imgs_lq, # (t, c, h, w)
|
||||
'gt': img_gt, # (c, h, w)
|
||||
'folder': self.data_info['folder'][index], # folder name
|
||||
'idx': self.data_info['idx'][index], # e.g., 0/843
|
||||
'border': self.data_info['border'][index], # 0 for non-border
|
||||
'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_info['gt_path'])
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class VideoTestDUFDataset(VideoTestDataset):
|
||||
""" Video test dataset for DUF dataset.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train dataset. Most of keys are the same as VideoTestDataset.
|
||||
It has the following extra keys:
|
||||
use_duf_downsampling (bool): Whether to use duf downsampling to generate low-resolution frames.
|
||||
scale (bool): Scale, which will be added automatically.
|
||||
"""
|
||||
|
||||
def __getitem__(self, index):
|
||||
folder = self.data_info['folder'][index]
|
||||
idx, max_idx = self.data_info['idx'][index].split('/')
|
||||
idx, max_idx = int(idx), int(max_idx)
|
||||
border = self.data_info['border'][index]
|
||||
lq_path = self.data_info['lq_path'][index]
|
||||
|
||||
select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
|
||||
|
||||
if self.cache_data:
|
||||
if self.opt['use_duf_downsampling']:
|
||||
# read imgs_gt to generate low-resolution frames
|
||||
imgs_lq = self.imgs_gt[folder].index_select(0, torch.LongTensor(select_idx))
|
||||
imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
|
||||
else:
|
||||
imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
|
||||
img_gt = self.imgs_gt[folder][idx]
|
||||
else:
|
||||
if self.opt['use_duf_downsampling']:
|
||||
img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx]
|
||||
# read imgs_gt to generate low-resolution frames
|
||||
imgs_lq = read_img_seq(img_paths_lq, require_mod_crop=True, scale=self.opt['scale'])
|
||||
imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
|
||||
else:
|
||||
img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
|
||||
imgs_lq = read_img_seq(img_paths_lq)
|
||||
img_gt = read_img_seq([self.imgs_gt[folder][idx]], require_mod_crop=True, scale=self.opt['scale'])
|
||||
img_gt.squeeze_(0)
|
||||
|
||||
return {
|
||||
'lq': imgs_lq, # (t, c, h, w)
|
||||
'gt': img_gt, # (c, h, w)
|
||||
'folder': folder, # folder name
|
||||
'idx': self.data_info['idx'][index], # e.g., 0/99
|
||||
'border': border, # 1 for border, 0 for non-border
|
||||
'lq_path': lq_path # center frame
|
||||
}
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class VideoRecurrentTestDataset(VideoTestDataset):
|
||||
"""Video test dataset for recurrent architectures, which takes LR video
|
||||
frames as input and output corresponding HR video frames.
|
||||
|
||||
Args:
|
||||
opt (dict): Same as VideoTestDataset. Unused opt:
|
||||
padding (str): Padding mode.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(VideoRecurrentTestDataset, self).__init__(opt)
|
||||
# Find unique folder strings
|
||||
self.folders = sorted(list(set(self.data_info['folder'])))
|
||||
|
||||
def __getitem__(self, index):
|
||||
folder = self.folders[index]
|
||||
|
||||
if self.cache_data:
|
||||
imgs_lq = self.imgs_lq[folder]
|
||||
imgs_gt = self.imgs_gt[folder]
|
||||
else:
|
||||
raise NotImplementedError('Without cache_data is not implemented.')
|
||||
|
||||
return {
|
||||
'lq': imgs_lq,
|
||||
'gt': imgs_gt,
|
||||
'folder': folder,
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.folders)
|
||||
199
basicsr/data/vimeo90k_dataset.py
Normal file
199
basicsr/data/vimeo90k_dataset.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import random
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from torch.utils import data as data
|
||||
|
||||
from basicsr.data.transforms import augment, paired_random_crop
|
||||
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class Vimeo90KDataset(data.Dataset):
|
||||
"""Vimeo90K dataset for training.
|
||||
|
||||
The keys are generated from a meta info txt file.
|
||||
basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
|
||||
|
||||
Each line contains the following items, separated by a white space.
|
||||
|
||||
1. clip name;
|
||||
2. frame number;
|
||||
3. image shape
|
||||
|
||||
Examples:
|
||||
|
||||
::
|
||||
|
||||
00001/0001 7 (256,448,3)
|
||||
00001/0002 7 (256,448,3)
|
||||
|
||||
- Key examples: "00001/0001"
|
||||
- GT (gt): Ground-Truth;
|
||||
- LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
|
||||
|
||||
The neighboring frame list for different num_frame:
|
||||
|
||||
::
|
||||
|
||||
num_frame | frame list
|
||||
1 | 4
|
||||
3 | 3,4,5
|
||||
5 | 2,3,4,5,6
|
||||
7 | 1,2,3,4,5,6,7
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train dataset. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
meta_info_file (str): Path for meta information file.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
num_frame (int): Window size for input frames.
|
||||
gt_size (int): Cropped patched size for gt patches.
|
||||
random_reverse (bool): Random reverse input frames.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
||||
scale (bool): Scale, which will be added automatically.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(Vimeo90KDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
|
||||
|
||||
with open(opt['meta_info_file'], 'r') as fin:
|
||||
self.keys = [line.split(' ')[0] for line in fin]
|
||||
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
self.is_lmdb = False
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.is_lmdb = True
|
||||
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
|
||||
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
||||
|
||||
# indices of input images
|
||||
self.neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
|
||||
|
||||
# temporal augmentation configs
|
||||
self.random_reverse = opt['random_reverse']
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Random reverse is {self.random_reverse}.')
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
# random reverse
|
||||
if self.random_reverse and random.random() < 0.5:
|
||||
self.neighbor_list.reverse()
|
||||
|
||||
scale = self.opt['scale']
|
||||
gt_size = self.opt['gt_size']
|
||||
key = self.keys[index]
|
||||
clip, seq = key.split('/') # key example: 00001/0001
|
||||
|
||||
# get the GT frame (im4.png)
|
||||
if self.is_lmdb:
|
||||
img_gt_path = f'{key}/im4'
|
||||
else:
|
||||
img_gt_path = self.gt_root / clip / seq / 'im4.png'
|
||||
img_bytes = self.file_client.get(img_gt_path, 'gt')
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# get the neighboring LQ frames
|
||||
img_lqs = []
|
||||
for neighbor in self.neighbor_list:
|
||||
if self.is_lmdb:
|
||||
img_lq_path = f'{clip}/{seq}/im{neighbor}'
|
||||
else:
|
||||
img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
|
||||
img_bytes = self.file_client.get(img_lq_path, 'lq')
|
||||
img_lq = imfrombytes(img_bytes, float32=True)
|
||||
img_lqs.append(img_lq)
|
||||
|
||||
# randomly crop
|
||||
img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
|
||||
|
||||
# augmentation - flip, rotate
|
||||
img_lqs.append(img_gt)
|
||||
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
img_results = img2tensor(img_results)
|
||||
img_lqs = torch.stack(img_results[0:-1], dim=0)
|
||||
img_gt = img_results[-1]
|
||||
|
||||
# img_lqs: (t, c, h, w)
|
||||
# img_gt: (c, h, w)
|
||||
# key: str
|
||||
return {'lq': img_lqs, 'gt': img_gt, 'key': key}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.keys)
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class Vimeo90KRecurrentDataset(Vimeo90KDataset):
|
||||
|
||||
def __init__(self, opt):
|
||||
super(Vimeo90KRecurrentDataset, self).__init__(opt)
|
||||
|
||||
self.flip_sequence = opt['flip_sequence']
|
||||
self.neighbor_list = [1, 2, 3, 4, 5, 6, 7]
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
# random reverse
|
||||
if self.random_reverse and random.random() < 0.5:
|
||||
self.neighbor_list.reverse()
|
||||
|
||||
scale = self.opt['scale']
|
||||
gt_size = self.opt['gt_size']
|
||||
key = self.keys[index]
|
||||
clip, seq = key.split('/') # key example: 00001/0001
|
||||
|
||||
# get the neighboring LQ and GT frames
|
||||
img_lqs = []
|
||||
img_gts = []
|
||||
for neighbor in self.neighbor_list:
|
||||
if self.is_lmdb:
|
||||
img_lq_path = f'{clip}/{seq}/im{neighbor}'
|
||||
img_gt_path = f'{clip}/{seq}/im{neighbor}'
|
||||
else:
|
||||
img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
|
||||
img_gt_path = self.gt_root / clip / seq / f'im{neighbor}.png'
|
||||
# LQ
|
||||
img_bytes = self.file_client.get(img_lq_path, 'lq')
|
||||
img_lq = imfrombytes(img_bytes, float32=True)
|
||||
# GT
|
||||
img_bytes = self.file_client.get(img_gt_path, 'gt')
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
img_lqs.append(img_lq)
|
||||
img_gts.append(img_gt)
|
||||
|
||||
# randomly crop
|
||||
img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
|
||||
|
||||
# augmentation - flip, rotate
|
||||
img_lqs.extend(img_gts)
|
||||
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
img_results = img2tensor(img_results)
|
||||
img_lqs = torch.stack(img_results[:7], dim=0)
|
||||
img_gts = torch.stack(img_results[7:], dim=0)
|
||||
|
||||
if self.flip_sequence: # flip the sequence: 7 frames to 14 frames
|
||||
img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0)
|
||||
img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0)
|
||||
|
||||
# img_lqs: (t, c, h, w)
|
||||
# img_gt: (c, h, w)
|
||||
# key: str
|
||||
return {'lq': img_lqs, 'gt': img_gts, 'key': key}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.keys)
|
||||
47
basicsr/utils/__init__.py
Normal file
47
basicsr/utils/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb
|
||||
from .diffjpeg import DiffJPEG
|
||||
from .file_client import FileClient
|
||||
from .img_process_util import USMSharp, usm_sharp
|
||||
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
|
||||
from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
|
||||
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
|
||||
from .options import yaml_load
|
||||
|
||||
__all__ = [
|
||||
# color_util.py
|
||||
'bgr2ycbcr',
|
||||
'rgb2ycbcr',
|
||||
'rgb2ycbcr_pt',
|
||||
'ycbcr2bgr',
|
||||
'ycbcr2rgb',
|
||||
# file_client.py
|
||||
'FileClient',
|
||||
# img_util.py
|
||||
'img2tensor',
|
||||
'tensor2img',
|
||||
'imfrombytes',
|
||||
'imwrite',
|
||||
'crop_border',
|
||||
# logger.py
|
||||
'MessageLogger',
|
||||
'AvgTimer',
|
||||
'init_tb_logger',
|
||||
'init_wandb_logger',
|
||||
'get_root_logger',
|
||||
'get_env_info',
|
||||
# misc.py
|
||||
'set_random_seed',
|
||||
'get_time_str',
|
||||
'mkdir_and_rename',
|
||||
'make_exp_dirs',
|
||||
'scandir',
|
||||
'check_resume',
|
||||
'sizeof_fmt',
|
||||
# diffjpeg
|
||||
'DiffJPEG',
|
||||
# img_process_util
|
||||
'USMSharp',
|
||||
'usm_sharp',
|
||||
# options
|
||||
'yaml_load'
|
||||
]
|
||||
BIN
basicsr/utils/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/color_util.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/color_util.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/color_util.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/color_util.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/diffjpeg.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/diffjpeg.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/diffjpeg.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/diffjpeg.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/dist_util.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/dist_util.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/dist_util.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/dist_util.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/file_client.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/file_client.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/file_client.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/file_client.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/flow_util.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/flow_util.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/flow_util.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/flow_util.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/img_process_util.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/img_process_util.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/img_process_util.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/img_process_util.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/img_util.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/img_util.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/img_util.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/img_util.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/logger.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/logger.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/logger.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/logger.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/misc.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/misc.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/misc.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/misc.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/options.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/options.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/options.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/options.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/registry.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/registry.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/registry.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/registry.cpython-38.pyc
Normal file
Binary file not shown.
208
basicsr/utils/color_util.py
Normal file
208
basicsr/utils/color_util.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def rgb2ycbcr(img, y_only=False):
|
||||
"""Convert a RGB image to YCbCr image.
|
||||
|
||||
This function produces the same results as Matlab's `rgb2ycbcr` function.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
y_only (bool): Whether to only return Y channel. Default: False.
|
||||
|
||||
Returns:
|
||||
ndarray: The converted YCbCr image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img)
|
||||
if y_only:
|
||||
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
|
||||
else:
|
||||
out_img = np.matmul(
|
||||
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def bgr2ycbcr(img, y_only=False):
|
||||
"""Convert a BGR image to YCbCr image.
|
||||
|
||||
The bgr version of rgb2ycbcr.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
y_only (bool): Whether to only return Y channel. Default: False.
|
||||
|
||||
Returns:
|
||||
ndarray: The converted YCbCr image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img)
|
||||
if y_only:
|
||||
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
|
||||
else:
|
||||
out_img = np.matmul(
|
||||
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def ycbcr2rgb(img):
|
||||
"""Convert a YCbCr image to RGB image.
|
||||
|
||||
This function produces the same results as Matlab's ycbcr2rgb function.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
|
||||
Returns:
|
||||
ndarray: The converted RGB image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img) * 255
|
||||
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
||||
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def ycbcr2bgr(img):
|
||||
"""Convert a YCbCr image to BGR image.
|
||||
|
||||
The bgr version of ycbcr2rgb.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
|
||||
Returns:
|
||||
ndarray: The converted BGR image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img) * 255
|
||||
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
|
||||
[0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def _convert_input_type_range(img):
|
||||
"""Convert the type and range of the input image.
|
||||
|
||||
It converts the input image to np.float32 type and range of [0, 1].
|
||||
It is mainly used for pre-processing the input image in colorspace
|
||||
conversion functions such as rgb2ycbcr and ycbcr2rgb.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
|
||||
Returns:
|
||||
(ndarray): The converted image with type of np.float32 and range of
|
||||
[0, 1].
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = img.astype(np.float32)
|
||||
if img_type == np.float32:
|
||||
pass
|
||||
elif img_type == np.uint8:
|
||||
img /= 255.
|
||||
else:
|
||||
raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
|
||||
return img
|
||||
|
||||
|
||||
def _convert_output_type_range(img, dst_type):
|
||||
"""Convert the type and range of the image according to dst_type.
|
||||
|
||||
It converts the image to desired type and range. If `dst_type` is np.uint8,
|
||||
images will be converted to np.uint8 type with range [0, 255]. If
|
||||
`dst_type` is np.float32, it converts the image to np.float32 type with
|
||||
range [0, 1].
|
||||
It is mainly used for post-processing images in colorspace conversion
|
||||
functions such as rgb2ycbcr and ycbcr2rgb.
|
||||
|
||||
Args:
|
||||
img (ndarray): The image to be converted with np.float32 type and
|
||||
range [0, 255].
|
||||
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
|
||||
converts the image to np.uint8 type with range [0, 255]. If
|
||||
dst_type is np.float32, it converts the image to np.float32 type
|
||||
with range [0, 1].
|
||||
|
||||
Returns:
|
||||
(ndarray): The converted image with desired type and range.
|
||||
"""
|
||||
if dst_type not in (np.uint8, np.float32):
|
||||
raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
|
||||
if dst_type == np.uint8:
|
||||
img = img.round()
|
||||
else:
|
||||
img /= 255.
|
||||
return img.astype(dst_type)
|
||||
|
||||
|
||||
def rgb2ycbcr_pt(img, y_only=False):
|
||||
"""Convert RGB images to YCbCr images (PyTorch version).
|
||||
|
||||
It implements the ITU-R BT.601 conversion for standard-definition television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
Args:
|
||||
img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format.
|
||||
y_only (bool): Whether to only return Y channel. Default: False.
|
||||
|
||||
Returns:
|
||||
(Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float.
|
||||
"""
|
||||
if y_only:
|
||||
weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img)
|
||||
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
|
||||
else:
|
||||
weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img)
|
||||
bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img)
|
||||
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
|
||||
|
||||
out_img = out_img / 255.
|
||||
return out_img
|
||||
515
basicsr/utils/diffjpeg.py
Normal file
515
basicsr/utils/diffjpeg.py
Normal file
@@ -0,0 +1,515 @@
|
||||
"""
|
||||
Modified from https://github.com/mlomnitz/DiffJPEG
|
||||
|
||||
For images not divisible by 8
|
||||
https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343
|
||||
"""
|
||||
import itertools
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
# ------------------------ utils ------------------------#
|
||||
y_table = np.array(
|
||||
[[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56],
|
||||
[14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92],
|
||||
[49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
|
||||
dtype=np.float32).T
|
||||
y_table = nn.Parameter(torch.from_numpy(y_table))
|
||||
c_table = np.empty((8, 8), dtype=np.float32)
|
||||
c_table.fill(99)
|
||||
c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T
|
||||
c_table = nn.Parameter(torch.from_numpy(c_table))
|
||||
|
||||
|
||||
def diff_round(x):
|
||||
""" Differentiable rounding function
|
||||
"""
|
||||
return torch.round(x) + (x - torch.round(x))**3
|
||||
|
||||
|
||||
def quality_to_factor(quality):
|
||||
""" Calculate factor corresponding to quality
|
||||
|
||||
Args:
|
||||
quality(float): Quality for jpeg compression.
|
||||
|
||||
Returns:
|
||||
float: Compression factor.
|
||||
"""
|
||||
if quality < 50:
|
||||
quality = 5000. / quality
|
||||
else:
|
||||
quality = 200. - quality * 2
|
||||
return quality / 100.
|
||||
|
||||
|
||||
# ------------------------ compression ------------------------#
|
||||
class RGB2YCbCrJpeg(nn.Module):
|
||||
""" Converts RGB image to YCbCr
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(RGB2YCbCrJpeg, self).__init__()
|
||||
matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]],
|
||||
dtype=np.float32).T
|
||||
self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
|
||||
self.matrix = nn.Parameter(torch.from_numpy(matrix))
|
||||
|
||||
def forward(self, image):
|
||||
"""
|
||||
Args:
|
||||
image(Tensor): batch x 3 x height x width
|
||||
|
||||
Returns:
|
||||
Tensor: batch x height x width x 3
|
||||
"""
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
result = torch.tensordot(image, self.matrix, dims=1) + self.shift
|
||||
return result.view(image.shape)
|
||||
|
||||
|
||||
class ChromaSubsampling(nn.Module):
|
||||
""" Chroma subsampling on CbCr channels
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ChromaSubsampling, self).__init__()
|
||||
|
||||
def forward(self, image):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x height x width x 3
|
||||
|
||||
Returns:
|
||||
y(tensor): batch x height x width
|
||||
cb(tensor): batch x height/2 x width/2
|
||||
cr(tensor): batch x height/2 x width/2
|
||||
"""
|
||||
image_2 = image.permute(0, 3, 1, 2).clone()
|
||||
cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
|
||||
cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
|
||||
cb = cb.permute(0, 2, 3, 1)
|
||||
cr = cr.permute(0, 2, 3, 1)
|
||||
return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
|
||||
|
||||
|
||||
class BlockSplitting(nn.Module):
|
||||
""" Splitting image into patches
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(BlockSplitting, self).__init__()
|
||||
self.k = 8
|
||||
|
||||
def forward(self, image):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x height x width
|
||||
|
||||
Returns:
|
||||
Tensor: batch x h*w/64 x h x w
|
||||
"""
|
||||
height, _ = image.shape[1:3]
|
||||
batch_size = image.shape[0]
|
||||
image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
|
||||
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
|
||||
return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
|
||||
|
||||
|
||||
class DCT8x8(nn.Module):
|
||||
""" Discrete Cosine Transformation
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(DCT8x8, self).__init__()
|
||||
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
|
||||
for x, y, u, v in itertools.product(range(8), repeat=4):
|
||||
tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)
|
||||
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
|
||||
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
|
||||
self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())
|
||||
|
||||
def forward(self, image):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x height x width
|
||||
|
||||
Returns:
|
||||
Tensor: batch x height x width
|
||||
"""
|
||||
image = image - 128
|
||||
result = self.scale * torch.tensordot(image, self.tensor, dims=2)
|
||||
result.view(image.shape)
|
||||
return result
|
||||
|
||||
|
||||
class YQuantize(nn.Module):
|
||||
""" JPEG Quantization for Y channel
|
||||
|
||||
Args:
|
||||
rounding(function): rounding function to use
|
||||
"""
|
||||
|
||||
def __init__(self, rounding):
|
||||
super(YQuantize, self).__init__()
|
||||
self.rounding = rounding
|
||||
self.y_table = y_table
|
||||
|
||||
def forward(self, image, factor=1):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x height x width
|
||||
|
||||
Returns:
|
||||
Tensor: batch x height x width
|
||||
"""
|
||||
if isinstance(factor, (int, float)):
|
||||
image = image.float() / (self.y_table * factor)
|
||||
else:
|
||||
b = factor.size(0)
|
||||
table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
|
||||
image = image.float() / table
|
||||
image = self.rounding(image)
|
||||
return image
|
||||
|
||||
|
||||
class CQuantize(nn.Module):
|
||||
""" JPEG Quantization for CbCr channels
|
||||
|
||||
Args:
|
||||
rounding(function): rounding function to use
|
||||
"""
|
||||
|
||||
def __init__(self, rounding):
|
||||
super(CQuantize, self).__init__()
|
||||
self.rounding = rounding
|
||||
self.c_table = c_table
|
||||
|
||||
def forward(self, image, factor=1):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x height x width
|
||||
|
||||
Returns:
|
||||
Tensor: batch x height x width
|
||||
"""
|
||||
if isinstance(factor, (int, float)):
|
||||
image = image.float() / (self.c_table * factor)
|
||||
else:
|
||||
b = factor.size(0)
|
||||
table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
|
||||
image = image.float() / table
|
||||
image = self.rounding(image)
|
||||
return image
|
||||
|
||||
|
||||
class CompressJpeg(nn.Module):
|
||||
"""Full JPEG compression algorithm
|
||||
|
||||
Args:
|
||||
rounding(function): rounding function to use
|
||||
"""
|
||||
|
||||
def __init__(self, rounding=torch.round):
|
||||
super(CompressJpeg, self).__init__()
|
||||
self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling())
|
||||
self.l2 = nn.Sequential(BlockSplitting(), DCT8x8())
|
||||
self.c_quantize = CQuantize(rounding=rounding)
|
||||
self.y_quantize = YQuantize(rounding=rounding)
|
||||
|
||||
def forward(self, image, factor=1):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x 3 x height x width
|
||||
|
||||
Returns:
|
||||
dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8.
|
||||
"""
|
||||
y, cb, cr = self.l1(image * 255)
|
||||
components = {'y': y, 'cb': cb, 'cr': cr}
|
||||
for k in components.keys():
|
||||
comp = self.l2(components[k])
|
||||
if k in ('cb', 'cr'):
|
||||
comp = self.c_quantize(comp, factor=factor)
|
||||
else:
|
||||
comp = self.y_quantize(comp, factor=factor)
|
||||
|
||||
components[k] = comp
|
||||
|
||||
return components['y'], components['cb'], components['cr']
|
||||
|
||||
|
||||
# ------------------------ decompression ------------------------#
|
||||
|
||||
|
||||
class YDequantize(nn.Module):
|
||||
"""Dequantize Y channel
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(YDequantize, self).__init__()
|
||||
self.y_table = y_table
|
||||
|
||||
def forward(self, image, factor=1):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x height x width
|
||||
|
||||
Returns:
|
||||
Tensor: batch x height x width
|
||||
"""
|
||||
if isinstance(factor, (int, float)):
|
||||
out = image * (self.y_table * factor)
|
||||
else:
|
||||
b = factor.size(0)
|
||||
table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
|
||||
out = image * table
|
||||
return out
|
||||
|
||||
|
||||
class CDequantize(nn.Module):
|
||||
"""Dequantize CbCr channel
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(CDequantize, self).__init__()
|
||||
self.c_table = c_table
|
||||
|
||||
def forward(self, image, factor=1):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x height x width
|
||||
|
||||
Returns:
|
||||
Tensor: batch x height x width
|
||||
"""
|
||||
if isinstance(factor, (int, float)):
|
||||
out = image * (self.c_table * factor)
|
||||
else:
|
||||
b = factor.size(0)
|
||||
table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
|
||||
out = image * table
|
||||
return out
|
||||
|
||||
|
||||
class iDCT8x8(nn.Module):
|
||||
"""Inverse discrete Cosine Transformation
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(iDCT8x8, self).__init__()
|
||||
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
|
||||
self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
|
||||
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
|
||||
for x, y, u, v in itertools.product(range(8), repeat=4):
|
||||
tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16)
|
||||
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
|
||||
|
||||
def forward(self, image):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x height x width
|
||||
|
||||
Returns:
|
||||
Tensor: batch x height x width
|
||||
"""
|
||||
image = image * self.alpha
|
||||
result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
|
||||
result.view(image.shape)
|
||||
return result
|
||||
|
||||
|
||||
class BlockMerging(nn.Module):
|
||||
"""Merge patches into image
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(BlockMerging, self).__init__()
|
||||
|
||||
def forward(self, patches, height, width):
|
||||
"""
|
||||
Args:
|
||||
patches(tensor) batch x height*width/64, height x width
|
||||
height(int)
|
||||
width(int)
|
||||
|
||||
Returns:
|
||||
Tensor: batch x height x width
|
||||
"""
|
||||
k = 8
|
||||
batch_size = patches.shape[0]
|
||||
image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
|
||||
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
|
||||
return image_transposed.contiguous().view(batch_size, height, width)
|
||||
|
||||
|
||||
class ChromaUpsampling(nn.Module):
|
||||
"""Upsample chroma layers
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ChromaUpsampling, self).__init__()
|
||||
|
||||
def forward(self, y, cb, cr):
|
||||
"""
|
||||
Args:
|
||||
y(tensor): y channel image
|
||||
cb(tensor): cb channel
|
||||
cr(tensor): cr channel
|
||||
|
||||
Returns:
|
||||
Tensor: batch x height x width x 3
|
||||
"""
|
||||
|
||||
def repeat(x, k=2):
|
||||
height, width = x.shape[1:3]
|
||||
x = x.unsqueeze(-1)
|
||||
x = x.repeat(1, 1, k, k)
|
||||
x = x.view(-1, height * k, width * k)
|
||||
return x
|
||||
|
||||
cb = repeat(cb)
|
||||
cr = repeat(cr)
|
||||
return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
|
||||
|
||||
|
||||
class YCbCr2RGBJpeg(nn.Module):
|
||||
"""Converts YCbCr image to RGB JPEG
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(YCbCr2RGBJpeg, self).__init__()
|
||||
|
||||
matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T
|
||||
self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
|
||||
self.matrix = nn.Parameter(torch.from_numpy(matrix))
|
||||
|
||||
def forward(self, image):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x height x width x 3
|
||||
|
||||
Returns:
|
||||
Tensor: batch x 3 x height x width
|
||||
"""
|
||||
result = torch.tensordot(image + self.shift, self.matrix, dims=1)
|
||||
return result.view(image.shape).permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
class DeCompressJpeg(nn.Module):
|
||||
"""Full JPEG decompression algorithm
|
||||
|
||||
Args:
|
||||
rounding(function): rounding function to use
|
||||
"""
|
||||
|
||||
def __init__(self, rounding=torch.round):
|
||||
super(DeCompressJpeg, self).__init__()
|
||||
self.c_dequantize = CDequantize()
|
||||
self.y_dequantize = YDequantize()
|
||||
self.idct = iDCT8x8()
|
||||
self.merging = BlockMerging()
|
||||
self.chroma = ChromaUpsampling()
|
||||
self.colors = YCbCr2RGBJpeg()
|
||||
|
||||
def forward(self, y, cb, cr, imgh, imgw, factor=1):
|
||||
"""
|
||||
Args:
|
||||
compressed(dict(tensor)): batch x h*w/64 x 8 x 8
|
||||
imgh(int)
|
||||
imgw(int)
|
||||
factor(float)
|
||||
|
||||
Returns:
|
||||
Tensor: batch x 3 x height x width
|
||||
"""
|
||||
components = {'y': y, 'cb': cb, 'cr': cr}
|
||||
for k in components.keys():
|
||||
if k in ('cb', 'cr'):
|
||||
comp = self.c_dequantize(components[k], factor=factor)
|
||||
height, width = int(imgh / 2), int(imgw / 2)
|
||||
else:
|
||||
comp = self.y_dequantize(components[k], factor=factor)
|
||||
height, width = imgh, imgw
|
||||
comp = self.idct(comp)
|
||||
components[k] = self.merging(comp, height, width)
|
||||
#
|
||||
image = self.chroma(components['y'], components['cb'], components['cr'])
|
||||
image = self.colors(image)
|
||||
|
||||
image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image))
|
||||
return image / 255
|
||||
|
||||
|
||||
# ------------------------ main DiffJPEG ------------------------ #
|
||||
|
||||
|
||||
class DiffJPEG(nn.Module):
|
||||
"""This JPEG algorithm result is slightly different from cv2.
|
||||
DiffJPEG supports batch processing.
|
||||
|
||||
Args:
|
||||
differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round
|
||||
"""
|
||||
|
||||
def __init__(self, differentiable=True):
|
||||
super(DiffJPEG, self).__init__()
|
||||
if differentiable:
|
||||
rounding = diff_round
|
||||
else:
|
||||
rounding = torch.round
|
||||
|
||||
self.compress = CompressJpeg(rounding=rounding)
|
||||
self.decompress = DeCompressJpeg(rounding=rounding)
|
||||
|
||||
def forward(self, x, quality):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): Input image, bchw, rgb, [0, 1]
|
||||
quality(float): Quality factor for jpeg compression scheme.
|
||||
"""
|
||||
factor = quality
|
||||
if isinstance(factor, (int, float)):
|
||||
factor = quality_to_factor(factor)
|
||||
else:
|
||||
for i in range(factor.size(0)):
|
||||
factor[i] = quality_to_factor(factor[i])
|
||||
h, w = x.size()[-2:]
|
||||
h_pad, w_pad = 0, 0
|
||||
# why should use 16
|
||||
if h % 16 != 0:
|
||||
h_pad = 16 - h % 16
|
||||
if w % 16 != 0:
|
||||
w_pad = 16 - w % 16
|
||||
x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0)
|
||||
|
||||
y, cb, cr = self.compress(x, factor=factor)
|
||||
recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor)
|
||||
recovered = recovered[:, :, 0:h, 0:w]
|
||||
return recovered
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import cv2
|
||||
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
|
||||
img_gt = cv2.imread('test.png') / 255.
|
||||
|
||||
# -------------- cv2 -------------- #
|
||||
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20]
|
||||
_, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param)
|
||||
img_lq = np.float32(cv2.imdecode(encimg, 1))
|
||||
cv2.imwrite('cv2_JPEG_20.png', img_lq)
|
||||
|
||||
# -------------- DiffJPEG -------------- #
|
||||
jpeger = DiffJPEG(differentiable=False).cuda()
|
||||
img_gt = img2tensor(img_gt)
|
||||
img_gt = torch.stack([img_gt, img_gt]).cuda()
|
||||
quality = img_gt.new_tensor([20, 40])
|
||||
out = jpeger(img_gt, quality=quality)
|
||||
|
||||
cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0]))
|
||||
cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1]))
|
||||
82
basicsr/utils/dist_util.py
Normal file
82
basicsr/utils/dist_util.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
|
||||
import functools
|
||||
import os
|
||||
import subprocess
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
|
||||
def init_dist(launcher, backend='nccl', **kwargs):
|
||||
if mp.get_start_method(allow_none=True) is None:
|
||||
mp.set_start_method('spawn')
|
||||
if launcher == 'pytorch':
|
||||
_init_dist_pytorch(backend, **kwargs)
|
||||
elif launcher == 'slurm':
|
||||
_init_dist_slurm(backend, **kwargs)
|
||||
else:
|
||||
raise ValueError(f'Invalid launcher type: {launcher}')
|
||||
|
||||
|
||||
def _init_dist_pytorch(backend, **kwargs):
|
||||
rank = int(os.environ['RANK'])
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(rank % num_gpus)
|
||||
dist.init_process_group(backend=backend, **kwargs)
|
||||
|
||||
|
||||
def _init_dist_slurm(backend, port=None):
|
||||
"""Initialize slurm distributed training environment.
|
||||
|
||||
If argument ``port`` is not specified, then the master port will be system
|
||||
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
|
||||
environment variable, then a default port ``29500`` will be used.
|
||||
|
||||
Args:
|
||||
backend (str): Backend of torch.distributed.
|
||||
port (int, optional): Master port. Defaults to None.
|
||||
"""
|
||||
proc_id = int(os.environ['SLURM_PROCID'])
|
||||
ntasks = int(os.environ['SLURM_NTASKS'])
|
||||
node_list = os.environ['SLURM_NODELIST']
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(proc_id % num_gpus)
|
||||
addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
|
||||
# specify master port
|
||||
if port is not None:
|
||||
os.environ['MASTER_PORT'] = str(port)
|
||||
elif 'MASTER_PORT' in os.environ:
|
||||
pass # use MASTER_PORT in the environment variable
|
||||
else:
|
||||
# 29500 is torch.distributed default port
|
||||
os.environ['MASTER_PORT'] = '29500'
|
||||
os.environ['MASTER_ADDR'] = addr
|
||||
os.environ['WORLD_SIZE'] = str(ntasks)
|
||||
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
||||
os.environ['RANK'] = str(proc_id)
|
||||
dist.init_process_group(backend=backend)
|
||||
|
||||
|
||||
def get_dist_info():
|
||||
if dist.is_available():
|
||||
initialized = dist.is_initialized()
|
||||
else:
|
||||
initialized = False
|
||||
if initialized:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
return rank, world_size
|
||||
|
||||
|
||||
def master_only(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
98
basicsr/utils/download_util.py
Normal file
98
basicsr/utils/download_util.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import math
|
||||
import os
|
||||
import requests
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
from tqdm import tqdm
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from .misc import sizeof_fmt
|
||||
|
||||
|
||||
def download_file_from_google_drive(file_id, save_path):
|
||||
"""Download files from google drive.
|
||||
|
||||
Reference: https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive
|
||||
|
||||
Args:
|
||||
file_id (str): File id.
|
||||
save_path (str): Save path.
|
||||
"""
|
||||
|
||||
session = requests.Session()
|
||||
URL = 'https://docs.google.com/uc?export=download'
|
||||
params = {'id': file_id}
|
||||
|
||||
response = session.get(URL, params=params, stream=True)
|
||||
token = get_confirm_token(response)
|
||||
if token:
|
||||
params['confirm'] = token
|
||||
response = session.get(URL, params=params, stream=True)
|
||||
|
||||
# get file size
|
||||
response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
|
||||
if 'Content-Range' in response_file_size.headers:
|
||||
file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
|
||||
else:
|
||||
file_size = None
|
||||
|
||||
save_response_content(response, save_path, file_size)
|
||||
|
||||
|
||||
def get_confirm_token(response):
|
||||
for key, value in response.cookies.items():
|
||||
if key.startswith('download_warning'):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def save_response_content(response, destination, file_size=None, chunk_size=32768):
|
||||
if file_size is not None:
|
||||
pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
|
||||
|
||||
readable_file_size = sizeof_fmt(file_size)
|
||||
else:
|
||||
pbar = None
|
||||
|
||||
with open(destination, 'wb') as f:
|
||||
downloaded_size = 0
|
||||
for chunk in response.iter_content(chunk_size):
|
||||
downloaded_size += chunk_size
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
if pbar is not None:
|
||||
pbar.close()
|
||||
|
||||
|
||||
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
||||
"""Load file form http url, will download models if necessary.
|
||||
|
||||
Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||
|
||||
Args:
|
||||
url (str): URL to be downloaded.
|
||||
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
|
||||
Default: None.
|
||||
progress (bool): Whether to show the download progress. Default: True.
|
||||
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
|
||||
|
||||
Returns:
|
||||
str: The path to the downloaded file.
|
||||
"""
|
||||
if model_dir is None: # use the pytorch hub_dir
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
if file_name is not None:
|
||||
filename = file_name
|
||||
cached_file = os.path.abspath(os.path.join(model_dir, filename))
|
||||
if not os.path.exists(cached_file):
|
||||
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
||||
return cached_file
|
||||
167
basicsr/utils/file_client.py
Normal file
167
basicsr/utils/file_client.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class BaseStorageBackend(metaclass=ABCMeta):
|
||||
"""Abstract class of storage backends.
|
||||
|
||||
All backends need to implement two apis: ``get()`` and ``get_text()``.
|
||||
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
|
||||
as texts.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, filepath):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_text(self, filepath):
|
||||
pass
|
||||
|
||||
|
||||
class MemcachedBackend(BaseStorageBackend):
|
||||
"""Memcached storage backend.
|
||||
|
||||
Attributes:
|
||||
server_list_cfg (str): Config file for memcached server list.
|
||||
client_cfg (str): Config file for memcached client.
|
||||
sys_path (str | None): Additional path to be appended to `sys.path`.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
|
||||
if sys_path is not None:
|
||||
import sys
|
||||
sys.path.append(sys_path)
|
||||
try:
|
||||
import mc
|
||||
except ImportError:
|
||||
raise ImportError('Please install memcached to enable MemcachedBackend.')
|
||||
|
||||
self.server_list_cfg = server_list_cfg
|
||||
self.client_cfg = client_cfg
|
||||
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
|
||||
# mc.pyvector servers as a point which points to a memory cache
|
||||
self._mc_buffer = mc.pyvector()
|
||||
|
||||
def get(self, filepath):
|
||||
filepath = str(filepath)
|
||||
import mc
|
||||
self._client.Get(filepath, self._mc_buffer)
|
||||
value_buf = mc.ConvertBuffer(self._mc_buffer)
|
||||
return value_buf
|
||||
|
||||
def get_text(self, filepath):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HardDiskBackend(BaseStorageBackend):
|
||||
"""Raw hard disks storage backend."""
|
||||
|
||||
def get(self, filepath):
|
||||
filepath = str(filepath)
|
||||
with open(filepath, 'rb') as f:
|
||||
value_buf = f.read()
|
||||
return value_buf
|
||||
|
||||
def get_text(self, filepath):
|
||||
filepath = str(filepath)
|
||||
with open(filepath, 'r') as f:
|
||||
value_buf = f.read()
|
||||
return value_buf
|
||||
|
||||
|
||||
class LmdbBackend(BaseStorageBackend):
|
||||
"""Lmdb storage backend.
|
||||
|
||||
Args:
|
||||
db_paths (str | list[str]): Lmdb database paths.
|
||||
client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
|
||||
readonly (bool, optional): Lmdb environment parameter. If True,
|
||||
disallow any write operations. Default: True.
|
||||
lock (bool, optional): Lmdb environment parameter. If False, when
|
||||
concurrent access occurs, do not lock the database. Default: False.
|
||||
readahead (bool, optional): Lmdb environment parameter. If False,
|
||||
disable the OS filesystem readahead mechanism, which may improve
|
||||
random read performance when a database is larger than RAM.
|
||||
Default: False.
|
||||
|
||||
Attributes:
|
||||
db_paths (list): Lmdb database path.
|
||||
_client (list): A list of several lmdb envs.
|
||||
"""
|
||||
|
||||
def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
|
||||
try:
|
||||
import lmdb
|
||||
except ImportError:
|
||||
raise ImportError('Please install lmdb to enable LmdbBackend.')
|
||||
|
||||
if isinstance(client_keys, str):
|
||||
client_keys = [client_keys]
|
||||
|
||||
if isinstance(db_paths, list):
|
||||
self.db_paths = [str(v) for v in db_paths]
|
||||
elif isinstance(db_paths, str):
|
||||
self.db_paths = [str(db_paths)]
|
||||
assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
|
||||
f'but received {len(client_keys)} and {len(self.db_paths)}.')
|
||||
|
||||
self._client = {}
|
||||
for client, path in zip(client_keys, self.db_paths):
|
||||
self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
|
||||
|
||||
def get(self, filepath, client_key):
|
||||
"""Get values according to the filepath from one lmdb named client_key.
|
||||
|
||||
Args:
|
||||
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
|
||||
client_key (str): Used for distinguishing different lmdb envs.
|
||||
"""
|
||||
filepath = str(filepath)
|
||||
assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.')
|
||||
client = self._client[client_key]
|
||||
with client.begin(write=False) as txn:
|
||||
value_buf = txn.get(filepath.encode('ascii'))
|
||||
return value_buf
|
||||
|
||||
def get_text(self, filepath):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FileClient(object):
|
||||
"""A general file client to access files in different backend.
|
||||
|
||||
The client loads a file or text in a specified backend from its path
|
||||
and return it as a binary file. it can also register other backend
|
||||
accessor with a given name and backend class.
|
||||
|
||||
Attributes:
|
||||
backend (str): The storage backend type. Options are "disk",
|
||||
"memcached" and "lmdb".
|
||||
client (:obj:`BaseStorageBackend`): The backend object.
|
||||
"""
|
||||
|
||||
_backends = {
|
||||
'disk': HardDiskBackend,
|
||||
'memcached': MemcachedBackend,
|
||||
'lmdb': LmdbBackend,
|
||||
}
|
||||
|
||||
def __init__(self, backend='disk', **kwargs):
|
||||
if backend not in self._backends:
|
||||
raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
|
||||
f' are {list(self._backends.keys())}')
|
||||
self.backend = backend
|
||||
self.client = self._backends[backend](**kwargs)
|
||||
|
||||
def get(self, filepath, client_key='default'):
|
||||
# client_key is used only for lmdb, where different fileclients have
|
||||
# different lmdb environments.
|
||||
if self.backend == 'lmdb':
|
||||
return self.client.get(filepath, client_key)
|
||||
else:
|
||||
return self.client.get(filepath)
|
||||
|
||||
def get_text(self, filepath):
|
||||
return self.client.get_text(filepath)
|
||||
170
basicsr/utils/flow_util.py
Normal file
170
basicsr/utils/flow_util.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
|
||||
"""Read an optical flow map.
|
||||
|
||||
Args:
|
||||
flow_path (ndarray or str): Flow path.
|
||||
quantize (bool): whether to read quantized pair, if set to True,
|
||||
remaining args will be passed to :func:`dequantize_flow`.
|
||||
concat_axis (int): The axis that dx and dy are concatenated,
|
||||
can be either 0 or 1. Ignored if quantize is False.
|
||||
|
||||
Returns:
|
||||
ndarray: Optical flow represented as a (h, w, 2) numpy array
|
||||
"""
|
||||
if quantize:
|
||||
assert concat_axis in [0, 1]
|
||||
cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
|
||||
if cat_flow.ndim != 2:
|
||||
raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.')
|
||||
assert cat_flow.shape[concat_axis] % 2 == 0
|
||||
dx, dy = np.split(cat_flow, 2, axis=concat_axis)
|
||||
flow = dequantize_flow(dx, dy, *args, **kwargs)
|
||||
else:
|
||||
with open(flow_path, 'rb') as f:
|
||||
try:
|
||||
header = f.read(4).decode('utf-8')
|
||||
except Exception:
|
||||
raise IOError(f'Invalid flow file: {flow_path}')
|
||||
else:
|
||||
if header != 'PIEH':
|
||||
raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH')
|
||||
|
||||
w = np.fromfile(f, np.int32, 1).squeeze()
|
||||
h = np.fromfile(f, np.int32, 1).squeeze()
|
||||
flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
|
||||
|
||||
return flow.astype(np.float32)
|
||||
|
||||
|
||||
def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
|
||||
"""Write optical flow to file.
|
||||
|
||||
If the flow is not quantized, it will be saved as a .flo file losslessly,
|
||||
otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
|
||||
will be concatenated horizontally into a single image if quantize is True.)
|
||||
|
||||
Args:
|
||||
flow (ndarray): (h, w, 2) array of optical flow.
|
||||
filename (str): Output filepath.
|
||||
quantize (bool): Whether to quantize the flow and save it to 2 jpeg
|
||||
images. If set to True, remaining args will be passed to
|
||||
:func:`quantize_flow`.
|
||||
concat_axis (int): The axis that dx and dy are concatenated,
|
||||
can be either 0 or 1. Ignored if quantize is False.
|
||||
"""
|
||||
if not quantize:
|
||||
with open(filename, 'wb') as f:
|
||||
f.write('PIEH'.encode('utf-8'))
|
||||
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
|
||||
flow = flow.astype(np.float32)
|
||||
flow.tofile(f)
|
||||
f.flush()
|
||||
else:
|
||||
assert concat_axis in [0, 1]
|
||||
dx, dy = quantize_flow(flow, *args, **kwargs)
|
||||
dxdy = np.concatenate((dx, dy), axis=concat_axis)
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
cv2.imwrite(filename, dxdy)
|
||||
|
||||
|
||||
def quantize_flow(flow, max_val=0.02, norm=True):
|
||||
"""Quantize flow to [0, 255].
|
||||
|
||||
After this step, the size of flow will be much smaller, and can be
|
||||
dumped as jpeg images.
|
||||
|
||||
Args:
|
||||
flow (ndarray): (h, w, 2) array of optical flow.
|
||||
max_val (float): Maximum value of flow, values beyond
|
||||
[-max_val, max_val] will be truncated.
|
||||
norm (bool): Whether to divide flow values by image width/height.
|
||||
|
||||
Returns:
|
||||
tuple[ndarray]: Quantized dx and dy.
|
||||
"""
|
||||
h, w, _ = flow.shape
|
||||
dx = flow[..., 0]
|
||||
dy = flow[..., 1]
|
||||
if norm:
|
||||
dx = dx / w # avoid inplace operations
|
||||
dy = dy / h
|
||||
# use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
|
||||
flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]]
|
||||
return tuple(flow_comps)
|
||||
|
||||
|
||||
def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
|
||||
"""Recover from quantized flow.
|
||||
|
||||
Args:
|
||||
dx (ndarray): Quantized dx.
|
||||
dy (ndarray): Quantized dy.
|
||||
max_val (float): Maximum value used when quantizing.
|
||||
denorm (bool): Whether to multiply flow values with width/height.
|
||||
|
||||
Returns:
|
||||
ndarray: Dequantized flow.
|
||||
"""
|
||||
assert dx.shape == dy.shape
|
||||
assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
|
||||
|
||||
dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
|
||||
|
||||
if denorm:
|
||||
dx *= dx.shape[1]
|
||||
dy *= dx.shape[0]
|
||||
flow = np.dstack((dx, dy))
|
||||
return flow
|
||||
|
||||
|
||||
def quantize(arr, min_val, max_val, levels, dtype=np.int64):
|
||||
"""Quantize an array of (-inf, inf) to [0, levels-1].
|
||||
|
||||
Args:
|
||||
arr (ndarray): Input array.
|
||||
min_val (scalar): Minimum value to be clipped.
|
||||
max_val (scalar): Maximum value to be clipped.
|
||||
levels (int): Quantization levels.
|
||||
dtype (np.type): The type of the quantized array.
|
||||
|
||||
Returns:
|
||||
tuple: Quantized array.
|
||||
"""
|
||||
if not (isinstance(levels, int) and levels > 1):
|
||||
raise ValueError(f'levels must be a positive integer, but got {levels}')
|
||||
if min_val >= max_val:
|
||||
raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
|
||||
|
||||
arr = np.clip(arr, min_val, max_val) - min_val
|
||||
quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
|
||||
|
||||
return quantized_arr
|
||||
|
||||
|
||||
def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
|
||||
"""Dequantize an array.
|
||||
|
||||
Args:
|
||||
arr (ndarray): Input array.
|
||||
min_val (scalar): Minimum value to be clipped.
|
||||
max_val (scalar): Maximum value to be clipped.
|
||||
levels (int): Quantization levels.
|
||||
dtype (np.type): The type of the dequantized array.
|
||||
|
||||
Returns:
|
||||
tuple: Dequantized array.
|
||||
"""
|
||||
if not (isinstance(levels, int) and levels > 1):
|
||||
raise ValueError(f'levels must be a positive integer, but got {levels}')
|
||||
if min_val >= max_val:
|
||||
raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
|
||||
|
||||
dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val
|
||||
|
||||
return dequantized_arr
|
||||
83
basicsr/utils/img_process_util.py
Normal file
83
basicsr/utils/img_process_util.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def filter2D(img, kernel):
|
||||
"""PyTorch version of cv2.filter2D
|
||||
|
||||
Args:
|
||||
img (Tensor): (b, c, h, w)
|
||||
kernel (Tensor): (b, k, k)
|
||||
"""
|
||||
k = kernel.size(-1)
|
||||
b, c, h, w = img.size()
|
||||
if k % 2 == 1:
|
||||
img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
|
||||
else:
|
||||
raise ValueError('Wrong kernel size')
|
||||
|
||||
ph, pw = img.size()[-2:]
|
||||
|
||||
if kernel.size(0) == 1:
|
||||
# apply the same kernel to all batch images
|
||||
img = img.view(b * c, 1, ph, pw)
|
||||
kernel = kernel.view(1, 1, k, k)
|
||||
return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
|
||||
else:
|
||||
img = img.view(1, b * c, ph, pw)
|
||||
kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
|
||||
return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
|
||||
|
||||
|
||||
def usm_sharp(img, weight=0.5, radius=50, threshold=10):
|
||||
"""USM sharpening.
|
||||
|
||||
Input image: I; Blurry image: B.
|
||||
1. sharp = I + weight * (I - B)
|
||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
||||
3. Blur mask:
|
||||
4. Out = Mask * sharp + (1 - Mask) * I
|
||||
|
||||
|
||||
Args:
|
||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
||||
weight (float): Sharp weight. Default: 1.
|
||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
||||
threshold (int):
|
||||
"""
|
||||
if radius % 2 == 0:
|
||||
radius += 1
|
||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
||||
residual = img - blur
|
||||
mask = np.abs(residual) * 255 > threshold
|
||||
mask = mask.astype('float32')
|
||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
||||
|
||||
sharp = img + weight * residual
|
||||
sharp = np.clip(sharp, 0, 1)
|
||||
return soft_mask * sharp + (1 - soft_mask) * img
|
||||
|
||||
|
||||
class USMSharp(torch.nn.Module):
|
||||
|
||||
def __init__(self, radius=50, sigma=0):
|
||||
super(USMSharp, self).__init__()
|
||||
if radius % 2 == 0:
|
||||
radius += 1
|
||||
self.radius = radius
|
||||
kernel = cv2.getGaussianKernel(radius, sigma)
|
||||
kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
|
||||
self.register_buffer('kernel', kernel)
|
||||
|
||||
def forward(self, img, weight=0.5, threshold=10):
|
||||
blur = filter2D(img, self.kernel)
|
||||
residual = img - blur
|
||||
|
||||
mask = torch.abs(residual) * 255 > threshold
|
||||
mask = mask.float()
|
||||
soft_mask = filter2D(mask, self.kernel)
|
||||
sharp = img + weight * residual
|
||||
sharp = torch.clip(sharp, 0, 1)
|
||||
return soft_mask * sharp + (1 - soft_mask) * img
|
||||
172
basicsr/utils/img_util.py
Normal file
172
basicsr/utils/img_util.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
|
||||
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
||||
"""Numpy array to tensor.
|
||||
|
||||
Args:
|
||||
imgs (list[ndarray] | ndarray): Input images.
|
||||
bgr2rgb (bool): Whether to change bgr to rgb.
|
||||
float32 (bool): Whether to change to float32.
|
||||
|
||||
Returns:
|
||||
list[tensor] | tensor: Tensor images. If returned results only have
|
||||
one element, just return tensor.
|
||||
"""
|
||||
|
||||
def _totensor(img, bgr2rgb, float32):
|
||||
if img.shape[2] == 3 and bgr2rgb:
|
||||
if img.dtype == 'float64':
|
||||
img = img.astype('float32')
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = torch.from_numpy(img.transpose(2, 0, 1))
|
||||
if float32:
|
||||
img = img.float()
|
||||
return img
|
||||
|
||||
if isinstance(imgs, list):
|
||||
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
||||
else:
|
||||
return _totensor(imgs, bgr2rgb, float32)
|
||||
|
||||
|
||||
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
|
||||
"""Convert torch Tensors into image numpy arrays.
|
||||
|
||||
After clamping to [min, max], values will be normalized to [0, 1].
|
||||
|
||||
Args:
|
||||
tensor (Tensor or list[Tensor]): Accept shapes:
|
||||
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
|
||||
2) 3D Tensor of shape (3/1 x H x W);
|
||||
3) 2D Tensor of shape (H x W).
|
||||
Tensor channel should be in RGB order.
|
||||
rgb2bgr (bool): Whether to change rgb to bgr.
|
||||
out_type (numpy type): output types. If ``np.uint8``, transform outputs
|
||||
to uint8 type with range [0, 255]; otherwise, float type with
|
||||
range [0, 1]. Default: ``np.uint8``.
|
||||
min_max (tuple[int]): min and max values for clamp.
|
||||
|
||||
Returns:
|
||||
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
|
||||
shape (H x W). The channel order is BGR.
|
||||
"""
|
||||
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
|
||||
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
|
||||
|
||||
if torch.is_tensor(tensor):
|
||||
tensor = [tensor]
|
||||
result = []
|
||||
for _tensor in tensor:
|
||||
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
||||
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
|
||||
|
||||
n_dim = _tensor.dim()
|
||||
if n_dim == 4:
|
||||
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
|
||||
img_np = img_np.transpose(1, 2, 0)
|
||||
if rgb2bgr:
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||||
elif n_dim == 3:
|
||||
img_np = _tensor.numpy()
|
||||
img_np = img_np.transpose(1, 2, 0)
|
||||
if img_np.shape[2] == 1: # gray image
|
||||
img_np = np.squeeze(img_np, axis=2)
|
||||
else:
|
||||
if rgb2bgr:
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||||
elif n_dim == 2:
|
||||
img_np = _tensor.numpy()
|
||||
else:
|
||||
raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
|
||||
if out_type == np.uint8:
|
||||
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
|
||||
img_np = (img_np * 255.0).round()
|
||||
img_np = img_np.astype(out_type)
|
||||
result.append(img_np)
|
||||
if len(result) == 1 and torch.is_tensor(tensor):
|
||||
result = result[0]
|
||||
return result
|
||||
|
||||
|
||||
def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
|
||||
"""This implementation is slightly faster than tensor2img.
|
||||
It now only supports torch tensor with shape (1, c, h, w).
|
||||
|
||||
Args:
|
||||
tensor (Tensor): Now only support torch tensor with (1, c, h, w).
|
||||
rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
|
||||
min_max (tuple[int]): min and max values for clamp.
|
||||
"""
|
||||
output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
|
||||
output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
|
||||
output = output.type(torch.uint8).cpu().numpy()
|
||||
if rgb2bgr:
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
|
||||
|
||||
def imfrombytes(content, flag='color', float32=False):
|
||||
"""Read an image from bytes.
|
||||
|
||||
Args:
|
||||
content (bytes): Image bytes got from files or other streams.
|
||||
flag (str): Flags specifying the color type of a loaded image,
|
||||
candidates are `color`, `grayscale` and `unchanged`.
|
||||
float32 (bool): Whether to change to float32., If True, will also norm
|
||||
to [0, 1]. Default: False.
|
||||
|
||||
Returns:
|
||||
ndarray: Loaded image array.
|
||||
"""
|
||||
img_np = np.frombuffer(content, np.uint8)
|
||||
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
|
||||
img = cv2.imdecode(img_np, imread_flags[flag])
|
||||
if float32:
|
||||
img = img.astype(np.float32) / 255.
|
||||
return img
|
||||
|
||||
|
||||
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
||||
"""Write image to file.
|
||||
|
||||
Args:
|
||||
img (ndarray): Image array to be written.
|
||||
file_path (str): Image file path.
|
||||
params (None or list): Same as opencv's :func:`imwrite` interface.
|
||||
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
|
||||
whether to create it automatically.
|
||||
|
||||
Returns:
|
||||
bool: Successful or not.
|
||||
"""
|
||||
if auto_mkdir:
|
||||
dir_name = os.path.abspath(os.path.dirname(file_path))
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
ok = cv2.imwrite(file_path, img, params)
|
||||
if not ok:
|
||||
raise IOError('Failed in writing images.')
|
||||
|
||||
|
||||
def crop_border(imgs, crop_border):
|
||||
"""Crop borders of images.
|
||||
|
||||
Args:
|
||||
imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
|
||||
crop_border (int): Crop border for each end of height and weight.
|
||||
|
||||
Returns:
|
||||
list[ndarray]: Cropped images.
|
||||
"""
|
||||
if crop_border == 0:
|
||||
return imgs
|
||||
else:
|
||||
if isinstance(imgs, list):
|
||||
return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
|
||||
else:
|
||||
return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
|
||||
199
basicsr/utils/lmdb_util.py
Normal file
199
basicsr/utils/lmdb_util.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import cv2
|
||||
import lmdb
|
||||
import sys
|
||||
from multiprocessing import Pool
|
||||
from os import path as osp
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def make_lmdb_from_imgs(data_path,
|
||||
lmdb_path,
|
||||
img_path_list,
|
||||
keys,
|
||||
batch=5000,
|
||||
compress_level=1,
|
||||
multiprocessing_read=False,
|
||||
n_thread=40,
|
||||
map_size=None):
|
||||
"""Make lmdb from images.
|
||||
|
||||
Contents of lmdb. The file structure is:
|
||||
|
||||
::
|
||||
|
||||
example.lmdb
|
||||
├── data.mdb
|
||||
├── lock.mdb
|
||||
├── meta_info.txt
|
||||
|
||||
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
||||
https://lmdb.readthedocs.io/en/release/ for more details.
|
||||
|
||||
The meta_info.txt is a specified txt file to record the meta information
|
||||
of our datasets. It will be automatically created when preparing
|
||||
datasets by our provided dataset tools.
|
||||
Each line in the txt file records 1)image name (with extension),
|
||||
2)image shape, and 3)compression level, separated by a white space.
|
||||
|
||||
For example, the meta information could be:
|
||||
`000_00000000.png (720,1280,3) 1`, which means:
|
||||
1) image name (with extension): 000_00000000.png;
|
||||
2) image shape: (720,1280,3);
|
||||
3) compression level: 1
|
||||
|
||||
We use the image name without extension as the lmdb key.
|
||||
|
||||
If `multiprocessing_read` is True, it will read all the images to memory
|
||||
using multiprocessing. Thus, your server needs to have enough memory.
|
||||
|
||||
Args:
|
||||
data_path (str): Data path for reading images.
|
||||
lmdb_path (str): Lmdb save path.
|
||||
img_path_list (str): Image path list.
|
||||
keys (str): Used for lmdb keys.
|
||||
batch (int): After processing batch images, lmdb commits.
|
||||
Default: 5000.
|
||||
compress_level (int): Compress level when encoding images. Default: 1.
|
||||
multiprocessing_read (bool): Whether use multiprocessing to read all
|
||||
the images to memory. Default: False.
|
||||
n_thread (int): For multiprocessing.
|
||||
map_size (int | None): Map size for lmdb env. If None, use the
|
||||
estimated size from images. Default: None
|
||||
"""
|
||||
|
||||
assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
|
||||
f'but got {len(img_path_list)} and {len(keys)}')
|
||||
print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
|
||||
print(f'Totoal images: {len(img_path_list)}')
|
||||
if not lmdb_path.endswith('.lmdb'):
|
||||
raise ValueError("lmdb_path must end with '.lmdb'.")
|
||||
if osp.exists(lmdb_path):
|
||||
print(f'Folder {lmdb_path} already exists. Exit.')
|
||||
sys.exit(1)
|
||||
|
||||
if multiprocessing_read:
|
||||
# read all the images to memory (multiprocessing)
|
||||
dataset = {} # use dict to keep the order for multiprocessing
|
||||
shapes = {}
|
||||
print(f'Read images with multiprocessing, #thread: {n_thread} ...')
|
||||
pbar = tqdm(total=len(img_path_list), unit='image')
|
||||
|
||||
def callback(arg):
|
||||
"""get the image data and update pbar."""
|
||||
key, dataset[key], shapes[key] = arg
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Read {key}')
|
||||
|
||||
pool = Pool(n_thread)
|
||||
for path, key in zip(img_path_list, keys):
|
||||
pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
|
||||
pool.close()
|
||||
pool.join()
|
||||
pbar.close()
|
||||
print(f'Finish reading {len(img_path_list)} images.')
|
||||
|
||||
# create lmdb environment
|
||||
if map_size is None:
|
||||
# obtain data size for one image
|
||||
img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
|
||||
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
||||
data_size_per_img = img_byte.nbytes
|
||||
print('Data size per image is: ', data_size_per_img)
|
||||
data_size = data_size_per_img * len(img_path_list)
|
||||
map_size = data_size * 10
|
||||
|
||||
env = lmdb.open(lmdb_path, map_size=map_size)
|
||||
|
||||
# write data to lmdb
|
||||
pbar = tqdm(total=len(img_path_list), unit='chunk')
|
||||
txn = env.begin(write=True)
|
||||
txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
|
||||
for idx, (path, key) in enumerate(zip(img_path_list, keys)):
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Write {key}')
|
||||
key_byte = key.encode('ascii')
|
||||
if multiprocessing_read:
|
||||
img_byte = dataset[key]
|
||||
h, w, c = shapes[key]
|
||||
else:
|
||||
_, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
|
||||
h, w, c = img_shape
|
||||
|
||||
txn.put(key_byte, img_byte)
|
||||
# write meta information
|
||||
txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
|
||||
if idx % batch == 0:
|
||||
txn.commit()
|
||||
txn = env.begin(write=True)
|
||||
pbar.close()
|
||||
txn.commit()
|
||||
env.close()
|
||||
txt_file.close()
|
||||
print('\nFinish writing lmdb.')
|
||||
|
||||
|
||||
def read_img_worker(path, key, compress_level):
|
||||
"""Read image worker.
|
||||
|
||||
Args:
|
||||
path (str): Image path.
|
||||
key (str): Image key.
|
||||
compress_level (int): Compress level when encoding images.
|
||||
|
||||
Returns:
|
||||
str: Image key.
|
||||
byte: Image byte.
|
||||
tuple[int]: Image shape.
|
||||
"""
|
||||
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
if img.ndim == 2:
|
||||
h, w = img.shape
|
||||
c = 1
|
||||
else:
|
||||
h, w, c = img.shape
|
||||
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
||||
return (key, img_byte, (h, w, c))
|
||||
|
||||
|
||||
class LmdbMaker():
|
||||
"""LMDB Maker.
|
||||
|
||||
Args:
|
||||
lmdb_path (str): Lmdb save path.
|
||||
map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
|
||||
batch (int): After processing batch images, lmdb commits.
|
||||
Default: 5000.
|
||||
compress_level (int): Compress level when encoding images. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
|
||||
if not lmdb_path.endswith('.lmdb'):
|
||||
raise ValueError("lmdb_path must end with '.lmdb'.")
|
||||
if osp.exists(lmdb_path):
|
||||
print(f'Folder {lmdb_path} already exists. Exit.')
|
||||
sys.exit(1)
|
||||
|
||||
self.lmdb_path = lmdb_path
|
||||
self.batch = batch
|
||||
self.compress_level = compress_level
|
||||
self.env = lmdb.open(lmdb_path, map_size=map_size)
|
||||
self.txn = self.env.begin(write=True)
|
||||
self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
|
||||
self.counter = 0
|
||||
|
||||
def put(self, img_byte, key, img_shape):
|
||||
self.counter += 1
|
||||
key_byte = key.encode('ascii')
|
||||
self.txn.put(key_byte, img_byte)
|
||||
# write meta information
|
||||
h, w, c = img_shape
|
||||
self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
|
||||
if self.counter % self.batch == 0:
|
||||
self.txn.commit()
|
||||
self.txn = self.env.begin(write=True)
|
||||
|
||||
def close(self):
|
||||
self.txn.commit()
|
||||
self.env.close()
|
||||
self.txt_file.close()
|
||||
213
basicsr/utils/logger.py
Normal file
213
basicsr/utils/logger.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
from .dist_util import get_dist_info, master_only
|
||||
|
||||
initialized_logger = {}
|
||||
|
||||
|
||||
class AvgTimer():
|
||||
|
||||
def __init__(self, window=200):
|
||||
self.window = window # average window
|
||||
self.current_time = 0
|
||||
self.total_time = 0
|
||||
self.count = 0
|
||||
self.avg_time = 0
|
||||
self.start()
|
||||
|
||||
def start(self):
|
||||
self.start_time = self.tic = time.time()
|
||||
|
||||
def record(self):
|
||||
self.count += 1
|
||||
self.toc = time.time()
|
||||
self.current_time = self.toc - self.tic
|
||||
self.total_time += self.current_time
|
||||
# calculate average time
|
||||
self.avg_time = self.total_time / self.count
|
||||
|
||||
# reset
|
||||
if self.count > self.window:
|
||||
self.count = 0
|
||||
self.total_time = 0
|
||||
|
||||
self.tic = time.time()
|
||||
|
||||
def get_current_time(self):
|
||||
return self.current_time
|
||||
|
||||
def get_avg_time(self):
|
||||
return self.avg_time
|
||||
|
||||
|
||||
class MessageLogger():
|
||||
"""Message logger for printing.
|
||||
|
||||
Args:
|
||||
opt (dict): Config. It contains the following keys:
|
||||
name (str): Exp name.
|
||||
logger (dict): Contains 'print_freq' (str) for logger interval.
|
||||
train (dict): Contains 'total_iter' (int) for total iters.
|
||||
use_tb_logger (bool): Use tensorboard logger.
|
||||
start_iter (int): Start iter. Default: 1.
|
||||
tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, opt, start_iter=1, tb_logger=None):
|
||||
self.exp_name = opt['name']
|
||||
self.interval = opt['logger']['print_freq']
|
||||
self.start_iter = start_iter
|
||||
self.max_iters = opt['train']['total_iter']
|
||||
self.use_tb_logger = opt['logger']['use_tb_logger']
|
||||
self.tb_logger = tb_logger
|
||||
self.start_time = time.time()
|
||||
self.logger = get_root_logger()
|
||||
|
||||
def reset_start_time(self):
|
||||
self.start_time = time.time()
|
||||
|
||||
@master_only
|
||||
def __call__(self, log_vars):
|
||||
"""Format logging message.
|
||||
|
||||
Args:
|
||||
log_vars (dict): It contains the following keys:
|
||||
epoch (int): Epoch number.
|
||||
iter (int): Current iter.
|
||||
lrs (list): List for learning rates.
|
||||
|
||||
time (float): Iter time.
|
||||
data_time (float): Data time for each iter.
|
||||
"""
|
||||
# epoch, iter, learning rates
|
||||
epoch = log_vars.pop('epoch')
|
||||
current_iter = log_vars.pop('iter')
|
||||
lrs = log_vars.pop('lrs')
|
||||
|
||||
message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
|
||||
for v in lrs:
|
||||
message += f'{v:.3e},'
|
||||
message += ')] '
|
||||
|
||||
# time and estimated time
|
||||
if 'time' in log_vars.keys():
|
||||
iter_time = log_vars.pop('time')
|
||||
data_time = log_vars.pop('data_time')
|
||||
|
||||
total_time = time.time() - self.start_time
|
||||
time_sec_avg = total_time / (current_iter - self.start_iter + 1)
|
||||
eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
|
||||
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
||||
message += f'[eta: {eta_str}, '
|
||||
message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
|
||||
|
||||
# other items, especially losses
|
||||
for k, v in log_vars.items():
|
||||
message += f'{k}: {v:.4e} '
|
||||
# tensorboard logger
|
||||
if self.use_tb_logger and 'debug' not in self.exp_name:
|
||||
if k.startswith('l_'):
|
||||
self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
|
||||
else:
|
||||
self.tb_logger.add_scalar(k, v, current_iter)
|
||||
self.logger.info(message)
|
||||
|
||||
|
||||
@master_only
|
||||
def init_tb_logger(log_dir):
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
tb_logger = SummaryWriter(log_dir=log_dir)
|
||||
return tb_logger
|
||||
|
||||
|
||||
@master_only
|
||||
def init_wandb_logger(opt):
|
||||
"""We now only use wandb to sync tensorboard log."""
|
||||
import wandb
|
||||
logger = get_root_logger()
|
||||
|
||||
project = opt['logger']['wandb']['project']
|
||||
resume_id = opt['logger']['wandb'].get('resume_id')
|
||||
if resume_id:
|
||||
wandb_id = resume_id
|
||||
resume = 'allow'
|
||||
logger.warning(f'Resume wandb logger with id={wandb_id}.')
|
||||
else:
|
||||
wandb_id = wandb.util.generate_id()
|
||||
resume = 'never'
|
||||
|
||||
wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
|
||||
|
||||
logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
|
||||
|
||||
|
||||
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
|
||||
"""Get the root logger.
|
||||
|
||||
The logger will be initialized if it has not been initialized. By default a
|
||||
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
||||
also be added.
|
||||
|
||||
Args:
|
||||
logger_name (str): root logger name. Default: 'basicsr'.
|
||||
log_file (str | None): The log filename. If specified, a FileHandler
|
||||
will be added to the root logger.
|
||||
log_level (int): The root logger level. Note that only the process of
|
||||
rank 0 is affected, while other processes will set the level to
|
||||
"Error" and be silent most of the time.
|
||||
|
||||
Returns:
|
||||
logging.Logger: The root logger.
|
||||
"""
|
||||
logger = logging.getLogger(logger_name)
|
||||
# if the logger has been initialized, just return it
|
||||
if logger_name in initialized_logger:
|
||||
return logger
|
||||
|
||||
format_str = '%(asctime)s %(levelname)s: %(message)s'
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(logging.Formatter(format_str))
|
||||
logger.addHandler(stream_handler)
|
||||
logger.propagate = False
|
||||
rank, _ = get_dist_info()
|
||||
if rank != 0:
|
||||
logger.setLevel('ERROR')
|
||||
elif log_file is not None:
|
||||
logger.setLevel(log_level)
|
||||
# add file handler
|
||||
file_handler = logging.FileHandler(log_file, 'w')
|
||||
file_handler.setFormatter(logging.Formatter(format_str))
|
||||
file_handler.setLevel(log_level)
|
||||
logger.addHandler(file_handler)
|
||||
initialized_logger[logger_name] = True
|
||||
return logger
|
||||
|
||||
|
||||
def get_env_info():
|
||||
"""Get environment information.
|
||||
|
||||
Currently, only log the software version.
|
||||
"""
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from basicsr.version import __version__
|
||||
msg = r"""
|
||||
____ _ _____ ____
|
||||
/ __ ) ____ _ _____ (_)_____/ ___/ / __ \
|
||||
/ __ |/ __ `// ___// // ___/\__ \ / /_/ /
|
||||
/ /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
|
||||
/_____/ \__,_//____//_/ \___//____//_/ |_|
|
||||
______ __ __ __ __
|
||||
/ ____/____ ____ ____/ / / / __ __ _____ / /__ / /
|
||||
/ / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
|
||||
/ /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
|
||||
\____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
|
||||
"""
|
||||
msg += ('\nVersion Information: '
|
||||
f'\n\tBasicSR: {__version__}'
|
||||
f'\n\tPyTorch: {torch.__version__}'
|
||||
f'\n\tTorchVision: {torchvision.__version__}')
|
||||
return msg
|
||||
178
basicsr/utils/matlab_functions.py
Normal file
178
basicsr/utils/matlab_functions.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def cubic(x):
|
||||
"""cubic function used for calculate_weights_indices."""
|
||||
absx = torch.abs(x)
|
||||
absx2 = absx**2
|
||||
absx3 = absx**3
|
||||
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
|
||||
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
|
||||
(absx <= 2)).type_as(absx))
|
||||
|
||||
|
||||
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
||||
"""Calculate weights and indices, used for imresize function.
|
||||
|
||||
Args:
|
||||
in_length (int): Input length.
|
||||
out_length (int): Output length.
|
||||
scale (float): Scale factor.
|
||||
kernel_width (int): Kernel width.
|
||||
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
|
||||
"""
|
||||
|
||||
if (scale < 1) and antialiasing:
|
||||
# Use a modified kernel (larger kernel width) to simultaneously
|
||||
# interpolate and antialias
|
||||
kernel_width = kernel_width / scale
|
||||
|
||||
# Output-space coordinates
|
||||
x = torch.linspace(1, out_length, out_length)
|
||||
|
||||
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
||||
# in output space maps to 0.5 in input space, and 0.5 + scale in output
|
||||
# space maps to 1.5 in input space.
|
||||
u = x / scale + 0.5 * (1 - 1 / scale)
|
||||
|
||||
# What is the left-most pixel that can be involved in the computation?
|
||||
left = torch.floor(u - kernel_width / 2)
|
||||
|
||||
# What is the maximum number of pixels that can be involved in the
|
||||
# computation? Note: it's OK to use an extra pixel here; if the
|
||||
# corresponding weights are all zero, it will be eliminated at the end
|
||||
# of this function.
|
||||
p = math.ceil(kernel_width) + 2
|
||||
|
||||
# The indices of the input pixels involved in computing the k-th output
|
||||
# pixel are in row k of the indices matrix.
|
||||
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
|
||||
out_length, p)
|
||||
|
||||
# The weights used to compute the k-th output pixel are in row k of the
|
||||
# weights matrix.
|
||||
distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
|
||||
|
||||
# apply cubic kernel
|
||||
if (scale < 1) and antialiasing:
|
||||
weights = scale * cubic(distance_to_center * scale)
|
||||
else:
|
||||
weights = cubic(distance_to_center)
|
||||
|
||||
# Normalize the weights matrix so that each row sums to 1.
|
||||
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
||||
weights = weights / weights_sum.expand(out_length, p)
|
||||
|
||||
# If a column in weights is all zero, get rid of it. only consider the
|
||||
# first and last column.
|
||||
weights_zero_tmp = torch.sum((weights == 0), 0)
|
||||
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
||||
indices = indices.narrow(1, 1, p - 2)
|
||||
weights = weights.narrow(1, 1, p - 2)
|
||||
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
||||
indices = indices.narrow(1, 0, p - 2)
|
||||
weights = weights.narrow(1, 0, p - 2)
|
||||
weights = weights.contiguous()
|
||||
indices = indices.contiguous()
|
||||
sym_len_s = -indices.min() + 1
|
||||
sym_len_e = indices.max() - in_length
|
||||
indices = indices + sym_len_s - 1
|
||||
return weights, indices, int(sym_len_s), int(sym_len_e)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def imresize(img, scale, antialiasing=True):
|
||||
"""imresize function same as MATLAB.
|
||||
|
||||
It now only supports bicubic.
|
||||
The same scale applies for both height and width.
|
||||
|
||||
Args:
|
||||
img (Tensor | Numpy array):
|
||||
Tensor: Input image with shape (c, h, w), [0, 1] range.
|
||||
Numpy: Input image with shape (h, w, c), [0, 1] range.
|
||||
scale (float): Scale factor. The same scale applies for both height
|
||||
and width.
|
||||
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
|
||||
Default: True.
|
||||
|
||||
Returns:
|
||||
Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
|
||||
"""
|
||||
squeeze_flag = False
|
||||
if type(img).__module__ == np.__name__: # numpy type
|
||||
numpy_type = True
|
||||
if img.ndim == 2:
|
||||
img = img[:, :, None]
|
||||
squeeze_flag = True
|
||||
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
|
||||
else:
|
||||
numpy_type = False
|
||||
if img.ndim == 2:
|
||||
img = img.unsqueeze(0)
|
||||
squeeze_flag = True
|
||||
|
||||
in_c, in_h, in_w = img.size()
|
||||
out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
|
||||
kernel_width = 4
|
||||
kernel = 'cubic'
|
||||
|
||||
# get weights and indices
|
||||
weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
|
||||
antialiasing)
|
||||
weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
|
||||
antialiasing)
|
||||
# process H dimension
|
||||
# symmetric copying
|
||||
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
|
||||
img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
|
||||
|
||||
sym_patch = img[:, :sym_len_hs, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = img[:, -sym_len_he:, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
|
||||
|
||||
out_1 = torch.FloatTensor(in_c, out_h, in_w)
|
||||
kernel_width = weights_h.size(1)
|
||||
for i in range(out_h):
|
||||
idx = int(indices_h[i][0])
|
||||
for j in range(in_c):
|
||||
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
|
||||
out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
|
||||
|
||||
sym_patch = out_1[:, :, :sym_len_ws]
|
||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||
out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = out_1[:, :, -sym_len_we:]
|
||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||
out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
|
||||
|
||||
out_2 = torch.FloatTensor(in_c, out_h, out_w)
|
||||
kernel_width = weights_w.size(1)
|
||||
for i in range(out_w):
|
||||
idx = int(indices_w[i][0])
|
||||
for j in range(in_c):
|
||||
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
|
||||
|
||||
if squeeze_flag:
|
||||
out_2 = out_2.squeeze(0)
|
||||
if numpy_type:
|
||||
out_2 = out_2.numpy()
|
||||
if not squeeze_flag:
|
||||
out_2 = out_2.transpose(1, 2, 0)
|
||||
|
||||
return out_2
|
||||
141
basicsr/utils/misc.py
Normal file
141
basicsr/utils/misc.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import torch
|
||||
from os import path as osp
|
||||
|
||||
from .dist_util import master_only
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
"""Set random seeds."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def get_time_str():
|
||||
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
|
||||
|
||||
def mkdir_and_rename(path):
|
||||
"""mkdirs. If path exists, rename it with timestamp and create a new one.
|
||||
|
||||
Args:
|
||||
path (str): Folder path.
|
||||
"""
|
||||
if osp.exists(path):
|
||||
new_name = path + '_archived_' + get_time_str()
|
||||
print(f'Path already exists. Rename it to {new_name}', flush=True)
|
||||
os.rename(path, new_name)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
|
||||
@master_only
|
||||
def make_exp_dirs(opt):
|
||||
"""Make dirs for experiments."""
|
||||
path_opt = opt['path'].copy()
|
||||
if opt['is_train']:
|
||||
mkdir_and_rename(path_opt.pop('experiments_root'))
|
||||
else:
|
||||
mkdir_and_rename(path_opt.pop('results_root'))
|
||||
for key, path in path_opt.items():
|
||||
if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key):
|
||||
continue
|
||||
else:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
|
||||
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
||||
"""Scan a directory to find the interested files.
|
||||
|
||||
Args:
|
||||
dir_path (str): Path of the directory.
|
||||
suffix (str | tuple(str), optional): File suffix that we are
|
||||
interested in. Default: None.
|
||||
recursive (bool, optional): If set to True, recursively scan the
|
||||
directory. Default: False.
|
||||
full_path (bool, optional): If set to True, include the dir_path.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
A generator for all the interested files with relative paths.
|
||||
"""
|
||||
|
||||
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
||||
raise TypeError('"suffix" must be a string or tuple of strings')
|
||||
|
||||
root = dir_path
|
||||
|
||||
def _scandir(dir_path, suffix, recursive):
|
||||
for entry in os.scandir(dir_path):
|
||||
if not entry.name.startswith('.') and entry.is_file():
|
||||
if full_path:
|
||||
return_path = entry.path
|
||||
else:
|
||||
return_path = osp.relpath(entry.path, root)
|
||||
|
||||
if suffix is None:
|
||||
yield return_path
|
||||
elif return_path.endswith(suffix):
|
||||
yield return_path
|
||||
else:
|
||||
if recursive:
|
||||
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
|
||||
else:
|
||||
continue
|
||||
|
||||
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
||||
|
||||
|
||||
def check_resume(opt, resume_iter):
|
||||
"""Check resume states and pretrain_network paths.
|
||||
|
||||
Args:
|
||||
opt (dict): Options.
|
||||
resume_iter (int): Resume iteration.
|
||||
"""
|
||||
if opt['path']['resume_state']:
|
||||
# get all the networks
|
||||
networks = [key for key in opt.keys() if key.startswith('network_')]
|
||||
flag_pretrain = False
|
||||
for network in networks:
|
||||
if opt['path'].get(f'pretrain_{network}') is not None:
|
||||
flag_pretrain = True
|
||||
if flag_pretrain:
|
||||
print('pretrain_network path will be ignored during resuming.')
|
||||
# set pretrained model paths
|
||||
for network in networks:
|
||||
name = f'pretrain_{network}'
|
||||
basename = network.replace('network_', '')
|
||||
if opt['path'].get('ignore_resume_networks') is None or (network
|
||||
not in opt['path']['ignore_resume_networks']):
|
||||
opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
|
||||
print(f"Set {name} to {opt['path'][name]}")
|
||||
|
||||
# change param_key to params in resume
|
||||
param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')]
|
||||
for param_key in param_keys:
|
||||
if opt['path'][param_key] == 'params_ema':
|
||||
opt['path'][param_key] = 'params'
|
||||
print(f'Set {param_key} to params')
|
||||
|
||||
|
||||
def sizeof_fmt(size, suffix='B'):
|
||||
"""Get human readable file size.
|
||||
|
||||
Args:
|
||||
size (int): File size.
|
||||
suffix (str): Suffix. Default: 'B'.
|
||||
|
||||
Return:
|
||||
str: Formatted file size.
|
||||
"""
|
||||
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
|
||||
if abs(size) < 1024.0:
|
||||
return f'{size:3.1f} {unit}{suffix}'
|
||||
size /= 1024.0
|
||||
return f'{size:3.1f} Y{suffix}'
|
||||
210
basicsr/utils/options.py
Normal file
210
basicsr/utils/options.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import yaml
|
||||
from collections import OrderedDict
|
||||
from os import path as osp
|
||||
|
||||
from basicsr.utils import set_random_seed
|
||||
from basicsr.utils.dist_util import get_dist_info, init_dist, master_only
|
||||
|
||||
|
||||
def ordered_yaml():
|
||||
"""Support OrderedDict for yaml.
|
||||
|
||||
Returns:
|
||||
tuple: yaml Loader and Dumper.
|
||||
"""
|
||||
try:
|
||||
from yaml import CDumper as Dumper
|
||||
from yaml import CLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import Dumper, Loader
|
||||
|
||||
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
|
||||
|
||||
def dict_representer(dumper, data):
|
||||
return dumper.represent_dict(data.items())
|
||||
|
||||
def dict_constructor(loader, node):
|
||||
return OrderedDict(loader.construct_pairs(node))
|
||||
|
||||
Dumper.add_representer(OrderedDict, dict_representer)
|
||||
Loader.add_constructor(_mapping_tag, dict_constructor)
|
||||
return Loader, Dumper
|
||||
|
||||
|
||||
def yaml_load(f):
|
||||
"""Load yaml file or string.
|
||||
|
||||
Args:
|
||||
f (str): File path or a python string.
|
||||
|
||||
Returns:
|
||||
dict: Loaded dict.
|
||||
"""
|
||||
if os.path.isfile(f):
|
||||
with open(f, 'r') as f:
|
||||
return yaml.load(f, Loader=ordered_yaml()[0])
|
||||
else:
|
||||
return yaml.load(f, Loader=ordered_yaml()[0])
|
||||
|
||||
|
||||
def dict2str(opt, indent_level=1):
|
||||
"""dict to string for printing options.
|
||||
|
||||
Args:
|
||||
opt (dict): Option dict.
|
||||
indent_level (int): Indent level. Default: 1.
|
||||
|
||||
Return:
|
||||
(str): Option string for printing.
|
||||
"""
|
||||
msg = '\n'
|
||||
for k, v in opt.items():
|
||||
if isinstance(v, dict):
|
||||
msg += ' ' * (indent_level * 2) + k + ':['
|
||||
msg += dict2str(v, indent_level + 1)
|
||||
msg += ' ' * (indent_level * 2) + ']\n'
|
||||
else:
|
||||
msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
|
||||
return msg
|
||||
|
||||
|
||||
def _postprocess_yml_value(value):
|
||||
# None
|
||||
if value == '~' or value.lower() == 'none':
|
||||
return None
|
||||
# bool
|
||||
if value.lower() == 'true':
|
||||
return True
|
||||
elif value.lower() == 'false':
|
||||
return False
|
||||
# !!float number
|
||||
if value.startswith('!!float'):
|
||||
return float(value.replace('!!float', ''))
|
||||
# number
|
||||
if value.isdigit():
|
||||
return int(value)
|
||||
elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
|
||||
return float(value)
|
||||
# list
|
||||
if value.startswith('['):
|
||||
return eval(value)
|
||||
# str
|
||||
return value
|
||||
|
||||
|
||||
def parse_options(root_path, is_train=True):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
|
||||
parser.add_argument('--auto_resume', action='store_true')
|
||||
parser.add_argument('--debug', action='store_true')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
parser.add_argument(
|
||||
'--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
|
||||
args = parser.parse_args()
|
||||
|
||||
# parse yml to dict
|
||||
opt = yaml_load(args.opt)
|
||||
|
||||
# distributed settings
|
||||
if args.launcher == 'none':
|
||||
opt['dist'] = False
|
||||
print('Disable distributed.', flush=True)
|
||||
else:
|
||||
opt['dist'] = True
|
||||
if args.launcher == 'slurm' and 'dist_params' in opt:
|
||||
init_dist(args.launcher, **opt['dist_params'])
|
||||
else:
|
||||
init_dist(args.launcher)
|
||||
opt['rank'], opt['world_size'] = get_dist_info()
|
||||
|
||||
# random seed
|
||||
seed = opt.get('manual_seed')
|
||||
if seed is None:
|
||||
seed = random.randint(1, 10000)
|
||||
opt['manual_seed'] = seed
|
||||
set_random_seed(seed + opt['rank'])
|
||||
|
||||
# force to update yml options
|
||||
if args.force_yml is not None:
|
||||
for entry in args.force_yml:
|
||||
# now do not support creating new keys
|
||||
keys, value = entry.split('=')
|
||||
keys, value = keys.strip(), value.strip()
|
||||
value = _postprocess_yml_value(value)
|
||||
eval_str = 'opt'
|
||||
for key in keys.split(':'):
|
||||
eval_str += f'["{key}"]'
|
||||
eval_str += '=value'
|
||||
# using exec function
|
||||
exec(eval_str)
|
||||
|
||||
opt['auto_resume'] = args.auto_resume
|
||||
opt['is_train'] = is_train
|
||||
|
||||
# debug setting
|
||||
if args.debug and not opt['name'].startswith('debug'):
|
||||
opt['name'] = 'debug_' + opt['name']
|
||||
|
||||
if opt['num_gpu'] == 'auto':
|
||||
opt['num_gpu'] = torch.cuda.device_count()
|
||||
|
||||
# datasets
|
||||
for phase, dataset in opt['datasets'].items():
|
||||
# for multiple datasets, e.g., val_1, val_2; test_1, test_2
|
||||
phase = phase.split('_')[0]
|
||||
dataset['phase'] = phase
|
||||
if 'scale' in opt:
|
||||
dataset['scale'] = opt['scale']
|
||||
if dataset.get('dataroot_gt') is not None:
|
||||
dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
|
||||
if dataset.get('dataroot_lq') is not None:
|
||||
dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
|
||||
|
||||
# paths
|
||||
for key, val in opt['path'].items():
|
||||
if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
|
||||
opt['path'][key] = osp.expanduser(val)
|
||||
|
||||
if is_train:
|
||||
experiments_root = osp.join(root_path, 'experiments', opt['name'])
|
||||
opt['path']['experiments_root'] = experiments_root
|
||||
opt['path']['models'] = osp.join(experiments_root, 'models')
|
||||
opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
|
||||
opt['path']['log'] = experiments_root
|
||||
opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
|
||||
|
||||
# change some options for debug mode
|
||||
if 'debug' in opt['name']:
|
||||
if 'val' in opt:
|
||||
opt['val']['val_freq'] = 8
|
||||
opt['logger']['print_freq'] = 1
|
||||
opt['logger']['save_checkpoint_freq'] = 8
|
||||
else: # test
|
||||
results_root = osp.join(root_path, 'results', opt['name'])
|
||||
opt['path']['results_root'] = results_root
|
||||
opt['path']['log'] = results_root
|
||||
opt['path']['visualization'] = osp.join(results_root, 'visualization')
|
||||
|
||||
return opt, args
|
||||
|
||||
|
||||
@master_only
|
||||
def copy_opt_file(opt_file, experiments_root):
|
||||
# copy the yml file to the experiment root
|
||||
import sys
|
||||
import time
|
||||
from shutil import copyfile
|
||||
cmd = ' '.join(sys.argv)
|
||||
filename = osp.join(experiments_root, osp.basename(opt_file))
|
||||
copyfile(opt_file, filename)
|
||||
|
||||
with open(filename, 'r+') as f:
|
||||
lines = f.readlines()
|
||||
lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
|
||||
f.seek(0)
|
||||
f.writelines(lines)
|
||||
83
basicsr/utils/plot_util.py
Normal file
83
basicsr/utils/plot_util.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import re
|
||||
|
||||
|
||||
def read_data_from_tensorboard(log_path, tag):
|
||||
"""Get raw data (steps and values) from tensorboard events.
|
||||
|
||||
Args:
|
||||
log_path (str): Path to the tensorboard log.
|
||||
tag (str): tag to be read.
|
||||
"""
|
||||
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
|
||||
|
||||
# tensorboard event
|
||||
event_acc = EventAccumulator(log_path)
|
||||
event_acc.Reload()
|
||||
scalar_list = event_acc.Tags()['scalars']
|
||||
print('tag list: ', scalar_list)
|
||||
steps = [int(s.step) for s in event_acc.Scalars(tag)]
|
||||
values = [s.value for s in event_acc.Scalars(tag)]
|
||||
return steps, values
|
||||
|
||||
|
||||
def read_data_from_txt_2v(path, pattern, step_one=False):
|
||||
"""Read data from txt with 2 returned values (usually [step, value]).
|
||||
|
||||
Args:
|
||||
path (str): path to the txt file.
|
||||
pattern (str): re (regular expression) pattern.
|
||||
step_one (bool): add 1 to steps. Default: False.
|
||||
"""
|
||||
with open(path) as f:
|
||||
lines = f.readlines()
|
||||
lines = [line.strip() for line in lines]
|
||||
steps = []
|
||||
values = []
|
||||
|
||||
pattern = re.compile(pattern)
|
||||
for line in lines:
|
||||
match = pattern.match(line)
|
||||
if match:
|
||||
steps.append(int(match.group(1)))
|
||||
values.append(float(match.group(2)))
|
||||
if step_one:
|
||||
steps = [v + 1 for v in steps]
|
||||
return steps, values
|
||||
|
||||
|
||||
def read_data_from_txt_1v(path, pattern):
|
||||
"""Read data from txt with 1 returned values.
|
||||
|
||||
Args:
|
||||
path (str): path to the txt file.
|
||||
pattern (str): re (regular expression) pattern.
|
||||
"""
|
||||
with open(path) as f:
|
||||
lines = f.readlines()
|
||||
lines = [line.strip() for line in lines]
|
||||
data = []
|
||||
|
||||
pattern = re.compile(pattern)
|
||||
for line in lines:
|
||||
match = pattern.match(line)
|
||||
if match:
|
||||
data.append(float(match.group(1)))
|
||||
return data
|
||||
|
||||
|
||||
def smooth_data(values, smooth_weight):
|
||||
""" Smooth data using 1st-order IIR low-pass filter (what tensorflow does).
|
||||
|
||||
Reference: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704 # noqa: E501
|
||||
|
||||
Args:
|
||||
values (list): A list of values to be smoothed.
|
||||
smooth_weight (float): Smooth weight.
|
||||
"""
|
||||
values_sm = []
|
||||
last_sm_value = values[0]
|
||||
for value in values:
|
||||
value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value
|
||||
values_sm.append(value_sm)
|
||||
last_sm_value = value_sm
|
||||
return values_sm
|
||||
293
basicsr/utils/realesrgan_utils.py
Normal file
293
basicsr/utils/realesrgan_utils.py
Normal file
@@ -0,0 +1,293 @@
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import torch
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from torch.nn import functional as F
|
||||
|
||||
# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
class RealESRGANer():
|
||||
"""A helper class for upsampling images with RealESRGAN.
|
||||
|
||||
Args:
|
||||
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
|
||||
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
|
||||
model (nn.Module): The defined network. Default: None.
|
||||
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
|
||||
input images into tiles, and then process each of them. Finally, they will be merged into one image.
|
||||
0 denotes for do not use tile. Default: 0.
|
||||
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
|
||||
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
|
||||
half (float): Whether to use half precision during inference. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
scale,
|
||||
model_path,
|
||||
model=None,
|
||||
tile=0,
|
||||
tile_pad=10,
|
||||
pre_pad=10,
|
||||
half=False,
|
||||
device=None,
|
||||
gpu_id=None):
|
||||
self.scale = scale
|
||||
self.tile_size = tile
|
||||
self.tile_pad = tile_pad
|
||||
self.pre_pad = pre_pad
|
||||
self.mod_scale = None
|
||||
self.half = half
|
||||
|
||||
# initialize model
|
||||
if gpu_id:
|
||||
self.device = torch.device(
|
||||
f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
|
||||
else:
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
||||
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(
|
||||
url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None)
|
||||
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
|
||||
# prefer to use params_ema
|
||||
if 'params_ema' in loadnet:
|
||||
keyname = 'params_ema'
|
||||
else:
|
||||
keyname = 'params'
|
||||
model.load_state_dict(loadnet[keyname], strict=True)
|
||||
model.eval()
|
||||
self.model = model.to(self.device)
|
||||
if self.half:
|
||||
self.model = self.model.half()
|
||||
|
||||
def pre_process(self, img):
|
||||
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
|
||||
"""
|
||||
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
||||
self.img = img.unsqueeze(0).to(self.device)
|
||||
if self.half:
|
||||
self.img = self.img.half()
|
||||
|
||||
# pre_pad
|
||||
if self.pre_pad != 0:
|
||||
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
||||
# mod pad for divisible borders
|
||||
if self.scale == 2:
|
||||
self.mod_scale = 2
|
||||
elif self.scale == 1:
|
||||
self.mod_scale = 4
|
||||
if self.mod_scale is not None:
|
||||
self.mod_pad_h, self.mod_pad_w = 0, 0
|
||||
_, _, h, w = self.img.size()
|
||||
if (h % self.mod_scale != 0):
|
||||
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
|
||||
if (w % self.mod_scale != 0):
|
||||
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
|
||||
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
||||
|
||||
def process(self):
|
||||
# model inference
|
||||
self.output = self.model(self.img)
|
||||
|
||||
def tile_process(self):
|
||||
"""It will first crop input images to tiles, and then process each tile.
|
||||
Finally, all the processed tiles are merged into one images.
|
||||
|
||||
Modified from: https://github.com/ata4/esrgan-launcher
|
||||
"""
|
||||
batch, channel, height, width = self.img.shape
|
||||
output_height = height * self.scale
|
||||
output_width = width * self.scale
|
||||
output_shape = (batch, channel, output_height, output_width)
|
||||
|
||||
# start with black image
|
||||
self.output = self.img.new_zeros(output_shape)
|
||||
tiles_x = math.ceil(width / self.tile_size)
|
||||
tiles_y = math.ceil(height / self.tile_size)
|
||||
|
||||
# loop over all tiles
|
||||
for y in range(tiles_y):
|
||||
for x in range(tiles_x):
|
||||
# extract tile from input image
|
||||
ofs_x = x * self.tile_size
|
||||
ofs_y = y * self.tile_size
|
||||
# input tile area on total image
|
||||
input_start_x = ofs_x
|
||||
input_end_x = min(ofs_x + self.tile_size, width)
|
||||
input_start_y = ofs_y
|
||||
input_end_y = min(ofs_y + self.tile_size, height)
|
||||
|
||||
# input tile area on total image with padding
|
||||
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
||||
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
||||
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
||||
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
||||
|
||||
# input tile dimensions
|
||||
input_tile_width = input_end_x - input_start_x
|
||||
input_tile_height = input_end_y - input_start_y
|
||||
tile_idx = y * tiles_x + x + 1
|
||||
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
||||
|
||||
# upscale tile
|
||||
try:
|
||||
with torch.no_grad():
|
||||
output_tile = self.model(input_tile)
|
||||
except RuntimeError as error:
|
||||
print('Error', error)
|
||||
# print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
|
||||
|
||||
# output tile area on total image
|
||||
output_start_x = input_start_x * self.scale
|
||||
output_end_x = input_end_x * self.scale
|
||||
output_start_y = input_start_y * self.scale
|
||||
output_end_y = input_end_y * self.scale
|
||||
|
||||
# output tile area without padding
|
||||
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
||||
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
||||
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
||||
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
||||
|
||||
# put tile into output image
|
||||
self.output[:, :, output_start_y:output_end_y,
|
||||
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
|
||||
output_start_x_tile:output_end_x_tile]
|
||||
|
||||
def post_process(self):
|
||||
# remove extra pad
|
||||
if self.mod_scale is not None:
|
||||
_, _, h, w = self.output.size()
|
||||
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
|
||||
# remove prepad
|
||||
if self.pre_pad != 0:
|
||||
_, _, h, w = self.output.size()
|
||||
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
|
||||
return self.output
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
|
||||
h_input, w_input = img.shape[0:2]
|
||||
# img: numpy
|
||||
img = img.astype(np.float32)
|
||||
if np.max(img) > 256: # 16-bit image
|
||||
max_range = 65535
|
||||
print('\tInput is a 16-bit image')
|
||||
else:
|
||||
max_range = 255
|
||||
img = img / max_range
|
||||
if len(img.shape) == 2: # gray image
|
||||
img_mode = 'L'
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
||||
elif img.shape[2] == 4: # RGBA image with alpha channel
|
||||
img_mode = 'RGBA'
|
||||
alpha = img[:, :, 3]
|
||||
img = img[:, :, 0:3]
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
if alpha_upsampler == 'realesrgan':
|
||||
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
||||
else:
|
||||
img_mode = 'RGB'
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# ------------------- process image (without the alpha channel) ------------------- #
|
||||
self.pre_process(img)
|
||||
if self.tile_size > 0:
|
||||
self.tile_process()
|
||||
else:
|
||||
self.process()
|
||||
output_img = self.post_process()
|
||||
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
||||
if img_mode == 'L':
|
||||
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# ------------------- process the alpha channel if necessary ------------------- #
|
||||
if img_mode == 'RGBA':
|
||||
if alpha_upsampler == 'realesrgan':
|
||||
self.pre_process(alpha)
|
||||
if self.tile_size > 0:
|
||||
self.tile_process()
|
||||
else:
|
||||
self.process()
|
||||
output_alpha = self.post_process()
|
||||
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
||||
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
||||
else: # use the cv2 resize for alpha channel
|
||||
h, w = alpha.shape[0:2]
|
||||
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# merge the alpha channel
|
||||
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
||||
output_img[:, :, 3] = output_alpha
|
||||
|
||||
# ------------------------------ return ------------------------------ #
|
||||
if max_range == 65535: # 16-bit image
|
||||
output = (output_img * 65535.0).round().astype(np.uint16)
|
||||
else:
|
||||
output = (output_img * 255.0).round().astype(np.uint8)
|
||||
|
||||
if outscale is not None and outscale != float(self.scale):
|
||||
output = cv2.resize(
|
||||
output, (
|
||||
int(w_input * outscale),
|
||||
int(h_input * outscale),
|
||||
), interpolation=cv2.INTER_LANCZOS4)
|
||||
|
||||
return output, img_mode
|
||||
|
||||
|
||||
class PrefetchReader(threading.Thread):
|
||||
"""Prefetch images.
|
||||
|
||||
Args:
|
||||
img_list (list[str]): A image list of image paths to be read.
|
||||
num_prefetch_queue (int): Number of prefetch queue.
|
||||
"""
|
||||
|
||||
def __init__(self, img_list, num_prefetch_queue):
|
||||
super().__init__()
|
||||
self.que = queue.Queue(num_prefetch_queue)
|
||||
self.img_list = img_list
|
||||
|
||||
def run(self):
|
||||
for img_path in self.img_list:
|
||||
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||
self.que.put(img)
|
||||
|
||||
self.que.put(None)
|
||||
|
||||
def __next__(self):
|
||||
next_item = self.que.get()
|
||||
if next_item is None:
|
||||
raise StopIteration
|
||||
return next_item
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
|
||||
class IOConsumer(threading.Thread):
|
||||
|
||||
def __init__(self, opt, que, qid):
|
||||
super().__init__()
|
||||
self._queue = que
|
||||
self.qid = qid
|
||||
self.opt = opt
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
msg = self._queue.get()
|
||||
if isinstance(msg, str) and msg == 'quit':
|
||||
break
|
||||
|
||||
output = msg['output']
|
||||
save_path = msg['save_path']
|
||||
cv2.imwrite(save_path, output)
|
||||
print(f'IO worker {self.qid} is done.')
|
||||
88
basicsr/utils/registry.py
Normal file
88
basicsr/utils/registry.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
|
||||
|
||||
|
||||
class Registry():
|
||||
"""
|
||||
The registry that provides name -> object mapping, to support third-party
|
||||
users' custom modules.
|
||||
|
||||
To create a registry (e.g. a backbone registry):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
BACKBONE_REGISTRY = Registry('BACKBONE')
|
||||
|
||||
To register an object:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
class MyBackbone():
|
||||
...
|
||||
|
||||
Or:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
BACKBONE_REGISTRY.register(MyBackbone)
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""
|
||||
Args:
|
||||
name (str): the name of this registry
|
||||
"""
|
||||
self._name = name
|
||||
self._obj_map = {}
|
||||
|
||||
def _do_register(self, name, obj, suffix=None):
|
||||
if isinstance(suffix, str):
|
||||
name = name + '_' + suffix
|
||||
|
||||
assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
|
||||
f"in '{self._name}' registry!")
|
||||
self._obj_map[name] = obj
|
||||
|
||||
def register(self, obj=None, suffix=None):
|
||||
"""
|
||||
Register the given object under the the name `obj.__name__`.
|
||||
Can be used as either a decorator or not.
|
||||
See docstring of this class for usage.
|
||||
"""
|
||||
if obj is None:
|
||||
# used as a decorator
|
||||
def deco(func_or_class):
|
||||
name = func_or_class.__name__
|
||||
self._do_register(name, func_or_class, suffix)
|
||||
return func_or_class
|
||||
|
||||
return deco
|
||||
|
||||
# used as a function call
|
||||
name = obj.__name__
|
||||
self._do_register(name, obj, suffix)
|
||||
|
||||
def get(self, name, suffix='basicsr'):
|
||||
ret = self._obj_map.get(name)
|
||||
if ret is None:
|
||||
ret = self._obj_map.get(name + '_' + suffix)
|
||||
print(f'Name {name} is not found, use name: {name}_{suffix}!')
|
||||
if ret is None:
|
||||
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
|
||||
return ret
|
||||
|
||||
def __contains__(self, name):
|
||||
return name in self._obj_map
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._obj_map.items())
|
||||
|
||||
def keys(self):
|
||||
return self._obj_map.keys()
|
||||
|
||||
|
||||
DATASET_REGISTRY = Registry('dataset')
|
||||
ARCH_REGISTRY = Registry('arch')
|
||||
MODEL_REGISTRY = Registry('model')
|
||||
LOSS_REGISTRY = Registry('loss')
|
||||
METRIC_REGISTRY = Registry('metric')
|
||||
Reference in New Issue
Block a user