first commit

This commit is contained in:
zsyOAOA
2024-12-11 18:46:36 +08:00
parent 9e65255d34
commit 27f2eb7dc3
847 changed files with 377076 additions and 2 deletions

35
LICENSE Normal file
View File

@@ -0,0 +1,35 @@
S-Lab License 1.0
Copyright 2024 S-Lab
Redistribution and use for non-commercial purpose in source and
binary forms, with or without modification, are permitted provided
that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
In the event that redistribution and/or use for commercial purpose in
source or binary forms, with or without modification is required,
please contact the contributor(s) of the work.

View File

@@ -1,2 +1,95 @@
# InvSR
Arbitrary-steps Image Super-resolution via Diffusion Inversion
# Arbitrary-steps Image Super-resolution via Diffusion Inversion
[Zongsheng Yue](https://zsyoaoa.github.io/), [Kang Liao](https://kangliao929.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)
<!--[Paper](https://arxiv.org/abs/2307.12348) | [Project Page](https://zsyoaoa.github.io/projects/resshift/) | [Demo](https://www.youtube.com/watch?v=8DB-6Xvvl5o)-->
<!--<a href="https://colab.research.google.com/drive/1CL8aJO7a_RA4MetanrCLqQO5H7KWO8KI?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a> [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/cjwbw/resshift) ![visitors](https://visitor-badge.laobi.icu/badge?page_id=zsyOAOA/ResShift) -->
:star: If InvSR is helpful to your researches or projects, please help star this repo. Thanks! :hugs:
---
>This study presents a new image super-resolution (SR) technique based on diffusion inversion, aiming at harnessing the rich image priors encapsulated in large pre-trained diffusion models to improve SR performance. We design a \textit{Partial noise Prediction} strategy to construct an intermediate state of the diffusion model, which serves as the starting sampling point. Central to our approach is a deep noise predictor to estimate the optimal noise maps for the forward diffusion process. Once trained, this noise predictor can be used to initialize the sampling process partially along the diffusion trajectory, generating the desirable high-resolution result. Compared to existing approaches, our method offers a flexible and efficient sampling mechanism that supports an arbitrary number of sampling steps, ranging from one to five. Even with a single sampling step, our method demonstrates superior or comparable performance to recent state-of-the-art approaches.
><img src="./assets/framework.png" align="middle" width="800">
---
## Update
- **2024.12.11**: Create this repo.
## Requirements
* Python 3.10, Pytorch 2.4.0, [xformers](https://github.com/facebookresearch/xformers) 0.0.27.post2
* More detail (See [environment.yaml](environment.yaml))
A suitable [conda](https://conda.io/) environment named `invsr` can be created and activated with:
```
conda create -n invsr python=3.10
conda activate invsr
pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121s
pip install -U xformers==0.0.27.post2 --index-url https://download.pytorch.org/whl/cu121
pip install -e ".[torch]"
pip install -r requirements.txt
```
## Applications
### :point_right: Real-world Image Super-resolution
[<img src="assets/real-7.png" height="280"/>](https://imgsli.com/MzI2MTU5) [<img src="assets/real-1.png" height="280"/>](https://imgsli.com/MzI2MTUx) [<img src="assets/real-2.png" height="280"/>](https://imgsli.com/MzI2MTUy) <!--[<img src="assets/real-3.png" height="256"/>](https://imgsli.com/MzI2MTUz)-->
[<img src="assets/real-4.png" height="430"/>](https://imgsli.com/MzI2MTU0) [<img src="assets/real-6.png" height="430"/>](https://imgsli.com/MzI2MTU3) [<img src="assets/real-5.png" height="430"/>](https://imgsli.com/MzI2MTU1)
### :point_right: General Image Inhancement
[<img src="assets/enhance-1.png" height="294"/>](https://imgsli.com/MzI2MTYw) [<img src="assets/enhance-2.png" height="294"/>](https://imgsli.com/MzI2MTYy)
[<img src="assets/enhance-3.png" height="247.5"/>](https://imgsli.com/MzI2MjAx) [<img src="assets/enhance-4.png" height="247.5"/>](https://imgsli.com/MzI2MjAz) [<img src="assets/enhance-5.png" height="247.5"/>](https://imgsli.com/MzI2MjA0)
### :point_right: AIGC Image Inhancement
[<img src="assets/sdxl-1.png" height="324"/>](https://imgsli.com/MzI2MjQy) [<img src="assets/sdxl-2.png" height="324"/>](https://imgsli.com/MzI2MjQ1) [<img src="assets/sdxl-3.png" height="324"/>](https://imgsli.com/MzI2MjQ3)
[<img src="assets/flux-1.png" height="324"/>](https://imgsli.com/MzI2MjQ5) [<img src="assets/flux-2.png" height="324"/>](https://imgsli.com/MzI2MjUw) [<img src="assets/flux-3.png" height="324"/>](https://imgsli.com/MzI2MjUx)
<!--## Online Demo-->
<!--You can try our method through an online demo:-->
<!--```-->
<!--python app.py-->
<!--```-->
## Inference
### :rocket: Fast testing
```
python inference_invsr.py -i [image folder/image path] -o [result folder] --num_steps 1
```
1. This script will automatically download the pre-trained [noise predictor](https://huggingface.co/OAOA/InvSR/tree/main) and [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo/tree/main). If you have pre-downloaded them manually, please include them via ``--started_ckpt_path`` and ``--sd_path``.
2. You can freely adjust the sampling steps via ``--num_steps``.
### :airplane: Reproducing our paper results
+ Synthetic dataset of ImageNet-Test: [Google Drive](https://drive.google.com/file/d/1PRGrujx3OFilgJ7I6nW7ETIR00wlAl2m/view?usp=sharing).
+ Real data for image super-resolution: [RealSRV3](https://github.com/csjcai/RealSR) | [RealSet80](testdata/RealSet80)
+ To reproduce the quantitative results on Imagenet-Test and RealSRV3, please add the color fixing options by ``--color_fix wavelet``.
## Training
### :turtle: Preparing stage
1. Download the finetuned LPIPS model from this [link](https://huggingface.co/OAOA/InvSR/resolve/main/vgg16_sdturbo_lpips.pth?download=true) and put it in the folder of "weights".
2. Prepare the [config](configs/sd-turbo-sr-ldis.yaml) file:
+ SD-Turbo path: configs.sd_pipe.params.cache_dir.
+ Training data path: data.train.params.data_source.
+ Validation data path: data.val.params.dir_path (low-quality image) and data.val.params.extra_dir_path (high-quality image).
+ Batchsize: configs.train.batch and configs.train.microbatch (total batchsize = microbatch * #GPUS * num_grad_accumulation)
### :dolphin: Begin training
```
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --standalone --nproc_per_node=4 --nnodes=1 main.py --save_dir [Logging Folder]
```
### :whale: Resume from interruption
```
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --standalone --nproc_per_node=4 --nnodes=1 main.py --save_dir [Logging Folder] --resume save_dir/ckpts/model_xx.pth
```
## License
This project is licensed under [NTU S-Lab License 1.0](LICENSE). Redistribution and use should follow this license.
## Acknowledgement
This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR) and [diffusers](https://github.com/huggingface/diffusers). Thanks for their awesome works.
### Contact
If you have any questions, please feel free to contact me via `zsyzam@gmail.com`.

BIN
assets/.DS_Store vendored Normal file

Binary file not shown.

BIN
assets/enhance-1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.0 MiB

BIN
assets/enhance-2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 763 KiB

BIN
assets/enhance-3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

BIN
assets/enhance-4.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 MiB

BIN
assets/enhance-5.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.6 MiB

BIN
assets/flux-1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.6 MiB

BIN
assets/flux-2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 MiB

BIN
assets/flux-3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.4 MiB

BIN
assets/framework.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 419 KiB

BIN
assets/real-1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 MiB

BIN
assets/real-2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

BIN
assets/real-3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

BIN
assets/real-4.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

BIN
assets/real-5.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

BIN
assets/real-6.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 MiB

BIN
assets/real-7.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.6 MiB

BIN
assets/sdxl-1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 MiB

BIN
assets/sdxl-2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.4 MiB

BIN
assets/sdxl-3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 MiB

BIN
basicsr/.DS_Store vendored Normal file

Binary file not shown.

4
basicsr/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
# https://github.com/xinntao/BasicSR
# flake8: noqa
from .data import *
from .utils import *

BIN
basicsr/data/.DS_Store vendored Normal file

Binary file not shown.

101
basicsr/data/__init__.py Normal file
View File

@@ -0,0 +1,101 @@
import importlib
import numpy as np
import random
import torch
import torch.utils.data
from copy import deepcopy
from functools import partial
from os import path as osp
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.registry import DATASET_REGISTRY
__all__ = ['build_dataset', 'build_dataloader']
# automatically scan and import dataset modules for registry
# scan all the files under the data folder with '_dataset' in file names
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
# import all the dataset modules
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
def build_dataset(dataset_opt):
"""Build dataset from options.
Args:
dataset_opt (dict): Configuration for dataset. It must contain:
name (str): Dataset name.
type (str): Dataset type.
"""
dataset_opt = deepcopy(dataset_opt)
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
logger = get_root_logger()
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
return dataset
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
"""Build dataloader.
Args:
dataset (torch.utils.data.Dataset): Dataset.
dataset_opt (dict): Dataset options. It contains the following keys:
phase (str): 'train' or 'val'.
num_worker_per_gpu (int): Number of workers for each GPU.
batch_size_per_gpu (int): Training batch size for each GPU.
num_gpu (int): Number of GPUs. Used only in the train phase.
Default: 1.
dist (bool): Whether in distributed training. Used only in the train
phase. Default: False.
sampler (torch.utils.data.sampler): Data sampler. Default: None.
seed (int | None): Seed. Default: None
"""
phase = dataset_opt['phase']
rank, _ = get_dist_info()
if phase == 'train':
if dist: # distributed training
batch_size = dataset_opt['batch_size_per_gpu']
num_workers = dataset_opt['num_worker_per_gpu']
else: # non-distributed training
multiplier = 1 if num_gpu == 0 else num_gpu
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
dataloader_args = dict(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
sampler=sampler,
drop_last=True)
if sampler is None:
dataloader_args['shuffle'] = True
dataloader_args['worker_init_fn'] = partial(
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
elif phase in ['val', 'test']: # validation
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
else:
raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
prefetch_mode = dataset_opt.get('prefetch_mode')
if prefetch_mode == 'cpu': # CPUPrefetcher
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
logger = get_root_logger()
logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
else:
# prefetch_mode=None: Normal dataloader
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
return torch.utils.data.DataLoader(**dataloader_args)
def worker_init_fn(worker_id, num_workers, rank, seed):
# Set the worker seed to num_workers * rank + worker_id + seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,48 @@
import math
import torch
from torch.utils.data.sampler import Sampler
class EnlargedSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
Modified from torch.utils.data.distributed.DistributedSampler
Support enlarging the dataset for iteration-based training, for saving
time when restart the dataloader after each epoch
Args:
dataset (torch.utils.data.Dataset): Dataset used for sampling.
num_replicas (int | None): Number of processes participating in
the training. It is usually the world_size.
rank (int | None): Rank of the current process within num_replicas.
ratio (int): Enlarging ratio. Default: 1.
"""
def __init__(self, dataset, num_replicas, rank, ratio=1):
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(self.total_size, generator=g).tolist()
dataset_size = len(self.dataset)
indices = [v % dataset_size for v in indices]
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch

315
basicsr/data/data_util.py Normal file
View File

@@ -0,0 +1,315 @@
import cv2
import numpy as np
import torch
from os import path as osp
from torch.nn import functional as F
from basicsr.data.transforms import mod_crop
from basicsr.utils import img2tensor, scandir
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
"""Read a sequence of images from a given folder path.
Args:
path (list[str] | str): List of image paths or image folder path.
require_mod_crop (bool): Require mod crop for each image.
Default: False.
scale (int): Scale factor for mod_crop. Default: 1.
return_imgname(bool): Whether return image names. Default False.
Returns:
Tensor: size (t, c, h, w), RGB, [0, 1].
list[str]: Returned image name list.
"""
if isinstance(path, list):
img_paths = path
else:
img_paths = sorted(list(scandir(path, full_path=True)))
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
if require_mod_crop:
imgs = [mod_crop(img, scale) for img in imgs]
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
imgs = torch.stack(imgs, dim=0)
if return_imgname:
imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
return imgs, imgnames
else:
return imgs
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
"""Generate an index list for reading `num_frames` frames from a sequence
of images.
Args:
crt_idx (int): Current center index.
max_frame_num (int): Max number of the sequence of images (from 1).
num_frames (int): Reading num_frames frames.
padding (str): Padding mode, one of
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
Examples: current_idx = 0, num_frames = 5
The generated frame indices under different padding mode:
replicate: [0, 0, 0, 1, 2]
reflection: [2, 1, 0, 1, 2]
reflection_circle: [4, 3, 0, 1, 2]
circle: [3, 4, 0, 1, 2]
Returns:
list[int]: A list of indices.
"""
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
max_frame_num = max_frame_num - 1 # start from 0
num_pad = num_frames // 2
indices = []
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
if i < 0:
if padding == 'replicate':
pad_idx = 0
elif padding == 'reflection':
pad_idx = -i
elif padding == 'reflection_circle':
pad_idx = crt_idx + num_pad - i
else:
pad_idx = num_frames + i
elif i > max_frame_num:
if padding == 'replicate':
pad_idx = max_frame_num
elif padding == 'reflection':
pad_idx = max_frame_num * 2 - i
elif padding == 'reflection_circle':
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
else:
pad_idx = i - num_frames
else:
pad_idx = i
indices.append(pad_idx)
return indices
def paired_paths_from_lmdb(folders, keys):
"""Generate paired paths from lmdb files.
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
::
lq.lmdb
├── data.mdb
├── lock.mdb
├── meta_info.txt
The data.mdb and lock.mdb are standard lmdb files and you can refer to
https://lmdb.readthedocs.io/en/release/ for more details.
The meta_info.txt is a specified txt file to record the meta information
of our datasets. It will be automatically created when preparing
datasets by our provided dataset tools.
Each line in the txt file records
1)image name (with extension),
2)image shape,
3)compression level, separated by a white space.
Example: `baboon.png (120,125,3) 1`
We use the image name without extension as the lmdb key.
Note that we use the same key for the corresponding lq and gt images.
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt'].
Note that this key is different from lmdb keys.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
input_folder, gt_folder = folders
input_key, gt_key = keys
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
f'formats. But received {input_key}: {input_folder}; '
f'{gt_key}: {gt_folder}')
# ensure that the two meta_info files are the same
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
input_lmdb_keys = [line.split('.')[0] for line in fin]
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
gt_lmdb_keys = [line.split('.')[0] for line in fin]
if set(input_lmdb_keys) != set(gt_lmdb_keys):
raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
else:
paths = []
for lmdb_key in sorted(input_lmdb_keys):
paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
return paths
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
"""Generate paired paths from an meta information file.
Each line in the meta information file contains the image names and
image shape (usually for gt), separated by a white space.
Example of an meta information file:
```
0001_s001.png (480,480,3)
0001_s002.png (480,480,3)
```
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt'].
meta_info_file (str): Path to the meta information file.
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Usually the filename_tmpl is
for files in the input folder.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
input_folder, gt_folder = folders
input_key, gt_key = keys
with open(meta_info_file, 'r') as fin:
gt_names = [line.strip().split(' ')[0] for line in fin]
paths = []
for gt_name in gt_names:
basename, ext = osp.splitext(osp.basename(gt_name))
input_name = f'{filename_tmpl.format(basename)}{ext}'
input_path = osp.join(input_folder, input_name)
gt_path = osp.join(gt_folder, gt_name)
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
return paths
def paired_paths_from_folder(folders, keys, filename_tmpl):
"""Generate paired paths from folders.
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt'].
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Usually the filename_tmpl is
for files in the input folder.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
input_folder, gt_folder = folders
input_key, gt_key = keys
input_paths = list(scandir(input_folder))
gt_paths = list(scandir(gt_folder))
assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
f'{len(input_paths)}, {len(gt_paths)}.')
paths = []
for gt_path in gt_paths:
basename, ext = osp.splitext(osp.basename(gt_path))
input_name = f'{filename_tmpl.format(basename)}{ext}'
input_path = osp.join(input_folder, input_name)
assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
gt_path = osp.join(gt_folder, gt_path)
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
return paths
def paths_from_folder(folder):
"""Generate paths from folder.
Args:
folder (str): Folder path.
Returns:
list[str]: Returned path list.
"""
paths = list(scandir(folder))
paths = [osp.join(folder, path) for path in paths]
return paths
def paths_from_lmdb(folder):
"""Generate paths from lmdb.
Args:
folder (str): Folder path.
Returns:
list[str]: Returned path list.
"""
if not folder.endswith('.lmdb'):
raise ValueError(f'Folder {folder}folder should in lmdb format.')
with open(osp.join(folder, 'meta_info.txt')) as fin:
paths = [line.split('.')[0] for line in fin]
return paths
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
"""Generate Gaussian kernel used in `duf_downsample`.
Args:
kernel_size (int): Kernel size. Default: 13.
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
Returns:
np.array: The Gaussian kernel.
"""
from scipy.ndimage import filters as filters
kernel = np.zeros((kernel_size, kernel_size))
# set element at the middle to one, a dirac delta
kernel[kernel_size // 2, kernel_size // 2] = 1
# gaussian-smooth the dirac, resulting in a gaussian filter
return filters.gaussian_filter(kernel, sigma)
def duf_downsample(x, kernel_size=13, scale=4):
"""Downsamping with Gaussian kernel used in the DUF official code.
Args:
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
kernel_size (int): Kernel size. Default: 13.
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
Default: 4.
Returns:
Tensor: DUF downsampled frames.
"""
assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
squeeze_flag = False
if x.ndim == 4:
squeeze_flag = True
x = x.unsqueeze(0)
b, t, c, h, w = x.size()
x = x.view(-1, 1, h, w)
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
x = F.conv2d(x, gaussian_filter, stride=scale)
x = x[:, :, 2:-2, 2:-2]
x = x.view(b, t, c, x.size(2), x.size(3))
if squeeze_flag:
x = x.squeeze(0)
return x

View File

@@ -0,0 +1,765 @@
import cv2
import math
import numpy as np
import random
import torch
from scipy import special
from scipy.stats import multivariate_normal
# from torchvision.transforms.functional_tensor import rgb_to_grayscale
from torchvision.transforms.functional import rgb_to_grayscale
# -------------------------------------------------------------------- #
# --------------------------- blur kernels --------------------------- #
# -------------------------------------------------------------------- #
# --------------------------- util functions --------------------------- #
def sigma_matrix2(sig_x, sig_y, theta):
"""Calculate the rotated sigma matrix (two dimensional matrix).
Args:
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
Returns:
ndarray: Rotated sigma matrix.
"""
d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
def mesh_grid(kernel_size):
"""Generate the mesh grid, centering at zero.
Args:
kernel_size (int):
Returns:
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
xx (ndarray): with the shape (kernel_size, kernel_size)
yy (ndarray): with the shape (kernel_size, kernel_size)
"""
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
xx, yy = np.meshgrid(ax, ax)
xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
1))).reshape(kernel_size, kernel_size, 2)
return xy, xx, yy
def pdf2(sigma_matrix, grid):
"""Calculate PDF of the bivariate Gaussian distribution.
Args:
sigma_matrix (ndarray): with the shape (2, 2)
grid (ndarray): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size.
Returns:
kernel (ndarrray): un-normalized kernel.
"""
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
return kernel
def cdf2(d_matrix, grid):
"""Calculate the CDF of the standard bivariate Gaussian distribution.
Used in skewed Gaussian distribution.
Args:
d_matrix (ndarrasy): skew matrix.
grid (ndarray): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size.
Returns:
cdf (ndarray): skewed cdf.
"""
rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
grid = np.dot(grid, d_matrix)
cdf = rv.cdf(grid)
return cdf
def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
isotropic (bool):
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
kernel = pdf2(sigma_matrix, grid)
kernel = kernel / np.sum(kernel)
return kernel
def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
"""Generate a bivariate generalized Gaussian kernel.
``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions``
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
beta (float): shape parameter, beta = 1 is the normal distribution.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
kernel = kernel / np.sum(kernel)
return kernel
def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
"""Generate a plateau-like anisotropic kernel.
1 / (1+x^(beta))
Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
beta (float): shape parameter, beta = 1 is the normal distribution.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
kernel = kernel / np.sum(kernel)
return kernel
def random_bivariate_Gaussian(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
noise_range=None,
isotropic=True):
"""Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
if isotropic is False:
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
else:
sigma_y = sigma_x
rotation = 0
kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
return kernel
def random_bivariate_generalized_Gaussian(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
beta_range,
noise_range=None,
isotropic=True):
"""Randomly generate bivariate generalized Gaussian kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
beta_range (tuple): [0.5, 8]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
if isotropic is False:
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
else:
sigma_y = sigma_x
rotation = 0
# assume beta_range[0] < 1 < beta_range[1]
if np.random.uniform() < 0.5:
beta = np.random.uniform(beta_range[0], 1)
else:
beta = np.random.uniform(1, beta_range[1])
kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
return kernel
def random_bivariate_plateau(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
beta_range,
noise_range=None,
isotropic=True):
"""Randomly generate bivariate plateau kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi/2, math.pi/2]
beta_range (tuple): [1, 4]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
if isotropic is False:
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
else:
sigma_y = sigma_x
rotation = 0
# TODO: this may be not proper
if np.random.uniform() < 0.5:
beta = np.random.uniform(beta_range[0], 1)
else:
beta = np.random.uniform(1, beta_range[1])
kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
return kernel
def random_mixed_kernels(kernel_list,
kernel_prob,
kernel_size=21,
sigma_x_range=(0.6, 5),
sigma_y_range=(0.6, 5),
rotation_range=(-math.pi, math.pi),
betag_range=(0.5, 8),
betap_range=(0.5, 8),
noise_range=None):
"""Randomly generate mixed kernels.
Args:
kernel_list (tuple): a list name of kernel types,
support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
'plateau_aniso']
kernel_prob (tuple): corresponding kernel probability for each
kernel type
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
beta_range (tuple): [0.5, 8]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
kernel_type = random.choices(kernel_list, kernel_prob)[0]
if kernel_type == 'iso':
kernel = random_bivariate_Gaussian(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
elif kernel_type == 'aniso':
kernel = random_bivariate_Gaussian(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
elif kernel_type == 'generalized_iso':
kernel = random_bivariate_generalized_Gaussian(
kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
betag_range,
noise_range=noise_range,
isotropic=True)
elif kernel_type == 'generalized_aniso':
kernel = random_bivariate_generalized_Gaussian(
kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
betag_range,
noise_range=noise_range,
isotropic=False)
elif kernel_type == 'plateau_iso':
kernel = random_bivariate_plateau(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
elif kernel_type == 'plateau_aniso':
kernel = random_bivariate_plateau(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
return kernel
np.seterr(divide='ignore', invalid='ignore')
def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
"""2D sinc filter
Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
Args:
cutoff (float): cutoff frequency in radians (pi is max)
kernel_size (int): horizontal and vertical size, must be odd.
pad_to (int): pad kernel size to desired size, must be odd or zero.
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
kernel = np.fromfunction(
lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
kernel = kernel / np.sum(kernel)
if pad_to > kernel_size:
pad_size = (pad_to - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
return kernel
# ------------------------------------------------------------- #
# --------------------------- noise --------------------------- #
# ------------------------------------------------------------- #
# ----------------------- Gaussian Noise ----------------------- #
def generate_gaussian_noise(img, sigma=10, gray_noise=False):
"""Generate Gaussian noise.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
sigma (float): Noise scale (measured in range 255). Default: 10.
Returns:
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
float32.
"""
if gray_noise:
noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
else:
noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
return noise
def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
"""Add Gaussian noise.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
sigma (float): Noise scale (measured in range 255). Default: 10.
Returns:
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
float32.
"""
noise = generate_gaussian_noise(img, sigma, gray_noise)
out = img + noise
if clip and rounds:
out = np.clip((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
"""Add Gaussian noise (PyTorch version).
Args:
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
scale (float | Tensor): Noise scale. Default: 1.0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
b, _, h, w = img.size()
if not isinstance(sigma, (float, int)):
sigma = sigma.view(img.size(0), 1, 1, 1)
if isinstance(gray_noise, (float, int)):
cal_gray_noise = gray_noise > 0
else:
gray_noise = gray_noise.view(b, 1, 1, 1)
cal_gray_noise = torch.sum(gray_noise) > 0
if cal_gray_noise:
noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
noise_gray = noise_gray.view(b, 1, h, w)
# always calculate color noise
noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
if cal_gray_noise:
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
return noise
def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
"""Add Gaussian noise (PyTorch version).
Args:
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
scale (float | Tensor): Noise scale. Default: 1.0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
# ----------------------- Random Gaussian Noise ----------------------- #
def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
if np.random.uniform() < gray_prob:
gray_noise = True
else:
gray_noise = False
return generate_gaussian_noise(img, sigma, gray_noise)
def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
out = img + noise
if clip and rounds:
out = np.clip((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
sigma = torch.rand(
img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
gray_noise = (gray_noise < gray_prob).float()
return generate_gaussian_noise_pt(img, sigma, gray_noise)
def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
# ----------------------- Poisson (Shot) Noise ----------------------- #
def generate_poisson_noise(img, scale=1.0, gray_noise=False):
"""Generate poisson noise.
Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
scale (float): Noise scale. Default: 1.0.
gray_noise (bool): Whether generate gray noise. Default: False.
Returns:
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
float32.
"""
if gray_noise:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# round and clip image for counting vals correctly
img = np.clip((img * 255.0).round(), 0, 255) / 255.
vals = len(np.unique(img))
vals = 2**np.ceil(np.log2(vals))
out = np.float32(np.random.poisson(img * vals) / float(vals))
noise = out - img
if gray_noise:
noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
return noise * scale
def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
"""Add poisson noise.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
scale (float): Noise scale. Default: 1.0.
gray_noise (bool): Whether generate gray noise. Default: False.
Returns:
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
float32.
"""
noise = generate_poisson_noise(img, scale, gray_noise)
out = img + noise
if clip and rounds:
out = np.clip((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
"""Generate a batch of poisson noise (PyTorch version)
Args:
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
Default: 1.0.
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
0 for False, 1 for True. Default: 0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
b, _, h, w = img.size()
if isinstance(gray_noise, (float, int)):
cal_gray_noise = gray_noise > 0
else:
gray_noise = gray_noise.view(b, 1, 1, 1)
cal_gray_noise = torch.sum(gray_noise) > 0
if cal_gray_noise:
img_gray = rgb_to_grayscale(img, num_output_channels=1)
# round and clip image for counting vals correctly
img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
# use for-loop to get the unique values for each sample
vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
out = torch.poisson(img_gray * vals) / vals
noise_gray = out - img_gray
noise_gray = noise_gray.expand(b, 3, h, w)
# always calculate color noise
# round and clip image for counting vals correctly
img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
# use for-loop to get the unique values for each sample
vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
out = torch.poisson(img * vals) / vals
noise = out - img
if cal_gray_noise:
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
if not isinstance(scale, (float, int)):
scale = scale.view(b, 1, 1, 1)
return noise * scale
def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
"""Add poisson noise to a batch of images (PyTorch version).
Args:
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
Default: 1.0.
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
0 for False, 1 for True. Default: 0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
noise = generate_poisson_noise_pt(img, scale, gray_noise)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
# ----------------------- Random Poisson (Shot) Noise ----------------------- #
def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
scale = np.random.uniform(scale_range[0], scale_range[1])
if np.random.uniform() < gray_prob:
gray_noise = True
else:
gray_noise = False
return generate_poisson_noise(img, scale, gray_noise)
def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
noise = random_generate_poisson_noise(img, scale_range, gray_prob)
out = img + noise
if clip and rounds:
out = np.clip((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
scale = torch.rand(
img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
gray_noise = (gray_noise < gray_prob).float()
return generate_poisson_noise_pt(img, scale, gray_noise)
def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
# ------------------------------------------------------------------------ #
# --------------------------- JPEG compression --------------------------- #
# ------------------------------------------------------------------------ #
def add_jpg_compression(img, quality=90):
"""Add JPG compression artifacts.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
quality (float): JPG compression quality. 0 for lowest quality, 100 for
best quality. Default: 90.
Returns:
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
float32.
"""
img = np.clip(img, 0, 1)
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(quality)]
_, encimg = cv2.imencode('.jpg', img * 255., encode_param)
img = np.float32(cv2.imdecode(encimg, 1)) / 255.
return img
def random_add_jpg_compression(img, quality_range=(90, 100)):
"""Randomly add JPG compression artifacts.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
quality_range (tuple[float] | list[float]): JPG compression quality
range. 0 for lowest quality, 100 for best quality.
Default: (90, 100).
Returns:
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
float32.
"""
quality = np.random.uniform(quality_range[0], quality_range[1])
return add_jpg_compression(img, quality)

View File

@@ -0,0 +1,80 @@
import random
import time
from os import path as osp
from torch.utils import data as data
from torchvision.transforms.functional import normalize
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class FFHQDataset(data.Dataset):
"""FFHQ dataset for StyleGAN.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
io_backend (dict): IO backend type and other kwarg.
mean (list | tuple): Image mean.
std (list | tuple): Image std.
use_hflip (bool): Whether to horizontally flip.
"""
def __init__(self, opt):
super(FFHQDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.gt_folder = opt['dataroot_gt']
self.mean = opt['mean']
self.std = opt['std']
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = self.gt_folder
if not self.gt_folder.endswith('.lmdb'):
raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
self.paths = [line.split('.')[0] for line in fin]
else:
# FFHQ has 70000 images in total
self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)]
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# load gt image
gt_path = self.paths[index]
# avoid errors caused by high latency in reading files
retry = 3
while retry > 0:
try:
img_bytes = self.file_client.get(gt_path)
except Exception as e:
logger = get_root_logger()
logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}')
# change another file to read
index = random.randint(0, self.__len__())
gt_path = self.paths[index]
time.sleep(1) # sleep 1s for occasional server congestion
else:
break
finally:
retry -= 1
img_gt = imfrombytes(img_bytes, float32=True)
# random horizontal flip
img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
# normalize
normalize(img_gt, self.mean, self.std, inplace=True)
return {'gt': img_gt, 'gt_path': gt_path}
def __len__(self):
return len(self.paths)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,4 @@
000 100 (720,1280,3)
011 100 (720,1280,3)
015 100 (720,1280,3)
020 100 (720,1280,3)

View File

@@ -0,0 +1,270 @@
000 100 (720,1280,3)
001 100 (720,1280,3)
002 100 (720,1280,3)
003 100 (720,1280,3)
004 100 (720,1280,3)
005 100 (720,1280,3)
006 100 (720,1280,3)
007 100 (720,1280,3)
008 100 (720,1280,3)
009 100 (720,1280,3)
010 100 (720,1280,3)
011 100 (720,1280,3)
012 100 (720,1280,3)
013 100 (720,1280,3)
014 100 (720,1280,3)
015 100 (720,1280,3)
016 100 (720,1280,3)
017 100 (720,1280,3)
018 100 (720,1280,3)
019 100 (720,1280,3)
020 100 (720,1280,3)
021 100 (720,1280,3)
022 100 (720,1280,3)
023 100 (720,1280,3)
024 100 (720,1280,3)
025 100 (720,1280,3)
026 100 (720,1280,3)
027 100 (720,1280,3)
028 100 (720,1280,3)
029 100 (720,1280,3)
030 100 (720,1280,3)
031 100 (720,1280,3)
032 100 (720,1280,3)
033 100 (720,1280,3)
034 100 (720,1280,3)
035 100 (720,1280,3)
036 100 (720,1280,3)
037 100 (720,1280,3)
038 100 (720,1280,3)
039 100 (720,1280,3)
040 100 (720,1280,3)
041 100 (720,1280,3)
042 100 (720,1280,3)
043 100 (720,1280,3)
044 100 (720,1280,3)
045 100 (720,1280,3)
046 100 (720,1280,3)
047 100 (720,1280,3)
048 100 (720,1280,3)
049 100 (720,1280,3)
050 100 (720,1280,3)
051 100 (720,1280,3)
052 100 (720,1280,3)
053 100 (720,1280,3)
054 100 (720,1280,3)
055 100 (720,1280,3)
056 100 (720,1280,3)
057 100 (720,1280,3)
058 100 (720,1280,3)
059 100 (720,1280,3)
060 100 (720,1280,3)
061 100 (720,1280,3)
062 100 (720,1280,3)
063 100 (720,1280,3)
064 100 (720,1280,3)
065 100 (720,1280,3)
066 100 (720,1280,3)
067 100 (720,1280,3)
068 100 (720,1280,3)
069 100 (720,1280,3)
070 100 (720,1280,3)
071 100 (720,1280,3)
072 100 (720,1280,3)
073 100 (720,1280,3)
074 100 (720,1280,3)
075 100 (720,1280,3)
076 100 (720,1280,3)
077 100 (720,1280,3)
078 100 (720,1280,3)
079 100 (720,1280,3)
080 100 (720,1280,3)
081 100 (720,1280,3)
082 100 (720,1280,3)
083 100 (720,1280,3)
084 100 (720,1280,3)
085 100 (720,1280,3)
086 100 (720,1280,3)
087 100 (720,1280,3)
088 100 (720,1280,3)
089 100 (720,1280,3)
090 100 (720,1280,3)
091 100 (720,1280,3)
092 100 (720,1280,3)
093 100 (720,1280,3)
094 100 (720,1280,3)
095 100 (720,1280,3)
096 100 (720,1280,3)
097 100 (720,1280,3)
098 100 (720,1280,3)
099 100 (720,1280,3)
100 100 (720,1280,3)
101 100 (720,1280,3)
102 100 (720,1280,3)
103 100 (720,1280,3)
104 100 (720,1280,3)
105 100 (720,1280,3)
106 100 (720,1280,3)
107 100 (720,1280,3)
108 100 (720,1280,3)
109 100 (720,1280,3)
110 100 (720,1280,3)
111 100 (720,1280,3)
112 100 (720,1280,3)
113 100 (720,1280,3)
114 100 (720,1280,3)
115 100 (720,1280,3)
116 100 (720,1280,3)
117 100 (720,1280,3)
118 100 (720,1280,3)
119 100 (720,1280,3)
120 100 (720,1280,3)
121 100 (720,1280,3)
122 100 (720,1280,3)
123 100 (720,1280,3)
124 100 (720,1280,3)
125 100 (720,1280,3)
126 100 (720,1280,3)
127 100 (720,1280,3)
128 100 (720,1280,3)
129 100 (720,1280,3)
130 100 (720,1280,3)
131 100 (720,1280,3)
132 100 (720,1280,3)
133 100 (720,1280,3)
134 100 (720,1280,3)
135 100 (720,1280,3)
136 100 (720,1280,3)
137 100 (720,1280,3)
138 100 (720,1280,3)
139 100 (720,1280,3)
140 100 (720,1280,3)
141 100 (720,1280,3)
142 100 (720,1280,3)
143 100 (720,1280,3)
144 100 (720,1280,3)
145 100 (720,1280,3)
146 100 (720,1280,3)
147 100 (720,1280,3)
148 100 (720,1280,3)
149 100 (720,1280,3)
150 100 (720,1280,3)
151 100 (720,1280,3)
152 100 (720,1280,3)
153 100 (720,1280,3)
154 100 (720,1280,3)
155 100 (720,1280,3)
156 100 (720,1280,3)
157 100 (720,1280,3)
158 100 (720,1280,3)
159 100 (720,1280,3)
160 100 (720,1280,3)
161 100 (720,1280,3)
162 100 (720,1280,3)
163 100 (720,1280,3)
164 100 (720,1280,3)
165 100 (720,1280,3)
166 100 (720,1280,3)
167 100 (720,1280,3)
168 100 (720,1280,3)
169 100 (720,1280,3)
170 100 (720,1280,3)
171 100 (720,1280,3)
172 100 (720,1280,3)
173 100 (720,1280,3)
174 100 (720,1280,3)
175 100 (720,1280,3)
176 100 (720,1280,3)
177 100 (720,1280,3)
178 100 (720,1280,3)
179 100 (720,1280,3)
180 100 (720,1280,3)
181 100 (720,1280,3)
182 100 (720,1280,3)
183 100 (720,1280,3)
184 100 (720,1280,3)
185 100 (720,1280,3)
186 100 (720,1280,3)
187 100 (720,1280,3)
188 100 (720,1280,3)
189 100 (720,1280,3)
190 100 (720,1280,3)
191 100 (720,1280,3)
192 100 (720,1280,3)
193 100 (720,1280,3)
194 100 (720,1280,3)
195 100 (720,1280,3)
196 100 (720,1280,3)
197 100 (720,1280,3)
198 100 (720,1280,3)
199 100 (720,1280,3)
200 100 (720,1280,3)
201 100 (720,1280,3)
202 100 (720,1280,3)
203 100 (720,1280,3)
204 100 (720,1280,3)
205 100 (720,1280,3)
206 100 (720,1280,3)
207 100 (720,1280,3)
208 100 (720,1280,3)
209 100 (720,1280,3)
210 100 (720,1280,3)
211 100 (720,1280,3)
212 100 (720,1280,3)
213 100 (720,1280,3)
214 100 (720,1280,3)
215 100 (720,1280,3)
216 100 (720,1280,3)
217 100 (720,1280,3)
218 100 (720,1280,3)
219 100 (720,1280,3)
220 100 (720,1280,3)
221 100 (720,1280,3)
222 100 (720,1280,3)
223 100 (720,1280,3)
224 100 (720,1280,3)
225 100 (720,1280,3)
226 100 (720,1280,3)
227 100 (720,1280,3)
228 100 (720,1280,3)
229 100 (720,1280,3)
230 100 (720,1280,3)
231 100 (720,1280,3)
232 100 (720,1280,3)
233 100 (720,1280,3)
234 100 (720,1280,3)
235 100 (720,1280,3)
236 100 (720,1280,3)
237 100 (720,1280,3)
238 100 (720,1280,3)
239 100 (720,1280,3)
240 100 (720,1280,3)
241 100 (720,1280,3)
242 100 (720,1280,3)
243 100 (720,1280,3)
244 100 (720,1280,3)
245 100 (720,1280,3)
246 100 (720,1280,3)
247 100 (720,1280,3)
248 100 (720,1280,3)
249 100 (720,1280,3)
250 100 (720,1280,3)
251 100 (720,1280,3)
252 100 (720,1280,3)
253 100 (720,1280,3)
254 100 (720,1280,3)
255 100 (720,1280,3)
256 100 (720,1280,3)
257 100 (720,1280,3)
258 100 (720,1280,3)
259 100 (720,1280,3)
260 100 (720,1280,3)
261 100 (720,1280,3)
262 100 (720,1280,3)
263 100 (720,1280,3)
264 100 (720,1280,3)
265 100 (720,1280,3)
266 100 (720,1280,3)
267 100 (720,1280,3)
268 100 (720,1280,3)
269 100 (720,1280,3)

View File

@@ -0,0 +1,4 @@
240 100 (720,1280,3)
241 100 (720,1280,3)
246 100 (720,1280,3)
257 100 (720,1280,3)

View File

@@ -0,0 +1,30 @@
240 100 (720,1280,3)
241 100 (720,1280,3)
242 100 (720,1280,3)
243 100 (720,1280,3)
244 100 (720,1280,3)
245 100 (720,1280,3)
246 100 (720,1280,3)
247 100 (720,1280,3)
248 100 (720,1280,3)
249 100 (720,1280,3)
250 100 (720,1280,3)
251 100 (720,1280,3)
252 100 (720,1280,3)
253 100 (720,1280,3)
254 100 (720,1280,3)
255 100 (720,1280,3)
256 100 (720,1280,3)
257 100 (720,1280,3)
258 100 (720,1280,3)
259 100 (720,1280,3)
260 100 (720,1280,3)
261 100 (720,1280,3)
262 100 (720,1280,3)
263 100 (720,1280,3)
264 100 (720,1280,3)
265 100 (720,1280,3)
266 100 (720,1280,3)
267 100 (720,1280,3)
268 100 (720,1280,3)
269 100 (720,1280,3)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,106 @@
from torch.utils import data as data
from torchvision.transforms.functional import normalize
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class PairedImageDataset(data.Dataset):
"""Paired image dataset for image restoration.
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
There are three modes:
1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
2. **meta_info_file**: Use meta information file to generate paths. \
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
3. **folder**: Scan folders to generate paths. The rest.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
meta_info_file (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
Default: '{}'.
gt_size (int): Cropped patched size for gt patches.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
scale (bool): Scale, which will be added automatically.
phase (str): 'train' or 'val'.
"""
def __init__(self, opt):
super(PairedImageDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.mean = opt['mean'] if 'mean' in opt else None
self.std = opt['std'] if 'std' in opt else None
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
if 'filename_tmpl' in opt:
self.filename_tmpl = opt['filename_tmpl']
else:
self.filename_tmpl = '{}'
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
self.io_backend_opt['client_keys'] = ['lq', 'gt']
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
self.opt['meta_info_file'], self.filename_tmpl)
else:
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
scale = self.opt['scale']
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path = self.paths[index]['gt_path']
img_bytes = self.file_client.get(gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
lq_path = self.paths[index]['lq_path']
img_bytes = self.file_client.get(lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
# augmentation for training
if self.opt['phase'] == 'train':
gt_size = self.opt['gt_size']
# random crop
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
# color space transform
if 'color' in self.opt and self.opt['color'] == 'y':
img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None]
img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None]
# crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
# TODO: It is better to update the datasets, rather than force to crop
if self.opt['phase'] != 'train':
img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
normalize(img_gt, self.mean, self.std, inplace=True)
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
def __len__(self):
return len(self.paths)

View File

@@ -0,0 +1,122 @@
import queue as Queue
import threading
import torch
from torch.utils.data import DataLoader
class PrefetchGenerator(threading.Thread):
"""A general prefetch generator.
Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
Args:
generator: Python generator.
num_prefetch_queue (int): Number of prefetch queue.
"""
def __init__(self, generator, num_prefetch_queue):
threading.Thread.__init__(self)
self.queue = Queue.Queue(num_prefetch_queue)
self.generator = generator
self.daemon = True
self.start()
def run(self):
for item in self.generator:
self.queue.put(item)
self.queue.put(None)
def __next__(self):
next_item = self.queue.get()
if next_item is None:
raise StopIteration
return next_item
def __iter__(self):
return self
class PrefetchDataLoader(DataLoader):
"""Prefetch version of dataloader.
Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
TODO:
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
ddp.
Args:
num_prefetch_queue (int): Number of prefetch queue.
kwargs (dict): Other arguments for dataloader.
"""
def __init__(self, num_prefetch_queue, **kwargs):
self.num_prefetch_queue = num_prefetch_queue
super(PrefetchDataLoader, self).__init__(**kwargs)
def __iter__(self):
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
class CPUPrefetcher():
"""CPU prefetcher.
Args:
loader: Dataloader.
"""
def __init__(self, loader):
self.ori_loader = loader
self.loader = iter(loader)
def next(self):
try:
return next(self.loader)
except StopIteration:
return None
def reset(self):
self.loader = iter(self.ori_loader)
class CUDAPrefetcher():
"""CUDA prefetcher.
Reference: https://github.com/NVIDIA/apex/issues/304#
It may consume more GPU memory.
Args:
loader: Dataloader.
opt (dict): Options.
"""
def __init__(self, loader, opt):
self.ori_loader = loader
self.loader = iter(loader)
self.opt = opt
self.stream = torch.cuda.Stream()
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
self.preload()
def preload(self):
try:
self.batch = next(self.loader) # self.batch is a dict
except StopIteration:
self.batch = None
return None
# put tensors to gpu
with torch.cuda.stream(self.stream):
for k, v in self.batch.items():
if torch.is_tensor(v):
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
batch = self.batch
self.preload()
return batch
def reset(self):
self.loader = iter(self.ori_loader)
self.preload()

View File

@@ -0,0 +1,384 @@
import cv2
import math
import numpy as np
import os
import os.path as osp
import random
import time
import torch
from pathlib import Path
import albumentations
import torch.nn.functional as F
from torch.utils import data as data
from basicsr.utils import DiffJPEG
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
from basicsr.utils.img_process_util import filter2D
from basicsr.data.transforms import paired_random_crop, random_crop
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
from utils import util_image
def readline_txt(txt_file):
txt_file = [txt_file, ] if isinstance(txt_file, str) else txt_file
out = []
for txt_file_current in txt_file:
with open(txt_file_current, 'r') as ff:
out.extend([x[:-1] for x in ff.readlines()])
return out
@DATASET_REGISTRY.register(suffix='basicsr')
class RealESRGANDataset(data.Dataset):
"""Dataset used for Real-ESRGAN model:
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
It loads gt (Ground-Truth) images, and augments them.
It also generates blur kernels and sinc kernels for generating low-quality images.
Note that the low-quality images are processed in tensors on GPUS for faster processing.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
meta_info (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
Please see more options in the codes.
"""
def __init__(self, opt, mode='training'):
super(RealESRGANDataset, self).__init__()
self.opt = opt
self.file_client = None
self.io_backend_opt = opt['io_backend']
# file client (lmdb io backend)
self.image_paths = []
self.text_paths = []
self.moment_paths = []
if opt.get('data_source', None) is not None:
for ii in range(len(opt['data_source'])):
configs = opt['data_source'].get(f'source{ii+1}')
root_path = Path(configs.root_path)
im_folder = root_path / configs.image_path
im_ext = configs.im_ext
image_stems = sorted([x.stem for x in im_folder.glob(f"*.{im_ext}")])
if configs.get('length', None) is not None:
assert configs.length < len(image_stems)
image_stems = image_stems[:configs.length]
if configs.get("text_path", None) is not None:
text_folder = root_path / configs.text_path
text_stems = [x.stem for x in text_folder.glob("*.txt")]
image_stems = sorted(list(set(image_stems).intersection(set(text_stems))))
self.text_paths.extend([str(text_folder / f"{x}.txt") for x in image_stems])
else:
self.text_paths.extend([None, ] * len(image_stems))
self.image_paths.extend([str(im_folder / f"{x}.{im_ext}") for x in image_stems])
if configs.get("moment_path", None) is not None:
moment_folder = root_path / configs.moment_path
self.moment_paths.extend([str(moment_folder / f"{x}.npy") for x in image_stems])
else:
self.moment_paths.extend([None, ] * len(image_stems))
# blur settings for the first degradation
self.blur_kernel_size = opt['blur_kernel_size']
self.kernel_list = opt['kernel_list']
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
self.blur_sigma = opt['blur_sigma']
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
# blur settings for the second degradation
self.blur_kernel_size2 = opt['blur_kernel_size2']
self.kernel_list2 = opt['kernel_list2']
self.kernel_prob2 = opt['kernel_prob2']
self.blur_sigma2 = opt['blur_sigma2']
self.betag_range2 = opt['betag_range2']
self.betap_range2 = opt['betap_range2']
self.sinc_prob2 = opt['sinc_prob2']
# a final sinc filter
self.final_sinc_prob = opt['final_sinc_prob']
self.kernel_range1 = [x for x in range(3, opt['blur_kernel_size'], 2)] # kernel size ranges from 7 to 21
self.kernel_range2 = [x for x in range(3, opt['blur_kernel_size2'], 2)] # kernel size ranges from 7 to 21
# TODO: kernel range is now hard-coded, should be in the configure file
# convolving with pulse tensor brings no blurry effect
self.pulse_tensor = torch.zeros(opt['blur_kernel_size2'], opt['blur_kernel_size2']).float()
self.pulse_tensor[opt['blur_kernel_size2']//2, opt['blur_kernel_size2']//2] = 1
self.mode = mode
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# -------------------------------- Load gt images -------------------------------- #
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
gt_path = self.image_paths[index]
# avoid errors caused by high latency in reading files
retry = 3
while retry > 0:
try:
img_bytes = self.file_client.get(gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
except:
index = random.randint(0, self.__len__())
gt_path = self.image_paths[index]
time.sleep(1) # sleep 1s for occasional server congestion
finally:
retry -= 1
if self.mode == 'testing':
if not hasattr(self, 'test_aug'):
self.test_aug = albumentations.Compose([
albumentations.SmallestMaxSize(
max_size=self.opt['gt_size'],
interpolation=cv2.INTER_AREA,
),
albumentations.CenterCrop(self.opt['gt_size'], self.opt['gt_size']),
])
img_gt = self.test_aug(image=img_gt)['image']
elif self.mode == 'training':
# -------------------- Do augmentation for training: flip, rotation -------------------- #
if self.opt['use_hflip'] or self.opt['use_rot']:
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
h, w = img_gt.shape[0:2]
gt_size = self.opt['gt_size']
# resize or pad
if not self.opt['random_crop']:
if not min(h, w) == gt_size:
if not hasattr(self, 'smallest_resizer'):
self.smallest_resizer = util_image.SmallestMaxSize(
max_size=gt_size, pass_resize=False,
)
img_gt = self.smallest_resizer(img_gt)
# center crop
if not hasattr(self, 'center_cropper'):
self.center_cropper = albumentations.CenterCrop(gt_size, gt_size)
img_gt = self.center_cropper(image=img_gt)['image']
else:
img_gt = random_crop(img_gt, self.opt['gt_size'])
else:
raise ValueError(f'Unexpected value {self.mode} for mode parameter')
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range1)
if np.random.uniform() < self.opt['sinc_prob']:
# this sinc filter setting is for kernels ranging from [7, 21]
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel = random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
kernel_size,
self.blur_sigma,
self.blur_sigma, [-math.pi, math.pi],
self.betag_range,
self.betap_range,
noise_range=None)
# pad kernel
pad_size = (self.blur_kernel_size - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range2)
if np.random.uniform() < self.opt['sinc_prob2']:
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel2 = random_mixed_kernels(
self.kernel_list2,
self.kernel_prob2,
kernel_size,
self.blur_sigma2,
self.blur_sigma2, [-math.pi, math.pi],
self.betag_range2,
self.betap_range2,
noise_range=None)
# pad kernel
pad_size = (self.blur_kernel_size2 - kernel_size) // 2
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------------------- the final sinc kernel ------------------------------------- #
if np.random.uniform() < self.opt['final_sinc_prob']:
kernel_size = random.choice(self.kernel_range2)
omega_c = np.random.uniform(np.pi / 3, np.pi)
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=self.blur_kernel_size2)
sinc_kernel = torch.FloatTensor(sinc_kernel)
else:
sinc_kernel = self.pulse_tensor
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
kernel = torch.FloatTensor(kernel)
kernel2 = torch.FloatTensor(kernel2)
if self.text_paths[index] is None or self.opt['random_crop']:
prompt = ""
else:
with open(self.text_paths[index], 'r') as ff:
prompt = ff.read()
if self.opt.max_token_length is not None:
prompt = prompt[:self.opt.max_token_length]
return_d = {
'gt': img_gt,
'gt_path': gt_path,
'txt': prompt,
'kernel1': kernel,
'kernel2': kernel2,
'sinc_kernel': sinc_kernel,
}
if self.moment_paths[index] is not None and (not self.opt['random_crop']):
return_d['gt_moment'] = np.load(self.moment_paths[index])
return return_d
def __len__(self):
return len(self.image_paths)
def degrade_fun(self, conf_degradation, im_gt, kernel1, kernel2, sinc_kernel):
if not hasattr(self, 'jpeger'):
self.jpeger = DiffJPEG(differentiable=False) # simulate JPEG compression artifacts
ori_h, ori_w = im_gt.size()[2:4]
sf = conf_degradation.sf
# ----------------------- The first degradation process ----------------------- #
# blur
out = filter2D(im_gt, kernel1)
# random resize
updown_type = random.choices(
['up', 'down', 'keep'],
conf_degradation['resize_prob'],
)[0]
if updown_type == 'up':
scale = random.uniform(1, conf_degradation['resize_range'][1])
elif updown_type == 'down':
scale = random.uniform(conf_degradation['resize_range'][0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, scale_factor=scale, mode=mode)
# add noise
gray_noise_prob = conf_degradation['gray_noise_prob']
if random.random() < conf_degradation['gaussian_noise_prob']:
out = random_add_gaussian_noise_pt(
out,
sigma_range=conf_degradation['noise_range'],
clip=True,
rounds=False,
gray_prob=gray_noise_prob,
)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=conf_degradation['poisson_scale_range'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*conf_degradation['jpeg_range'])
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
out = self.jpeger(out, quality=jpeg_p)
# ----------------------- The second degradation process ----------------------- #
# blur
if random.random() < conf_degradation['second_order_prob']:
if random.random() < conf_degradation['second_blur_prob']:
out = filter2D(out, kernel2)
# random resize
updown_type = random.choices(
['up', 'down', 'keep'],
conf_degradation['resize_prob2'],
)[0]
if updown_type == 'up':
scale = random.uniform(1, conf_degradation['resize_range2'][1])
elif updown_type == 'down':
scale = random.uniform(conf_degradation['resize_range2'][0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
out,
size=(int(ori_h / sf * scale), int(ori_w / sf * scale)),
mode=mode,
)
# add noise
gray_noise_prob = conf_degradation['gray_noise_prob2']
if random.random() < conf_degradation['gaussian_noise_prob2']:
out = random_add_gaussian_noise_pt(
out,
sigma_range=conf_degradation['noise_range2'],
clip=True,
rounds=False,
gray_prob=gray_noise_prob,
)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=conf_degradation['poisson_scale_range2'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False,
)
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if random.random() < 0.5:
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
out,
size=(ori_h // sf, ori_w // sf),
mode=mode,
)
out = filter2D(out, sinc_kernel)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*conf_degradation['jpeg_range2'])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
else:
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*conf_degradation['jpeg_range2'])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
out,
size=(ori_h // sf, ori_w // sf),
mode=mode,
)
out = filter2D(out, sinc_kernel)
# clamp and round
im_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
return {'lq':im_lq.contiguous(), 'gt':im_gt}

View File

@@ -0,0 +1,106 @@
import os
from torch.utils import data as data
from torchvision.transforms.functional import normalize
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register(suffix='basicsr')
class RealESRGANPairedDataset(data.Dataset):
"""Paired image dataset for image restoration.
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
There are three modes:
1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
2. **meta_info_file**: Use meta information file to generate paths. \
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
3. **folder**: Scan folders to generate paths. The rest.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
meta_info (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
Default: '{}'.
gt_size (int): Cropped patched size for gt patches.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
scale (bool): Scale, which will be added automatically.
phase (str): 'train' or 'val'.
"""
def __init__(self, opt):
super(RealESRGANPairedDataset, self).__init__()
self.opt = opt
self.file_client = None
self.io_backend_opt = opt['io_backend']
# mean and std for normalizing the input images
self.mean = opt['mean'] if 'mean' in opt else None
self.std = opt['std'] if 'std' in opt else None
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
# file client (lmdb io backend)
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
self.io_backend_opt['client_keys'] = ['lq', 'gt']
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
# disk backend with meta_info
# Each line in the meta_info describes the relative path to an image
with open(self.opt['meta_info']) as fin:
paths = [line.strip() for line in fin]
self.paths = []
for path in paths:
gt_path, lq_path = path.split(', ')
gt_path = os.path.join(self.gt_folder, gt_path)
lq_path = os.path.join(self.lq_folder, lq_path)
self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
else:
# disk backend
# it will scan the whole folder to get meta info
# it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
scale = self.opt['scale']
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path = self.paths[index]['gt_path']
img_bytes = self.file_client.get(gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
lq_path = self.paths[index]['lq_path']
img_bytes = self.file_client.get(lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
# augmentation for training
if self.opt['phase'] == 'train':
gt_size = self.opt['gt_size']
# random crop
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
normalize(img_gt, self.mean, self.std, inplace=True)
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
def __len__(self):
return len(self.paths)

View File

@@ -0,0 +1,352 @@
import numpy as np
import random
import torch
from pathlib import Path
from torch.utils import data as data
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.flow_util import dequantize_flow
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class REDSDataset(data.Dataset):
"""REDS dataset for training.
The keys are generated from a meta info txt file.
basicsr/data/meta_info/meta_info_REDS_GT.txt
Each line contains:
1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
a white space.
Examples:
000 100 (720,1280,3)
001 100 (720,1280,3)
...
Key examples: "000/00000000"
GT (gt): Ground-Truth;
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
Args:
opt (dict): Config for train dataset. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
dataroot_flow (str, optional): Data root path for flow.
meta_info_file (str): Path for meta information file.
val_partition (str): Validation partition types. 'REDS4' or 'official'.
io_backend (dict): IO backend type and other kwarg.
num_frame (int): Window size for input frames.
gt_size (int): Cropped patched size for gt patches.
interval_list (list): Interval list for temporal augmentation.
random_reverse (bool): Random reverse input frames.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
scale (bool): Scale, which will be added automatically.
"""
def __init__(self, opt):
super(REDSDataset, self).__init__()
self.opt = opt
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
self.flow_root = Path(opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None
assert opt['num_frame'] % 2 == 1, (f'num_frame should be odd number, but got {opt["num_frame"]}')
self.num_frame = opt['num_frame']
self.num_half_frames = opt['num_frame'] // 2
self.keys = []
with open(opt['meta_info_file'], 'r') as fin:
for line in fin:
folder, frame_num, _ = line.split(' ')
self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
# remove the video clips used in validation
if opt['val_partition'] == 'REDS4':
val_partition = ['000', '011', '015', '020']
elif opt['val_partition'] == 'official':
val_partition = [f'{v:03d}' for v in range(240, 270)]
else:
raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
f"Supported ones are ['official', 'REDS4'].")
self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.is_lmdb = False
if self.io_backend_opt['type'] == 'lmdb':
self.is_lmdb = True
if self.flow_root is not None:
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
else:
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
self.io_backend_opt['client_keys'] = ['lq', 'gt']
# temporal augmentation configs
self.interval_list = opt['interval_list']
self.random_reverse = opt['random_reverse']
interval_str = ','.join(str(x) for x in opt['interval_list'])
logger = get_root_logger()
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
f'random reverse is {self.random_reverse}.')
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
scale = self.opt['scale']
gt_size = self.opt['gt_size']
key = self.keys[index]
clip_name, frame_name = key.split('/') # key example: 000/00000000
center_frame_idx = int(frame_name)
# determine the neighboring frames
interval = random.choice(self.interval_list)
# ensure not exceeding the borders
start_frame_idx = center_frame_idx - self.num_half_frames * interval
end_frame_idx = center_frame_idx + self.num_half_frames * interval
# each clip has 100 frames starting from 0 to 99
while (start_frame_idx < 0) or (end_frame_idx > 99):
center_frame_idx = random.randint(0, 99)
start_frame_idx = (center_frame_idx - self.num_half_frames * interval)
end_frame_idx = center_frame_idx + self.num_half_frames * interval
frame_name = f'{center_frame_idx:08d}'
neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval))
# random reverse
if self.random_reverse and random.random() < 0.5:
neighbor_list.reverse()
assert len(neighbor_list) == self.num_frame, (f'Wrong length of neighbor list: {len(neighbor_list)}')
# get the GT frame (as the center frame)
if self.is_lmdb:
img_gt_path = f'{clip_name}/{frame_name}'
else:
img_gt_path = self.gt_root / clip_name / f'{frame_name}.png'
img_bytes = self.file_client.get(img_gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
# get the neighboring LQ frames
img_lqs = []
for neighbor in neighbor_list:
if self.is_lmdb:
img_lq_path = f'{clip_name}/{neighbor:08d}'
else:
img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
img_bytes = self.file_client.get(img_lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
img_lqs.append(img_lq)
# get flows
if self.flow_root is not None:
img_flows = []
# read previous flows
for i in range(self.num_half_frames, 0, -1):
if self.is_lmdb:
flow_path = f'{clip_name}/{frame_name}_p{i}'
else:
flow_path = (self.flow_root / clip_name / f'{frame_name}_p{i}.png')
img_bytes = self.file_client.get(flow_path, 'flow')
cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
dx, dy = np.split(cat_flow, 2, axis=0)
flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
img_flows.append(flow)
# read next flows
for i in range(1, self.num_half_frames + 1):
if self.is_lmdb:
flow_path = f'{clip_name}/{frame_name}_n{i}'
else:
flow_path = (self.flow_root / clip_name / f'{frame_name}_n{i}.png')
img_bytes = self.file_client.get(flow_path, 'flow')
cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
dx, dy = np.split(cat_flow, 2, axis=0)
flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
img_flows.append(flow)
# for random crop, here, img_flows and img_lqs have the same
# spatial size
img_lqs.extend(img_flows)
# randomly crop
img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
if self.flow_root is not None:
img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.num_frame:]
# augmentation - flip, rotate
img_lqs.append(img_gt)
if self.flow_root is not None:
img_results, img_flows = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'], img_flows)
else:
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
img_results = img2tensor(img_results)
img_lqs = torch.stack(img_results[0:-1], dim=0)
img_gt = img_results[-1]
if self.flow_root is not None:
img_flows = img2tensor(img_flows)
# add the zero center flow
img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0]))
img_flows = torch.stack(img_flows, dim=0)
# img_lqs: (t, c, h, w)
# img_flows: (t, 2, h, w)
# img_gt: (c, h, w)
# key: str
if self.flow_root is not None:
return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key}
else:
return {'lq': img_lqs, 'gt': img_gt, 'key': key}
def __len__(self):
return len(self.keys)
@DATASET_REGISTRY.register()
class REDSRecurrentDataset(data.Dataset):
"""REDS dataset for training recurrent networks.
The keys are generated from a meta info txt file.
basicsr/data/meta_info/meta_info_REDS_GT.txt
Each line contains:
1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
a white space.
Examples:
000 100 (720,1280,3)
001 100 (720,1280,3)
...
Key examples: "000/00000000"
GT (gt): Ground-Truth;
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
Args:
opt (dict): Config for train dataset. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
dataroot_flow (str, optional): Data root path for flow.
meta_info_file (str): Path for meta information file.
val_partition (str): Validation partition types. 'REDS4' or 'official'.
io_backend (dict): IO backend type and other kwarg.
num_frame (int): Window size for input frames.
gt_size (int): Cropped patched size for gt patches.
interval_list (list): Interval list for temporal augmentation.
random_reverse (bool): Random reverse input frames.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
scale (bool): Scale, which will be added automatically.
"""
def __init__(self, opt):
super(REDSRecurrentDataset, self).__init__()
self.opt = opt
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
self.num_frame = opt['num_frame']
self.keys = []
with open(opt['meta_info_file'], 'r') as fin:
for line in fin:
folder, frame_num, _ = line.split(' ')
self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
# remove the video clips used in validation
if opt['val_partition'] == 'REDS4':
val_partition = ['000', '011', '015', '020']
elif opt['val_partition'] == 'official':
val_partition = [f'{v:03d}' for v in range(240, 270)]
else:
raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
f"Supported ones are ['official', 'REDS4'].")
if opt['test_mode']:
self.keys = [v for v in self.keys if v.split('/')[0] in val_partition]
else:
self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.is_lmdb = False
if self.io_backend_opt['type'] == 'lmdb':
self.is_lmdb = True
if hasattr(self, 'flow_root') and self.flow_root is not None:
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
else:
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
self.io_backend_opt['client_keys'] = ['lq', 'gt']
# temporal augmentation configs
self.interval_list = opt.get('interval_list', [1])
self.random_reverse = opt.get('random_reverse', False)
interval_str = ','.join(str(x) for x in self.interval_list)
logger = get_root_logger()
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
f'random reverse is {self.random_reverse}.')
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
scale = self.opt['scale']
gt_size = self.opt['gt_size']
key = self.keys[index]
clip_name, frame_name = key.split('/') # key example: 000/00000000
# determine the neighboring frames
interval = random.choice(self.interval_list)
# ensure not exceeding the borders
start_frame_idx = int(frame_name)
if start_frame_idx > 100 - self.num_frame * interval:
start_frame_idx = random.randint(0, 100 - self.num_frame * interval)
end_frame_idx = start_frame_idx + self.num_frame * interval
neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))
# random reverse
if self.random_reverse and random.random() < 0.5:
neighbor_list.reverse()
# get the neighboring LQ and GT frames
img_lqs = []
img_gts = []
for neighbor in neighbor_list:
if self.is_lmdb:
img_lq_path = f'{clip_name}/{neighbor:08d}'
img_gt_path = f'{clip_name}/{neighbor:08d}'
else:
img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
img_gt_path = self.gt_root / clip_name / f'{neighbor:08d}.png'
# get LQ
img_bytes = self.file_client.get(img_lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
img_lqs.append(img_lq)
# get GT
img_bytes = self.file_client.get(img_gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
img_gts.append(img_gt)
# randomly crop
img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
# augmentation - flip, rotate
img_lqs.extend(img_gts)
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
img_results = img2tensor(img_results)
img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0)
img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0)
# img_lqs: (t, c, h, w)
# img_gts: (t, c, h, w)
# key: str
return {'lq': img_lqs, 'gt': img_gts, 'key': key}
def __len__(self):
return len(self.keys)

View File

@@ -0,0 +1,68 @@
from os import path as osp
from torch.utils import data as data
from torchvision.transforms.functional import normalize
from basicsr.data.data_util import paths_from_lmdb
from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class SingleImageDataset(data.Dataset):
"""Read only lq images in the test phase.
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
There are two modes:
1. 'meta_info_file': Use meta information file to generate paths.
2. 'folder': Scan folders to generate paths.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_lq (str): Data root path for lq.
meta_info_file (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
"""
def __init__(self, opt):
super(SingleImageDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.mean = opt['mean'] if 'mean' in opt else None
self.std = opt['std'] if 'std' in opt else None
self.lq_folder = opt['dataroot_lq']
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = [self.lq_folder]
self.io_backend_opt['client_keys'] = ['lq']
self.paths = paths_from_lmdb(self.lq_folder)
elif 'meta_info_file' in self.opt:
with open(self.opt['meta_info_file'], 'r') as fin:
self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin]
else:
self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# load lq image
lq_path = self.paths[index]
img_bytes = self.file_client.get(lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
# color space transform
if 'color' in self.opt and self.opt['color'] == 'y':
img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
# BGR to RGB, HWC to CHW, numpy to tensor
img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
return {'lq': img_lq, 'lq_path': lq_path}
def __len__(self):
return len(self.paths)

207
basicsr/data/transforms.py Normal file
View File

@@ -0,0 +1,207 @@
import cv2
import random
import torch
def mod_crop(img, scale):
"""Mod crop images, used during testing.
Args:
img (ndarray): Input image.
scale (int): Scale factor.
Returns:
ndarray: Result image.
"""
img = img.copy()
if img.ndim in (2, 3):
h, w = img.shape[0], img.shape[1]
h_remainder, w_remainder = h % scale, w % scale
img = img[:h - h_remainder, :w - w_remainder, ...]
else:
raise ValueError(f'Wrong img ndim: {img.ndim}.')
return img
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
"""Paired random crop. Support Numpy array and Tensor inputs.
It crops lists of lq and gt images with corresponding locations.
Args:
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.
gt_path (str): Path to ground-truth. Default: None.
Returns:
list[ndarray] | ndarray: GT images and LQ images. If returned results
only have one element, just return ndarray.
"""
if not isinstance(img_gts, list):
img_gts = [img_gts]
if not isinstance(img_lqs, list):
img_lqs = [img_lqs]
# determine input type: Numpy array or Tensor
input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
if input_type == 'Tensor':
h_lq, w_lq = img_lqs[0].size()[-2:]
h_gt, w_gt = img_gts[0].size()[-2:]
else:
h_lq, w_lq = img_lqs[0].shape[0:2]
h_gt, w_gt = img_gts[0].shape[0:2]
lq_patch_size = gt_patch_size // scale
if h_gt != h_lq * scale or w_gt != w_lq * scale:
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
f'multiplication of LQ ({h_lq}, {w_lq}).')
if h_lq < lq_patch_size or w_lq < lq_patch_size:
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
f'({lq_patch_size}, {lq_patch_size}). '
f'Please remove {gt_path}.')
# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)
# crop lq patch
if input_type == 'Tensor':
img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
else:
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
if input_type == 'Tensor':
img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
else:
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
if len(img_gts) == 1:
img_gts = img_gts[0]
if len(img_lqs) == 1:
img_lqs = img_lqs[0]
return img_gts, img_lqs
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
We use vertical flip and transpose for rotation implementation.
All the images in the list use the same augmentation.
Args:
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
is an ndarray, it will be transformed to a list.
hflip (bool): Horizontal flip. Default: True.
rotation (bool): Ratotation. Default: True.
flows (list[ndarray]: Flows to be augmented. If the input is an
ndarray, it will be transformed to a list.
Dimension is (h, w, 2). Default: None.
return_status (bool): Return the status of flip and rotation.
Default: False.
Returns:
list[ndarray] | ndarray: Augmented images and flows. If returned
results only have one element, just return ndarray.
"""
hflip = hflip and random.random() < 0.5
vflip = rotation and random.random() < 0.5
rot90 = rotation and random.random() < 0.5
def _augment(img):
if hflip: # horizontal
cv2.flip(img, 1, img)
if vflip: # vertical
cv2.flip(img, 0, img)
if rot90:
img = img.transpose(1, 0, 2)
return img
def _augment_flow(flow):
if hflip: # horizontal
cv2.flip(flow, 1, flow)
flow[:, :, 0] *= -1
if vflip: # vertical
cv2.flip(flow, 0, flow)
flow[:, :, 1] *= -1
if rot90:
flow = flow.transpose(1, 0, 2)
flow = flow[:, :, [1, 0]]
return flow
if not isinstance(imgs, list):
imgs = [imgs]
imgs = [_augment(img) for img in imgs]
if len(imgs) == 1:
imgs = imgs[0]
if flows is not None:
if not isinstance(flows, list):
flows = [flows]
flows = [_augment_flow(flow) for flow in flows]
if len(flows) == 1:
flows = flows[0]
return imgs, flows
else:
if return_status:
return imgs, (hflip, vflip, rot90)
else:
return imgs
def img_rotate(img, angle, center=None, scale=1.0):
"""Rotate image.
Args:
img (ndarray): Image to be rotated.
angle (float): Rotation angle in degrees. Positive values mean
counter-clockwise rotation.
center (tuple[int]): Rotation center. If the center is None,
initialize it as the center of the image. Default: None.
scale (float): Isotropic scale factor. Default: 1.0.
"""
(h, w) = img.shape[:2]
if center is None:
center = (w // 2, h // 2)
matrix = cv2.getRotationMatrix2D(center, angle, scale)
rotated_img = cv2.warpAffine(img, matrix, (w, h))
return rotated_img
def random_crop(im, pch_size):
'''
Randomly crop a patch from the give image.
'''
h, w = im.shape[:2]
# padding if necessary
if h < pch_size or w < pch_size:
pad_h = min(max(0, pch_size - h), h)
pad_w = min(max(0, pch_size - w), w)
im = cv2.copyMakeBorder(im, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
h, w = im.shape[:2]
if h == pch_size:
ind_h = 0
elif h > pch_size:
ind_h = random.randint(0, h-pch_size)
else:
raise ValueError('Image height is smaller than the patch size')
if w == pch_size:
ind_w = 0
elif w > pch_size:
ind_w = random.randint(0, w-pch_size)
else:
raise ValueError('Image width is smaller than the patch size')
im_pch = im[ind_h:ind_h+pch_size, ind_w:ind_w+pch_size,]
return im_pch

View File

@@ -0,0 +1,283 @@
import glob
import torch
from os import path as osp
from torch.utils import data as data
from basicsr.data.data_util import duf_downsample, generate_frame_indices, read_img_seq
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class VideoTestDataset(data.Dataset):
"""Video test dataset.
Supported datasets: Vid4, REDS4, REDSofficial.
More generally, it supports testing dataset with following structures:
::
dataroot
├── subfolder1
├── frame000
├── frame001
├── ...
├── subfolder2
├── frame000
├── frame001
├── ...
├── ...
For testing datasets, there is no need to prepare LMDB files.
Args:
opt (dict): Config for train dataset. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
io_backend (dict): IO backend type and other kwarg.
cache_data (bool): Whether to cache testing datasets.
name (str): Dataset name.
meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
in the dataroot will be used.
num_frame (int): Window size for input frames.
padding (str): Padding mode.
"""
def __init__(self, opt):
super(VideoTestDataset, self).__init__()
self.opt = opt
self.cache_data = opt['cache_data']
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
logger = get_root_logger()
logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
self.imgs_lq, self.imgs_gt = {}, {}
if 'meta_info_file' in opt:
with open(opt['meta_info_file'], 'r') as fin:
subfolders = [line.split(' ')[0] for line in fin]
subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders]
subfolders_gt = [osp.join(self.gt_root, key) for key in subfolders]
else:
subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*')))
if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']:
for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt):
# get frame list for lq and gt
subfolder_name = osp.basename(subfolder_lq)
img_paths_lq = sorted(list(scandir(subfolder_lq, full_path=True)))
img_paths_gt = sorted(list(scandir(subfolder_gt, full_path=True)))
max_idx = len(img_paths_lq)
assert max_idx == len(img_paths_gt), (f'Different number of images in lq ({max_idx})'
f' and gt folders ({len(img_paths_gt)})')
self.data_info['lq_path'].extend(img_paths_lq)
self.data_info['gt_path'].extend(img_paths_gt)
self.data_info['folder'].extend([subfolder_name] * max_idx)
for i in range(max_idx):
self.data_info['idx'].append(f'{i}/{max_idx}')
border_l = [0] * max_idx
for i in range(self.opt['num_frame'] // 2):
border_l[i] = 1
border_l[max_idx - i - 1] = 1
self.data_info['border'].extend(border_l)
# cache data or save the frame list
if self.cache_data:
logger.info(f'Cache {subfolder_name} for VideoTestDataset...')
self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq)
self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt)
else:
self.imgs_lq[subfolder_name] = img_paths_lq
self.imgs_gt[subfolder_name] = img_paths_gt
else:
raise ValueError(f'Non-supported video test dataset: {type(opt["name"])}')
def __getitem__(self, index):
folder = self.data_info['folder'][index]
idx, max_idx = self.data_info['idx'][index].split('/')
idx, max_idx = int(idx), int(max_idx)
border = self.data_info['border'][index]
lq_path = self.data_info['lq_path'][index]
select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
if self.cache_data:
imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
img_gt = self.imgs_gt[folder][idx]
else:
img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
imgs_lq = read_img_seq(img_paths_lq)
img_gt = read_img_seq([self.imgs_gt[folder][idx]])
img_gt.squeeze_(0)
return {
'lq': imgs_lq, # (t, c, h, w)
'gt': img_gt, # (c, h, w)
'folder': folder, # folder name
'idx': self.data_info['idx'][index], # e.g., 0/99
'border': border, # 1 for border, 0 for non-border
'lq_path': lq_path # center frame
}
def __len__(self):
return len(self.data_info['gt_path'])
@DATASET_REGISTRY.register()
class VideoTestVimeo90KDataset(data.Dataset):
"""Video test dataset for Vimeo90k-Test dataset.
It only keeps the center frame for testing.
For testing datasets, there is no need to prepare LMDB files.
Args:
opt (dict): Config for train dataset. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
io_backend (dict): IO backend type and other kwarg.
cache_data (bool): Whether to cache testing datasets.
name (str): Dataset name.
meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
in the dataroot will be used.
num_frame (int): Window size for input frames.
padding (str): Padding mode.
"""
def __init__(self, opt):
super(VideoTestVimeo90KDataset, self).__init__()
self.opt = opt
self.cache_data = opt['cache_data']
if self.cache_data:
raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.')
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
logger = get_root_logger()
logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
with open(opt['meta_info_file'], 'r') as fin:
subfolders = [line.split(' ')[0] for line in fin]
for idx, subfolder in enumerate(subfolders):
gt_path = osp.join(self.gt_root, subfolder, 'im4.png')
self.data_info['gt_path'].append(gt_path)
lq_paths = [osp.join(self.lq_root, subfolder, f'im{i}.png') for i in neighbor_list]
self.data_info['lq_path'].append(lq_paths)
self.data_info['folder'].append('vimeo90k')
self.data_info['idx'].append(f'{idx}/{len(subfolders)}')
self.data_info['border'].append(0)
def __getitem__(self, index):
lq_path = self.data_info['lq_path'][index]
gt_path = self.data_info['gt_path'][index]
imgs_lq = read_img_seq(lq_path)
img_gt = read_img_seq([gt_path])
img_gt.squeeze_(0)
return {
'lq': imgs_lq, # (t, c, h, w)
'gt': img_gt, # (c, h, w)
'folder': self.data_info['folder'][index], # folder name
'idx': self.data_info['idx'][index], # e.g., 0/843
'border': self.data_info['border'][index], # 0 for non-border
'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame
}
def __len__(self):
return len(self.data_info['gt_path'])
@DATASET_REGISTRY.register()
class VideoTestDUFDataset(VideoTestDataset):
""" Video test dataset for DUF dataset.
Args:
opt (dict): Config for train dataset. Most of keys are the same as VideoTestDataset.
It has the following extra keys:
use_duf_downsampling (bool): Whether to use duf downsampling to generate low-resolution frames.
scale (bool): Scale, which will be added automatically.
"""
def __getitem__(self, index):
folder = self.data_info['folder'][index]
idx, max_idx = self.data_info['idx'][index].split('/')
idx, max_idx = int(idx), int(max_idx)
border = self.data_info['border'][index]
lq_path = self.data_info['lq_path'][index]
select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
if self.cache_data:
if self.opt['use_duf_downsampling']:
# read imgs_gt to generate low-resolution frames
imgs_lq = self.imgs_gt[folder].index_select(0, torch.LongTensor(select_idx))
imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
else:
imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
img_gt = self.imgs_gt[folder][idx]
else:
if self.opt['use_duf_downsampling']:
img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx]
# read imgs_gt to generate low-resolution frames
imgs_lq = read_img_seq(img_paths_lq, require_mod_crop=True, scale=self.opt['scale'])
imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
else:
img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
imgs_lq = read_img_seq(img_paths_lq)
img_gt = read_img_seq([self.imgs_gt[folder][idx]], require_mod_crop=True, scale=self.opt['scale'])
img_gt.squeeze_(0)
return {
'lq': imgs_lq, # (t, c, h, w)
'gt': img_gt, # (c, h, w)
'folder': folder, # folder name
'idx': self.data_info['idx'][index], # e.g., 0/99
'border': border, # 1 for border, 0 for non-border
'lq_path': lq_path # center frame
}
@DATASET_REGISTRY.register()
class VideoRecurrentTestDataset(VideoTestDataset):
"""Video test dataset for recurrent architectures, which takes LR video
frames as input and output corresponding HR video frames.
Args:
opt (dict): Same as VideoTestDataset. Unused opt:
padding (str): Padding mode.
"""
def __init__(self, opt):
super(VideoRecurrentTestDataset, self).__init__(opt)
# Find unique folder strings
self.folders = sorted(list(set(self.data_info['folder'])))
def __getitem__(self, index):
folder = self.folders[index]
if self.cache_data:
imgs_lq = self.imgs_lq[folder]
imgs_gt = self.imgs_gt[folder]
else:
raise NotImplementedError('Without cache_data is not implemented.')
return {
'lq': imgs_lq,
'gt': imgs_gt,
'folder': folder,
}
def __len__(self):
return len(self.folders)

View File

@@ -0,0 +1,199 @@
import random
import torch
from pathlib import Path
from torch.utils import data as data
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class Vimeo90KDataset(data.Dataset):
"""Vimeo90K dataset for training.
The keys are generated from a meta info txt file.
basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
Each line contains the following items, separated by a white space.
1. clip name;
2. frame number;
3. image shape
Examples:
::
00001/0001 7 (256,448,3)
00001/0002 7 (256,448,3)
- Key examples: "00001/0001"
- GT (gt): Ground-Truth;
- LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
The neighboring frame list for different num_frame:
::
num_frame | frame list
1 | 4
3 | 3,4,5
5 | 2,3,4,5,6
7 | 1,2,3,4,5,6,7
Args:
opt (dict): Config for train dataset. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
meta_info_file (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
num_frame (int): Window size for input frames.
gt_size (int): Cropped patched size for gt patches.
random_reverse (bool): Random reverse input frames.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
scale (bool): Scale, which will be added automatically.
"""
def __init__(self, opt):
super(Vimeo90KDataset, self).__init__()
self.opt = opt
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
with open(opt['meta_info_file'], 'r') as fin:
self.keys = [line.split(' ')[0] for line in fin]
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.is_lmdb = False
if self.io_backend_opt['type'] == 'lmdb':
self.is_lmdb = True
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
self.io_backend_opt['client_keys'] = ['lq', 'gt']
# indices of input images
self.neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
# temporal augmentation configs
self.random_reverse = opt['random_reverse']
logger = get_root_logger()
logger.info(f'Random reverse is {self.random_reverse}.')
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# random reverse
if self.random_reverse and random.random() < 0.5:
self.neighbor_list.reverse()
scale = self.opt['scale']
gt_size = self.opt['gt_size']
key = self.keys[index]
clip, seq = key.split('/') # key example: 00001/0001
# get the GT frame (im4.png)
if self.is_lmdb:
img_gt_path = f'{key}/im4'
else:
img_gt_path = self.gt_root / clip / seq / 'im4.png'
img_bytes = self.file_client.get(img_gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
# get the neighboring LQ frames
img_lqs = []
for neighbor in self.neighbor_list:
if self.is_lmdb:
img_lq_path = f'{clip}/{seq}/im{neighbor}'
else:
img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
img_bytes = self.file_client.get(img_lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
img_lqs.append(img_lq)
# randomly crop
img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
# augmentation - flip, rotate
img_lqs.append(img_gt)
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
img_results = img2tensor(img_results)
img_lqs = torch.stack(img_results[0:-1], dim=0)
img_gt = img_results[-1]
# img_lqs: (t, c, h, w)
# img_gt: (c, h, w)
# key: str
return {'lq': img_lqs, 'gt': img_gt, 'key': key}
def __len__(self):
return len(self.keys)
@DATASET_REGISTRY.register()
class Vimeo90KRecurrentDataset(Vimeo90KDataset):
def __init__(self, opt):
super(Vimeo90KRecurrentDataset, self).__init__(opt)
self.flip_sequence = opt['flip_sequence']
self.neighbor_list = [1, 2, 3, 4, 5, 6, 7]
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# random reverse
if self.random_reverse and random.random() < 0.5:
self.neighbor_list.reverse()
scale = self.opt['scale']
gt_size = self.opt['gt_size']
key = self.keys[index]
clip, seq = key.split('/') # key example: 00001/0001
# get the neighboring LQ and GT frames
img_lqs = []
img_gts = []
for neighbor in self.neighbor_list:
if self.is_lmdb:
img_lq_path = f'{clip}/{seq}/im{neighbor}'
img_gt_path = f'{clip}/{seq}/im{neighbor}'
else:
img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
img_gt_path = self.gt_root / clip / seq / f'im{neighbor}.png'
# LQ
img_bytes = self.file_client.get(img_lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
# GT
img_bytes = self.file_client.get(img_gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
img_lqs.append(img_lq)
img_gts.append(img_gt)
# randomly crop
img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
# augmentation - flip, rotate
img_lqs.extend(img_gts)
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
img_results = img2tensor(img_results)
img_lqs = torch.stack(img_results[:7], dim=0)
img_gts = torch.stack(img_results[7:], dim=0)
if self.flip_sequence: # flip the sequence: 7 frames to 14 frames
img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0)
img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0)
# img_lqs: (t, c, h, w)
# img_gt: (c, h, w)
# key: str
return {'lq': img_lqs, 'gt': img_gts, 'key': key}
def __len__(self):
return len(self.keys)

47
basicsr/utils/__init__.py Normal file
View File

@@ -0,0 +1,47 @@
from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb
from .diffjpeg import DiffJPEG
from .file_client import FileClient
from .img_process_util import USMSharp, usm_sharp
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
from .options import yaml_load
__all__ = [
# color_util.py
'bgr2ycbcr',
'rgb2ycbcr',
'rgb2ycbcr_pt',
'ycbcr2bgr',
'ycbcr2rgb',
# file_client.py
'FileClient',
# img_util.py
'img2tensor',
'tensor2img',
'imfrombytes',
'imwrite',
'crop_border',
# logger.py
'MessageLogger',
'AvgTimer',
'init_tb_logger',
'init_wandb_logger',
'get_root_logger',
'get_env_info',
# misc.py
'set_random_seed',
'get_time_str',
'mkdir_and_rename',
'make_exp_dirs',
'scandir',
'check_resume',
'sizeof_fmt',
# diffjpeg
'DiffJPEG',
# img_process_util
'USMSharp',
'usm_sharp',
# options
'yaml_load'
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Some files were not shown because too many files have changed in this diff Show More