first commit
35
LICENSE
Normal file
@@ -0,0 +1,35 @@
|
||||
S-Lab License 1.0
|
||||
|
||||
Copyright 2024 S-Lab
|
||||
|
||||
Redistribution and use for non-commercial purpose in source and
|
||||
binary forms, with or without modification, are permitted provided
|
||||
that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
In the event that redistribution and/or use for commercial purpose in
|
||||
source or binary forms, with or without modification is required,
|
||||
please contact the contributor(s) of the work.
|
||||
97
README.md
@@ -1,2 +1,95 @@
|
||||
# InvSR
|
||||
Arbitrary-steps Image Super-resolution via Diffusion Inversion
|
||||
# Arbitrary-steps Image Super-resolution via Diffusion Inversion
|
||||
|
||||
[Zongsheng Yue](https://zsyoaoa.github.io/), [Kang Liao](https://kangliao929.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)
|
||||
|
||||
<!--[Paper](https://arxiv.org/abs/2307.12348) | [Project Page](https://zsyoaoa.github.io/projects/resshift/) | [Demo](https://www.youtube.com/watch?v=8DB-6Xvvl5o)-->
|
||||
|
||||
<!--<a href="https://colab.research.google.com/drive/1CL8aJO7a_RA4MetanrCLqQO5H7KWO8KI?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a> [](https://replicate.com/cjwbw/resshift)  -->
|
||||
|
||||
|
||||
:star: If InvSR is helpful to your researches or projects, please help star this repo. Thanks! :hugs:
|
||||
|
||||
---
|
||||
>This study presents a new image super-resolution (SR) technique based on diffusion inversion, aiming at harnessing the rich image priors encapsulated in large pre-trained diffusion models to improve SR performance. We design a \textit{Partial noise Prediction} strategy to construct an intermediate state of the diffusion model, which serves as the starting sampling point. Central to our approach is a deep noise predictor to estimate the optimal noise maps for the forward diffusion process. Once trained, this noise predictor can be used to initialize the sampling process partially along the diffusion trajectory, generating the desirable high-resolution result. Compared to existing approaches, our method offers a flexible and efficient sampling mechanism that supports an arbitrary number of sampling steps, ranging from one to five. Even with a single sampling step, our method demonstrates superior or comparable performance to recent state-of-the-art approaches.
|
||||
><img src="./assets/framework.png" align="middle" width="800">
|
||||
---
|
||||
## Update
|
||||
- **2024.12.11**: Create this repo.
|
||||
|
||||
## Requirements
|
||||
* Python 3.10, Pytorch 2.4.0, [xformers](https://github.com/facebookresearch/xformers) 0.0.27.post2
|
||||
* More detail (See [environment.yaml](environment.yaml))
|
||||
A suitable [conda](https://conda.io/) environment named `invsr` can be created and activated with:
|
||||
|
||||
```
|
||||
conda create -n invsr python=3.10
|
||||
conda activate invsr
|
||||
pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121s
|
||||
pip install -U xformers==0.0.27.post2 --index-url https://download.pytorch.org/whl/cu121
|
||||
pip install -e ".[torch]"
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Applications
|
||||
### :point_right: Real-world Image Super-resolution
|
||||
[<img src="assets/real-7.png" height="280"/>](https://imgsli.com/MzI2MTU5) [<img src="assets/real-1.png" height="280"/>](https://imgsli.com/MzI2MTUx) [<img src="assets/real-2.png" height="280"/>](https://imgsli.com/MzI2MTUy) <!--[<img src="assets/real-3.png" height="256"/>](https://imgsli.com/MzI2MTUz)-->
|
||||
[<img src="assets/real-4.png" height="430"/>](https://imgsli.com/MzI2MTU0) [<img src="assets/real-6.png" height="430"/>](https://imgsli.com/MzI2MTU3) [<img src="assets/real-5.png" height="430"/>](https://imgsli.com/MzI2MTU1)
|
||||
|
||||
### :point_right: General Image Inhancement
|
||||
[<img src="assets/enhance-1.png" height="294"/>](https://imgsli.com/MzI2MTYw) [<img src="assets/enhance-2.png" height="294"/>](https://imgsli.com/MzI2MTYy)
|
||||
[<img src="assets/enhance-3.png" height="247.5"/>](https://imgsli.com/MzI2MjAx) [<img src="assets/enhance-4.png" height="247.5"/>](https://imgsli.com/MzI2MjAz) [<img src="assets/enhance-5.png" height="247.5"/>](https://imgsli.com/MzI2MjA0)
|
||||
|
||||
### :point_right: AIGC Image Inhancement
|
||||
[<img src="assets/sdxl-1.png" height="324"/>](https://imgsli.com/MzI2MjQy) [<img src="assets/sdxl-2.png" height="324"/>](https://imgsli.com/MzI2MjQ1) [<img src="assets/sdxl-3.png" height="324"/>](https://imgsli.com/MzI2MjQ3)
|
||||
[<img src="assets/flux-1.png" height="324"/>](https://imgsli.com/MzI2MjQ5) [<img src="assets/flux-2.png" height="324"/>](https://imgsli.com/MzI2MjUw) [<img src="assets/flux-3.png" height="324"/>](https://imgsli.com/MzI2MjUx)
|
||||
|
||||
<!--## Online Demo-->
|
||||
<!--You can try our method through an online demo:-->
|
||||
<!--```-->
|
||||
<!--python app.py-->
|
||||
<!--```-->
|
||||
|
||||
## Inference
|
||||
### :rocket: Fast testing
|
||||
```
|
||||
python inference_invsr.py -i [image folder/image path] -o [result folder] --num_steps 1
|
||||
```
|
||||
1. This script will automatically download the pre-trained [noise predictor](https://huggingface.co/OAOA/InvSR/tree/main) and [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo/tree/main). If you have pre-downloaded them manually, please include them via ``--started_ckpt_path`` and ``--sd_path``.
|
||||
2. You can freely adjust the sampling steps via ``--num_steps``.
|
||||
|
||||
### :airplane: Reproducing our paper results
|
||||
+ Synthetic dataset of ImageNet-Test: [Google Drive](https://drive.google.com/file/d/1PRGrujx3OFilgJ7I6nW7ETIR00wlAl2m/view?usp=sharing).
|
||||
|
||||
+ Real data for image super-resolution: [RealSRV3](https://github.com/csjcai/RealSR) | [RealSet80](testdata/RealSet80)
|
||||
|
||||
+ To reproduce the quantitative results on Imagenet-Test and RealSRV3, please add the color fixing options by ``--color_fix wavelet``.
|
||||
|
||||
## Training
|
||||
### :turtle: Preparing stage
|
||||
1. Download the finetuned LPIPS model from this [link](https://huggingface.co/OAOA/InvSR/resolve/main/vgg16_sdturbo_lpips.pth?download=true) and put it in the folder of "weights".
|
||||
2. Prepare the [config](configs/sd-turbo-sr-ldis.yaml) file:
|
||||
+ SD-Turbo path: configs.sd_pipe.params.cache_dir.
|
||||
+ Training data path: data.train.params.data_source.
|
||||
+ Validation data path: data.val.params.dir_path (low-quality image) and data.val.params.extra_dir_path (high-quality image).
|
||||
+ Batchsize: configs.train.batch and configs.train.microbatch (total batchsize = microbatch * #GPUS * num_grad_accumulation)
|
||||
|
||||
### :dolphin: Begin training
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --standalone --nproc_per_node=4 --nnodes=1 main.py --save_dir [Logging Folder]
|
||||
```
|
||||
|
||||
### :whale: Resume from interruption
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --standalone --nproc_per_node=4 --nnodes=1 main.py --save_dir [Logging Folder] --resume save_dir/ckpts/model_xx.pth
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under [NTU S-Lab License 1.0](LICENSE). Redistribution and use should follow this license.
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR) and [diffusers](https://github.com/huggingface/diffusers). Thanks for their awesome works.
|
||||
|
||||
### Contact
|
||||
If you have any questions, please feel free to contact me via `zsyzam@gmail.com`.
|
||||
|
||||
BIN
assets/.DS_Store
vendored
Normal file
BIN
assets/enhance-1.png
Normal file
|
After Width: | Height: | Size: 4.0 MiB |
BIN
assets/enhance-2.png
Normal file
|
After Width: | Height: | Size: 763 KiB |
BIN
assets/enhance-3.png
Normal file
|
After Width: | Height: | Size: 1.7 MiB |
BIN
assets/enhance-4.png
Normal file
|
After Width: | Height: | Size: 2.5 MiB |
BIN
assets/enhance-5.png
Normal file
|
After Width: | Height: | Size: 2.6 MiB |
BIN
assets/flux-1.png
Normal file
|
After Width: | Height: | Size: 2.6 MiB |
BIN
assets/flux-2.png
Normal file
|
After Width: | Height: | Size: 1.8 MiB |
BIN
assets/flux-3.png
Normal file
|
After Width: | Height: | Size: 3.4 MiB |
BIN
assets/framework.png
Normal file
|
After Width: | Height: | Size: 419 KiB |
BIN
assets/real-1.png
Normal file
|
After Width: | Height: | Size: 3.2 MiB |
BIN
assets/real-2.png
Normal file
|
After Width: | Height: | Size: 1.9 MiB |
BIN
assets/real-3.png
Normal file
|
After Width: | Height: | Size: 1.9 MiB |
BIN
assets/real-4.png
Normal file
|
After Width: | Height: | Size: 1.9 MiB |
BIN
assets/real-5.png
Normal file
|
After Width: | Height: | Size: 1.4 MiB |
BIN
assets/real-6.png
Normal file
|
After Width: | Height: | Size: 2.1 MiB |
BIN
assets/real-7.png
Normal file
|
After Width: | Height: | Size: 2.6 MiB |
BIN
assets/sdxl-1.png
Normal file
|
After Width: | Height: | Size: 2.0 MiB |
BIN
assets/sdxl-2.png
Normal file
|
After Width: | Height: | Size: 2.4 MiB |
BIN
assets/sdxl-3.png
Normal file
|
After Width: | Height: | Size: 2.9 MiB |
BIN
basicsr/.DS_Store
vendored
Normal file
4
basicsr/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# https://github.com/xinntao/BasicSR
|
||||
# flake8: noqa
|
||||
from .data import *
|
||||
from .utils import *
|
||||
BIN
basicsr/data/.DS_Store
vendored
Normal file
101
basicsr/data/__init__.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import importlib
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from os import path as osp
|
||||
|
||||
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
|
||||
from basicsr.utils import get_root_logger, scandir
|
||||
from basicsr.utils.dist_util import get_dist_info
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
|
||||
__all__ = ['build_dataset', 'build_dataloader']
|
||||
|
||||
# automatically scan and import dataset modules for registry
|
||||
# scan all the files under the data folder with '_dataset' in file names
|
||||
data_folder = osp.dirname(osp.abspath(__file__))
|
||||
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
||||
# import all the dataset modules
|
||||
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
|
||||
|
||||
|
||||
def build_dataset(dataset_opt):
|
||||
"""Build dataset from options.
|
||||
|
||||
Args:
|
||||
dataset_opt (dict): Configuration for dataset. It must contain:
|
||||
name (str): Dataset name.
|
||||
type (str): Dataset type.
|
||||
"""
|
||||
dataset_opt = deepcopy(dataset_opt)
|
||||
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
|
||||
return dataset
|
||||
|
||||
|
||||
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
|
||||
"""Build dataloader.
|
||||
|
||||
Args:
|
||||
dataset (torch.utils.data.Dataset): Dataset.
|
||||
dataset_opt (dict): Dataset options. It contains the following keys:
|
||||
phase (str): 'train' or 'val'.
|
||||
num_worker_per_gpu (int): Number of workers for each GPU.
|
||||
batch_size_per_gpu (int): Training batch size for each GPU.
|
||||
num_gpu (int): Number of GPUs. Used only in the train phase.
|
||||
Default: 1.
|
||||
dist (bool): Whether in distributed training. Used only in the train
|
||||
phase. Default: False.
|
||||
sampler (torch.utils.data.sampler): Data sampler. Default: None.
|
||||
seed (int | None): Seed. Default: None
|
||||
"""
|
||||
phase = dataset_opt['phase']
|
||||
rank, _ = get_dist_info()
|
||||
if phase == 'train':
|
||||
if dist: # distributed training
|
||||
batch_size = dataset_opt['batch_size_per_gpu']
|
||||
num_workers = dataset_opt['num_worker_per_gpu']
|
||||
else: # non-distributed training
|
||||
multiplier = 1 if num_gpu == 0 else num_gpu
|
||||
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
|
||||
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
|
||||
dataloader_args = dict(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
sampler=sampler,
|
||||
drop_last=True)
|
||||
if sampler is None:
|
||||
dataloader_args['shuffle'] = True
|
||||
dataloader_args['worker_init_fn'] = partial(
|
||||
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
|
||||
elif phase in ['val', 'test']: # validation
|
||||
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
||||
else:
|
||||
raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
|
||||
|
||||
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
|
||||
dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
|
||||
|
||||
prefetch_mode = dataset_opt.get('prefetch_mode')
|
||||
if prefetch_mode == 'cpu': # CPUPrefetcher
|
||||
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
|
||||
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
|
||||
else:
|
||||
# prefetch_mode=None: Normal dataloader
|
||||
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
|
||||
return torch.utils.data.DataLoader(**dataloader_args)
|
||||
|
||||
|
||||
def worker_init_fn(worker_id, num_workers, rank, seed):
|
||||
# Set the worker seed to num_workers * rank + worker_id + seed
|
||||
worker_seed = num_workers * rank + worker_id + seed
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
BIN
basicsr/data/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/data_util.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/data_util.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/degradations.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/degradations.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/ffhq_dataset.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/ffhq_dataset.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/paired_image_dataset.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/paired_image_dataset.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/prefetch_dataloader.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/prefetch_dataloader.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/realesrgan_dataset.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/realesrgan_dataset.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/reds_dataset.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/reds_dataset.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/single_image_dataset.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/single_image_dataset.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/transforms.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/transforms.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/video_test_dataset.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/video_test_dataset.cpython-38.pyc
Normal file
BIN
basicsr/data/__pycache__/vimeo90k_dataset.cpython-310.pyc
Normal file
BIN
basicsr/data/__pycache__/vimeo90k_dataset.cpython-38.pyc
Normal file
48
basicsr/data/data_sampler.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import math
|
||||
import torch
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
|
||||
class EnlargedSampler(Sampler):
|
||||
"""Sampler that restricts data loading to a subset of the dataset.
|
||||
|
||||
Modified from torch.utils.data.distributed.DistributedSampler
|
||||
Support enlarging the dataset for iteration-based training, for saving
|
||||
time when restart the dataloader after each epoch
|
||||
|
||||
Args:
|
||||
dataset (torch.utils.data.Dataset): Dataset used for sampling.
|
||||
num_replicas (int | None): Number of processes participating in
|
||||
the training. It is usually the world_size.
|
||||
rank (int | None): Rank of the current process within num_replicas.
|
||||
ratio (int): Enlarging ratio. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas, rank, ratio=1):
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = torch.randperm(self.total_size, generator=g).tolist()
|
||||
|
||||
dataset_size = len(self.dataset)
|
||||
indices = [v % dataset_size for v in indices]
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
315
basicsr/data/data_util.py
Normal file
@@ -0,0 +1,315 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from os import path as osp
|
||||
from torch.nn import functional as F
|
||||
|
||||
from basicsr.data.transforms import mod_crop
|
||||
from basicsr.utils import img2tensor, scandir
|
||||
|
||||
|
||||
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
|
||||
"""Read a sequence of images from a given folder path.
|
||||
|
||||
Args:
|
||||
path (list[str] | str): List of image paths or image folder path.
|
||||
require_mod_crop (bool): Require mod crop for each image.
|
||||
Default: False.
|
||||
scale (int): Scale factor for mod_crop. Default: 1.
|
||||
return_imgname(bool): Whether return image names. Default False.
|
||||
|
||||
Returns:
|
||||
Tensor: size (t, c, h, w), RGB, [0, 1].
|
||||
list[str]: Returned image name list.
|
||||
"""
|
||||
if isinstance(path, list):
|
||||
img_paths = path
|
||||
else:
|
||||
img_paths = sorted(list(scandir(path, full_path=True)))
|
||||
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
|
||||
|
||||
if require_mod_crop:
|
||||
imgs = [mod_crop(img, scale) for img in imgs]
|
||||
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
|
||||
imgs = torch.stack(imgs, dim=0)
|
||||
|
||||
if return_imgname:
|
||||
imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
|
||||
return imgs, imgnames
|
||||
else:
|
||||
return imgs
|
||||
|
||||
|
||||
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
|
||||
"""Generate an index list for reading `num_frames` frames from a sequence
|
||||
of images.
|
||||
|
||||
Args:
|
||||
crt_idx (int): Current center index.
|
||||
max_frame_num (int): Max number of the sequence of images (from 1).
|
||||
num_frames (int): Reading num_frames frames.
|
||||
padding (str): Padding mode, one of
|
||||
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
|
||||
Examples: current_idx = 0, num_frames = 5
|
||||
The generated frame indices under different padding mode:
|
||||
replicate: [0, 0, 0, 1, 2]
|
||||
reflection: [2, 1, 0, 1, 2]
|
||||
reflection_circle: [4, 3, 0, 1, 2]
|
||||
circle: [3, 4, 0, 1, 2]
|
||||
|
||||
Returns:
|
||||
list[int]: A list of indices.
|
||||
"""
|
||||
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
|
||||
assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
|
||||
|
||||
max_frame_num = max_frame_num - 1 # start from 0
|
||||
num_pad = num_frames // 2
|
||||
|
||||
indices = []
|
||||
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
|
||||
if i < 0:
|
||||
if padding == 'replicate':
|
||||
pad_idx = 0
|
||||
elif padding == 'reflection':
|
||||
pad_idx = -i
|
||||
elif padding == 'reflection_circle':
|
||||
pad_idx = crt_idx + num_pad - i
|
||||
else:
|
||||
pad_idx = num_frames + i
|
||||
elif i > max_frame_num:
|
||||
if padding == 'replicate':
|
||||
pad_idx = max_frame_num
|
||||
elif padding == 'reflection':
|
||||
pad_idx = max_frame_num * 2 - i
|
||||
elif padding == 'reflection_circle':
|
||||
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
|
||||
else:
|
||||
pad_idx = i - num_frames
|
||||
else:
|
||||
pad_idx = i
|
||||
indices.append(pad_idx)
|
||||
return indices
|
||||
|
||||
|
||||
def paired_paths_from_lmdb(folders, keys):
|
||||
"""Generate paired paths from lmdb files.
|
||||
|
||||
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
|
||||
|
||||
::
|
||||
|
||||
lq.lmdb
|
||||
├── data.mdb
|
||||
├── lock.mdb
|
||||
├── meta_info.txt
|
||||
|
||||
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
||||
https://lmdb.readthedocs.io/en/release/ for more details.
|
||||
|
||||
The meta_info.txt is a specified txt file to record the meta information
|
||||
of our datasets. It will be automatically created when preparing
|
||||
datasets by our provided dataset tools.
|
||||
Each line in the txt file records
|
||||
1)image name (with extension),
|
||||
2)image shape,
|
||||
3)compression level, separated by a white space.
|
||||
Example: `baboon.png (120,125,3) 1`
|
||||
|
||||
We use the image name without extension as the lmdb key.
|
||||
Note that we use the same key for the corresponding lq and gt images.
|
||||
|
||||
Args:
|
||||
folders (list[str]): A list of folder path. The order of list should
|
||||
be [input_folder, gt_folder].
|
||||
keys (list[str]): A list of keys identifying folders. The order should
|
||||
be in consistent with folders, e.g., ['lq', 'gt'].
|
||||
Note that this key is different from lmdb keys.
|
||||
|
||||
Returns:
|
||||
list[str]: Returned path list.
|
||||
"""
|
||||
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
||||
f'But got {len(folders)}')
|
||||
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
||||
input_folder, gt_folder = folders
|
||||
input_key, gt_key = keys
|
||||
|
||||
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
|
||||
raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
|
||||
f'formats. But received {input_key}: {input_folder}; '
|
||||
f'{gt_key}: {gt_folder}')
|
||||
# ensure that the two meta_info files are the same
|
||||
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
|
||||
input_lmdb_keys = [line.split('.')[0] for line in fin]
|
||||
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
|
||||
gt_lmdb_keys = [line.split('.')[0] for line in fin]
|
||||
if set(input_lmdb_keys) != set(gt_lmdb_keys):
|
||||
raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
|
||||
else:
|
||||
paths = []
|
||||
for lmdb_key in sorted(input_lmdb_keys):
|
||||
paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
|
||||
return paths
|
||||
|
||||
|
||||
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
|
||||
"""Generate paired paths from an meta information file.
|
||||
|
||||
Each line in the meta information file contains the image names and
|
||||
image shape (usually for gt), separated by a white space.
|
||||
|
||||
Example of an meta information file:
|
||||
```
|
||||
0001_s001.png (480,480,3)
|
||||
0001_s002.png (480,480,3)
|
||||
```
|
||||
|
||||
Args:
|
||||
folders (list[str]): A list of folder path. The order of list should
|
||||
be [input_folder, gt_folder].
|
||||
keys (list[str]): A list of keys identifying folders. The order should
|
||||
be in consistent with folders, e.g., ['lq', 'gt'].
|
||||
meta_info_file (str): Path to the meta information file.
|
||||
filename_tmpl (str): Template for each filename. Note that the
|
||||
template excludes the file extension. Usually the filename_tmpl is
|
||||
for files in the input folder.
|
||||
|
||||
Returns:
|
||||
list[str]: Returned path list.
|
||||
"""
|
||||
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
||||
f'But got {len(folders)}')
|
||||
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
||||
input_folder, gt_folder = folders
|
||||
input_key, gt_key = keys
|
||||
|
||||
with open(meta_info_file, 'r') as fin:
|
||||
gt_names = [line.strip().split(' ')[0] for line in fin]
|
||||
|
||||
paths = []
|
||||
for gt_name in gt_names:
|
||||
basename, ext = osp.splitext(osp.basename(gt_name))
|
||||
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
||||
input_path = osp.join(input_folder, input_name)
|
||||
gt_path = osp.join(gt_folder, gt_name)
|
||||
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
||||
return paths
|
||||
|
||||
|
||||
def paired_paths_from_folder(folders, keys, filename_tmpl):
|
||||
"""Generate paired paths from folders.
|
||||
|
||||
Args:
|
||||
folders (list[str]): A list of folder path. The order of list should
|
||||
be [input_folder, gt_folder].
|
||||
keys (list[str]): A list of keys identifying folders. The order should
|
||||
be in consistent with folders, e.g., ['lq', 'gt'].
|
||||
filename_tmpl (str): Template for each filename. Note that the
|
||||
template excludes the file extension. Usually the filename_tmpl is
|
||||
for files in the input folder.
|
||||
|
||||
Returns:
|
||||
list[str]: Returned path list.
|
||||
"""
|
||||
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
||||
f'But got {len(folders)}')
|
||||
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
||||
input_folder, gt_folder = folders
|
||||
input_key, gt_key = keys
|
||||
|
||||
input_paths = list(scandir(input_folder))
|
||||
gt_paths = list(scandir(gt_folder))
|
||||
assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
|
||||
f'{len(input_paths)}, {len(gt_paths)}.')
|
||||
paths = []
|
||||
for gt_path in gt_paths:
|
||||
basename, ext = osp.splitext(osp.basename(gt_path))
|
||||
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
||||
input_path = osp.join(input_folder, input_name)
|
||||
assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
|
||||
gt_path = osp.join(gt_folder, gt_path)
|
||||
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
||||
return paths
|
||||
|
||||
|
||||
def paths_from_folder(folder):
|
||||
"""Generate paths from folder.
|
||||
|
||||
Args:
|
||||
folder (str): Folder path.
|
||||
|
||||
Returns:
|
||||
list[str]: Returned path list.
|
||||
"""
|
||||
|
||||
paths = list(scandir(folder))
|
||||
paths = [osp.join(folder, path) for path in paths]
|
||||
return paths
|
||||
|
||||
|
||||
def paths_from_lmdb(folder):
|
||||
"""Generate paths from lmdb.
|
||||
|
||||
Args:
|
||||
folder (str): Folder path.
|
||||
|
||||
Returns:
|
||||
list[str]: Returned path list.
|
||||
"""
|
||||
if not folder.endswith('.lmdb'):
|
||||
raise ValueError(f'Folder {folder}folder should in lmdb format.')
|
||||
with open(osp.join(folder, 'meta_info.txt')) as fin:
|
||||
paths = [line.split('.')[0] for line in fin]
|
||||
return paths
|
||||
|
||||
|
||||
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
|
||||
"""Generate Gaussian kernel used in `duf_downsample`.
|
||||
|
||||
Args:
|
||||
kernel_size (int): Kernel size. Default: 13.
|
||||
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
|
||||
|
||||
Returns:
|
||||
np.array: The Gaussian kernel.
|
||||
"""
|
||||
from scipy.ndimage import filters as filters
|
||||
kernel = np.zeros((kernel_size, kernel_size))
|
||||
# set element at the middle to one, a dirac delta
|
||||
kernel[kernel_size // 2, kernel_size // 2] = 1
|
||||
# gaussian-smooth the dirac, resulting in a gaussian filter
|
||||
return filters.gaussian_filter(kernel, sigma)
|
||||
|
||||
|
||||
def duf_downsample(x, kernel_size=13, scale=4):
|
||||
"""Downsamping with Gaussian kernel used in the DUF official code.
|
||||
|
||||
Args:
|
||||
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
|
||||
kernel_size (int): Kernel size. Default: 13.
|
||||
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
|
||||
Default: 4.
|
||||
|
||||
Returns:
|
||||
Tensor: DUF downsampled frames.
|
||||
"""
|
||||
assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
|
||||
|
||||
squeeze_flag = False
|
||||
if x.ndim == 4:
|
||||
squeeze_flag = True
|
||||
x = x.unsqueeze(0)
|
||||
b, t, c, h, w = x.size()
|
||||
x = x.view(-1, 1, h, w)
|
||||
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
|
||||
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
|
||||
|
||||
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
|
||||
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
|
||||
x = F.conv2d(x, gaussian_filter, stride=scale)
|
||||
x = x[:, :, 2:-2, 2:-2]
|
||||
x = x.view(b, t, c, x.size(2), x.size(3))
|
||||
if squeeze_flag:
|
||||
x = x.squeeze(0)
|
||||
return x
|
||||
765
basicsr/data/degradations.py
Normal file
@@ -0,0 +1,765 @@
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
from scipy import special
|
||||
from scipy.stats import multivariate_normal
|
||||
# from torchvision.transforms.functional_tensor import rgb_to_grayscale
|
||||
from torchvision.transforms.functional import rgb_to_grayscale
|
||||
|
||||
# -------------------------------------------------------------------- #
|
||||
# --------------------------- blur kernels --------------------------- #
|
||||
# -------------------------------------------------------------------- #
|
||||
|
||||
|
||||
# --------------------------- util functions --------------------------- #
|
||||
def sigma_matrix2(sig_x, sig_y, theta):
|
||||
"""Calculate the rotated sigma matrix (two dimensional matrix).
|
||||
|
||||
Args:
|
||||
sig_x (float):
|
||||
sig_y (float):
|
||||
theta (float): Radian measurement.
|
||||
|
||||
Returns:
|
||||
ndarray: Rotated sigma matrix.
|
||||
"""
|
||||
d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
|
||||
u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
||||
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
|
||||
|
||||
|
||||
def mesh_grid(kernel_size):
|
||||
"""Generate the mesh grid, centering at zero.
|
||||
|
||||
Args:
|
||||
kernel_size (int):
|
||||
|
||||
Returns:
|
||||
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
|
||||
xx (ndarray): with the shape (kernel_size, kernel_size)
|
||||
yy (ndarray): with the shape (kernel_size, kernel_size)
|
||||
"""
|
||||
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
|
||||
xx, yy = np.meshgrid(ax, ax)
|
||||
xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
|
||||
1))).reshape(kernel_size, kernel_size, 2)
|
||||
return xy, xx, yy
|
||||
|
||||
|
||||
def pdf2(sigma_matrix, grid):
|
||||
"""Calculate PDF of the bivariate Gaussian distribution.
|
||||
|
||||
Args:
|
||||
sigma_matrix (ndarray): with the shape (2, 2)
|
||||
grid (ndarray): generated by :func:`mesh_grid`,
|
||||
with the shape (K, K, 2), K is the kernel size.
|
||||
|
||||
Returns:
|
||||
kernel (ndarrray): un-normalized kernel.
|
||||
"""
|
||||
inverse_sigma = np.linalg.inv(sigma_matrix)
|
||||
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
|
||||
return kernel
|
||||
|
||||
|
||||
def cdf2(d_matrix, grid):
|
||||
"""Calculate the CDF of the standard bivariate Gaussian distribution.
|
||||
Used in skewed Gaussian distribution.
|
||||
|
||||
Args:
|
||||
d_matrix (ndarrasy): skew matrix.
|
||||
grid (ndarray): generated by :func:`mesh_grid`,
|
||||
with the shape (K, K, 2), K is the kernel size.
|
||||
|
||||
Returns:
|
||||
cdf (ndarray): skewed cdf.
|
||||
"""
|
||||
rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
|
||||
grid = np.dot(grid, d_matrix)
|
||||
cdf = rv.cdf(grid)
|
||||
return cdf
|
||||
|
||||
|
||||
def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
|
||||
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
|
||||
|
||||
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
||||
|
||||
Args:
|
||||
kernel_size (int):
|
||||
sig_x (float):
|
||||
sig_y (float):
|
||||
theta (float): Radian measurement.
|
||||
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
||||
with the shape (K, K, 2), K is the kernel size. Default: None
|
||||
isotropic (bool):
|
||||
|
||||
Returns:
|
||||
kernel (ndarray): normalized kernel.
|
||||
"""
|
||||
if grid is None:
|
||||
grid, _, _ = mesh_grid(kernel_size)
|
||||
if isotropic:
|
||||
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
||||
else:
|
||||
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
||||
kernel = pdf2(sigma_matrix, grid)
|
||||
kernel = kernel / np.sum(kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
|
||||
"""Generate a bivariate generalized Gaussian kernel.
|
||||
|
||||
``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions``
|
||||
|
||||
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
||||
|
||||
Args:
|
||||
kernel_size (int):
|
||||
sig_x (float):
|
||||
sig_y (float):
|
||||
theta (float): Radian measurement.
|
||||
beta (float): shape parameter, beta = 1 is the normal distribution.
|
||||
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
||||
with the shape (K, K, 2), K is the kernel size. Default: None
|
||||
|
||||
Returns:
|
||||
kernel (ndarray): normalized kernel.
|
||||
"""
|
||||
if grid is None:
|
||||
grid, _, _ = mesh_grid(kernel_size)
|
||||
if isotropic:
|
||||
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
||||
else:
|
||||
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
||||
inverse_sigma = np.linalg.inv(sigma_matrix)
|
||||
kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
|
||||
kernel = kernel / np.sum(kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
|
||||
"""Generate a plateau-like anisotropic kernel.
|
||||
|
||||
1 / (1+x^(beta))
|
||||
|
||||
Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
|
||||
|
||||
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
||||
|
||||
Args:
|
||||
kernel_size (int):
|
||||
sig_x (float):
|
||||
sig_y (float):
|
||||
theta (float): Radian measurement.
|
||||
beta (float): shape parameter, beta = 1 is the normal distribution.
|
||||
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
||||
with the shape (K, K, 2), K is the kernel size. Default: None
|
||||
|
||||
Returns:
|
||||
kernel (ndarray): normalized kernel.
|
||||
"""
|
||||
if grid is None:
|
||||
grid, _, _ = mesh_grid(kernel_size)
|
||||
if isotropic:
|
||||
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
||||
else:
|
||||
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
||||
inverse_sigma = np.linalg.inv(sigma_matrix)
|
||||
kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
|
||||
kernel = kernel / np.sum(kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def random_bivariate_Gaussian(kernel_size,
|
||||
sigma_x_range,
|
||||
sigma_y_range,
|
||||
rotation_range,
|
||||
noise_range=None,
|
||||
isotropic=True):
|
||||
"""Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
|
||||
|
||||
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
||||
|
||||
Args:
|
||||
kernel_size (int):
|
||||
sigma_x_range (tuple): [0.6, 5]
|
||||
sigma_y_range (tuple): [0.6, 5]
|
||||
rotation range (tuple): [-math.pi, math.pi]
|
||||
noise_range(tuple, optional): multiplicative kernel noise,
|
||||
[0.75, 1.25]. Default: None
|
||||
|
||||
Returns:
|
||||
kernel (ndarray):
|
||||
"""
|
||||
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
||||
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
||||
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
||||
if isotropic is False:
|
||||
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
||||
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
||||
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
||||
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
||||
else:
|
||||
sigma_y = sigma_x
|
||||
rotation = 0
|
||||
|
||||
kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
|
||||
|
||||
# add multiplicative noise
|
||||
if noise_range is not None:
|
||||
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
||||
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
||||
kernel = kernel * noise
|
||||
kernel = kernel / np.sum(kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def random_bivariate_generalized_Gaussian(kernel_size,
|
||||
sigma_x_range,
|
||||
sigma_y_range,
|
||||
rotation_range,
|
||||
beta_range,
|
||||
noise_range=None,
|
||||
isotropic=True):
|
||||
"""Randomly generate bivariate generalized Gaussian kernels.
|
||||
|
||||
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
||||
|
||||
Args:
|
||||
kernel_size (int):
|
||||
sigma_x_range (tuple): [0.6, 5]
|
||||
sigma_y_range (tuple): [0.6, 5]
|
||||
rotation range (tuple): [-math.pi, math.pi]
|
||||
beta_range (tuple): [0.5, 8]
|
||||
noise_range(tuple, optional): multiplicative kernel noise,
|
||||
[0.75, 1.25]. Default: None
|
||||
|
||||
Returns:
|
||||
kernel (ndarray):
|
||||
"""
|
||||
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
||||
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
||||
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
||||
if isotropic is False:
|
||||
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
||||
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
||||
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
||||
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
||||
else:
|
||||
sigma_y = sigma_x
|
||||
rotation = 0
|
||||
|
||||
# assume beta_range[0] < 1 < beta_range[1]
|
||||
if np.random.uniform() < 0.5:
|
||||
beta = np.random.uniform(beta_range[0], 1)
|
||||
else:
|
||||
beta = np.random.uniform(1, beta_range[1])
|
||||
|
||||
kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
|
||||
|
||||
# add multiplicative noise
|
||||
if noise_range is not None:
|
||||
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
||||
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
||||
kernel = kernel * noise
|
||||
kernel = kernel / np.sum(kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def random_bivariate_plateau(kernel_size,
|
||||
sigma_x_range,
|
||||
sigma_y_range,
|
||||
rotation_range,
|
||||
beta_range,
|
||||
noise_range=None,
|
||||
isotropic=True):
|
||||
"""Randomly generate bivariate plateau kernels.
|
||||
|
||||
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
||||
|
||||
Args:
|
||||
kernel_size (int):
|
||||
sigma_x_range (tuple): [0.6, 5]
|
||||
sigma_y_range (tuple): [0.6, 5]
|
||||
rotation range (tuple): [-math.pi/2, math.pi/2]
|
||||
beta_range (tuple): [1, 4]
|
||||
noise_range(tuple, optional): multiplicative kernel noise,
|
||||
[0.75, 1.25]. Default: None
|
||||
|
||||
Returns:
|
||||
kernel (ndarray):
|
||||
"""
|
||||
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
||||
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
||||
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
||||
if isotropic is False:
|
||||
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
||||
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
||||
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
||||
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
||||
else:
|
||||
sigma_y = sigma_x
|
||||
rotation = 0
|
||||
|
||||
# TODO: this may be not proper
|
||||
if np.random.uniform() < 0.5:
|
||||
beta = np.random.uniform(beta_range[0], 1)
|
||||
else:
|
||||
beta = np.random.uniform(1, beta_range[1])
|
||||
|
||||
kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
|
||||
# add multiplicative noise
|
||||
if noise_range is not None:
|
||||
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
||||
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
||||
kernel = kernel * noise
|
||||
kernel = kernel / np.sum(kernel)
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
def random_mixed_kernels(kernel_list,
|
||||
kernel_prob,
|
||||
kernel_size=21,
|
||||
sigma_x_range=(0.6, 5),
|
||||
sigma_y_range=(0.6, 5),
|
||||
rotation_range=(-math.pi, math.pi),
|
||||
betag_range=(0.5, 8),
|
||||
betap_range=(0.5, 8),
|
||||
noise_range=None):
|
||||
"""Randomly generate mixed kernels.
|
||||
|
||||
Args:
|
||||
kernel_list (tuple): a list name of kernel types,
|
||||
support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
|
||||
'plateau_aniso']
|
||||
kernel_prob (tuple): corresponding kernel probability for each
|
||||
kernel type
|
||||
kernel_size (int):
|
||||
sigma_x_range (tuple): [0.6, 5]
|
||||
sigma_y_range (tuple): [0.6, 5]
|
||||
rotation range (tuple): [-math.pi, math.pi]
|
||||
beta_range (tuple): [0.5, 8]
|
||||
noise_range(tuple, optional): multiplicative kernel noise,
|
||||
[0.75, 1.25]. Default: None
|
||||
|
||||
Returns:
|
||||
kernel (ndarray):
|
||||
"""
|
||||
kernel_type = random.choices(kernel_list, kernel_prob)[0]
|
||||
if kernel_type == 'iso':
|
||||
kernel = random_bivariate_Gaussian(
|
||||
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
|
||||
elif kernel_type == 'aniso':
|
||||
kernel = random_bivariate_Gaussian(
|
||||
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
|
||||
elif kernel_type == 'generalized_iso':
|
||||
kernel = random_bivariate_generalized_Gaussian(
|
||||
kernel_size,
|
||||
sigma_x_range,
|
||||
sigma_y_range,
|
||||
rotation_range,
|
||||
betag_range,
|
||||
noise_range=noise_range,
|
||||
isotropic=True)
|
||||
elif kernel_type == 'generalized_aniso':
|
||||
kernel = random_bivariate_generalized_Gaussian(
|
||||
kernel_size,
|
||||
sigma_x_range,
|
||||
sigma_y_range,
|
||||
rotation_range,
|
||||
betag_range,
|
||||
noise_range=noise_range,
|
||||
isotropic=False)
|
||||
elif kernel_type == 'plateau_iso':
|
||||
kernel = random_bivariate_plateau(
|
||||
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
|
||||
elif kernel_type == 'plateau_aniso':
|
||||
kernel = random_bivariate_plateau(
|
||||
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
|
||||
return kernel
|
||||
|
||||
|
||||
np.seterr(divide='ignore', invalid='ignore')
|
||||
|
||||
|
||||
def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
|
||||
"""2D sinc filter
|
||||
|
||||
Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
|
||||
|
||||
Args:
|
||||
cutoff (float): cutoff frequency in radians (pi is max)
|
||||
kernel_size (int): horizontal and vertical size, must be odd.
|
||||
pad_to (int): pad kernel size to desired size, must be odd or zero.
|
||||
"""
|
||||
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
||||
kernel = np.fromfunction(
|
||||
lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
|
||||
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
|
||||
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
|
||||
kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
|
||||
kernel = kernel / np.sum(kernel)
|
||||
if pad_to > kernel_size:
|
||||
pad_size = (pad_to - kernel_size) // 2
|
||||
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
||||
return kernel
|
||||
|
||||
|
||||
# ------------------------------------------------------------- #
|
||||
# --------------------------- noise --------------------------- #
|
||||
# ------------------------------------------------------------- #
|
||||
|
||||
# ----------------------- Gaussian Noise ----------------------- #
|
||||
|
||||
|
||||
def generate_gaussian_noise(img, sigma=10, gray_noise=False):
|
||||
"""Generate Gaussian noise.
|
||||
|
||||
Args:
|
||||
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
||||
sigma (float): Noise scale (measured in range 255). Default: 10.
|
||||
|
||||
Returns:
|
||||
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
if gray_noise:
|
||||
noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
|
||||
noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
|
||||
else:
|
||||
noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
|
||||
return noise
|
||||
|
||||
|
||||
def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
|
||||
"""Add Gaussian noise.
|
||||
|
||||
Args:
|
||||
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
||||
sigma (float): Noise scale (measured in range 255). Default: 10.
|
||||
|
||||
Returns:
|
||||
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
noise = generate_gaussian_noise(img, sigma, gray_noise)
|
||||
out = img + noise
|
||||
if clip and rounds:
|
||||
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
||||
elif clip:
|
||||
out = np.clip(out, 0, 1)
|
||||
elif rounds:
|
||||
out = (out * 255.0).round() / 255.
|
||||
return out
|
||||
|
||||
|
||||
def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
|
||||
"""Add Gaussian noise (PyTorch version).
|
||||
|
||||
Args:
|
||||
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
|
||||
scale (float | Tensor): Noise scale. Default: 1.0.
|
||||
|
||||
Returns:
|
||||
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
b, _, h, w = img.size()
|
||||
if not isinstance(sigma, (float, int)):
|
||||
sigma = sigma.view(img.size(0), 1, 1, 1)
|
||||
if isinstance(gray_noise, (float, int)):
|
||||
cal_gray_noise = gray_noise > 0
|
||||
else:
|
||||
gray_noise = gray_noise.view(b, 1, 1, 1)
|
||||
cal_gray_noise = torch.sum(gray_noise) > 0
|
||||
|
||||
if cal_gray_noise:
|
||||
noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
|
||||
noise_gray = noise_gray.view(b, 1, h, w)
|
||||
|
||||
# always calculate color noise
|
||||
noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
|
||||
|
||||
if cal_gray_noise:
|
||||
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
|
||||
return noise
|
||||
|
||||
|
||||
def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
|
||||
"""Add Gaussian noise (PyTorch version).
|
||||
|
||||
Args:
|
||||
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
|
||||
scale (float | Tensor): Noise scale. Default: 1.0.
|
||||
|
||||
Returns:
|
||||
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
|
||||
out = img + noise
|
||||
if clip and rounds:
|
||||
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
||||
elif clip:
|
||||
out = torch.clamp(out, 0, 1)
|
||||
elif rounds:
|
||||
out = (out * 255.0).round() / 255.
|
||||
return out
|
||||
|
||||
|
||||
# ----------------------- Random Gaussian Noise ----------------------- #
|
||||
def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
|
||||
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
|
||||
if np.random.uniform() < gray_prob:
|
||||
gray_noise = True
|
||||
else:
|
||||
gray_noise = False
|
||||
return generate_gaussian_noise(img, sigma, gray_noise)
|
||||
|
||||
|
||||
def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
||||
noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
|
||||
out = img + noise
|
||||
if clip and rounds:
|
||||
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
||||
elif clip:
|
||||
out = np.clip(out, 0, 1)
|
||||
elif rounds:
|
||||
out = (out * 255.0).round() / 255.
|
||||
return out
|
||||
|
||||
|
||||
def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
|
||||
sigma = torch.rand(
|
||||
img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
|
||||
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
|
||||
gray_noise = (gray_noise < gray_prob).float()
|
||||
return generate_gaussian_noise_pt(img, sigma, gray_noise)
|
||||
|
||||
|
||||
def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
||||
noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
|
||||
out = img + noise
|
||||
if clip and rounds:
|
||||
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
||||
elif clip:
|
||||
out = torch.clamp(out, 0, 1)
|
||||
elif rounds:
|
||||
out = (out * 255.0).round() / 255.
|
||||
return out
|
||||
|
||||
|
||||
# ----------------------- Poisson (Shot) Noise ----------------------- #
|
||||
|
||||
|
||||
def generate_poisson_noise(img, scale=1.0, gray_noise=False):
|
||||
"""Generate poisson noise.
|
||||
|
||||
Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
|
||||
|
||||
Args:
|
||||
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
||||
scale (float): Noise scale. Default: 1.0.
|
||||
gray_noise (bool): Whether generate gray noise. Default: False.
|
||||
|
||||
Returns:
|
||||
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
if gray_noise:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
# round and clip image for counting vals correctly
|
||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
||||
vals = len(np.unique(img))
|
||||
vals = 2**np.ceil(np.log2(vals))
|
||||
out = np.float32(np.random.poisson(img * vals) / float(vals))
|
||||
noise = out - img
|
||||
if gray_noise:
|
||||
noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
|
||||
return noise * scale
|
||||
|
||||
|
||||
def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
|
||||
"""Add poisson noise.
|
||||
|
||||
Args:
|
||||
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
||||
scale (float): Noise scale. Default: 1.0.
|
||||
gray_noise (bool): Whether generate gray noise. Default: False.
|
||||
|
||||
Returns:
|
||||
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
noise = generate_poisson_noise(img, scale, gray_noise)
|
||||
out = img + noise
|
||||
if clip and rounds:
|
||||
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
||||
elif clip:
|
||||
out = np.clip(out, 0, 1)
|
||||
elif rounds:
|
||||
out = (out * 255.0).round() / 255.
|
||||
return out
|
||||
|
||||
|
||||
def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
|
||||
"""Generate a batch of poisson noise (PyTorch version)
|
||||
|
||||
Args:
|
||||
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
|
||||
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
|
||||
Default: 1.0.
|
||||
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
|
||||
0 for False, 1 for True. Default: 0.
|
||||
|
||||
Returns:
|
||||
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
b, _, h, w = img.size()
|
||||
if isinstance(gray_noise, (float, int)):
|
||||
cal_gray_noise = gray_noise > 0
|
||||
else:
|
||||
gray_noise = gray_noise.view(b, 1, 1, 1)
|
||||
cal_gray_noise = torch.sum(gray_noise) > 0
|
||||
if cal_gray_noise:
|
||||
img_gray = rgb_to_grayscale(img, num_output_channels=1)
|
||||
# round and clip image for counting vals correctly
|
||||
img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
|
||||
# use for-loop to get the unique values for each sample
|
||||
vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
|
||||
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
|
||||
vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
|
||||
out = torch.poisson(img_gray * vals) / vals
|
||||
noise_gray = out - img_gray
|
||||
noise_gray = noise_gray.expand(b, 3, h, w)
|
||||
|
||||
# always calculate color noise
|
||||
# round and clip image for counting vals correctly
|
||||
img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
|
||||
# use for-loop to get the unique values for each sample
|
||||
vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
|
||||
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
|
||||
vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
|
||||
out = torch.poisson(img * vals) / vals
|
||||
noise = out - img
|
||||
if cal_gray_noise:
|
||||
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
|
||||
if not isinstance(scale, (float, int)):
|
||||
scale = scale.view(b, 1, 1, 1)
|
||||
return noise * scale
|
||||
|
||||
|
||||
def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
|
||||
"""Add poisson noise to a batch of images (PyTorch version).
|
||||
|
||||
Args:
|
||||
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
|
||||
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
|
||||
Default: 1.0.
|
||||
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
|
||||
0 for False, 1 for True. Default: 0.
|
||||
|
||||
Returns:
|
||||
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
noise = generate_poisson_noise_pt(img, scale, gray_noise)
|
||||
out = img + noise
|
||||
if clip and rounds:
|
||||
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
||||
elif clip:
|
||||
out = torch.clamp(out, 0, 1)
|
||||
elif rounds:
|
||||
out = (out * 255.0).round() / 255.
|
||||
return out
|
||||
|
||||
|
||||
# ----------------------- Random Poisson (Shot) Noise ----------------------- #
|
||||
|
||||
|
||||
def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
|
||||
scale = np.random.uniform(scale_range[0], scale_range[1])
|
||||
if np.random.uniform() < gray_prob:
|
||||
gray_noise = True
|
||||
else:
|
||||
gray_noise = False
|
||||
return generate_poisson_noise(img, scale, gray_noise)
|
||||
|
||||
|
||||
def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
||||
noise = random_generate_poisson_noise(img, scale_range, gray_prob)
|
||||
out = img + noise
|
||||
if clip and rounds:
|
||||
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
||||
elif clip:
|
||||
out = np.clip(out, 0, 1)
|
||||
elif rounds:
|
||||
out = (out * 255.0).round() / 255.
|
||||
return out
|
||||
|
||||
|
||||
def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
|
||||
scale = torch.rand(
|
||||
img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
|
||||
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
|
||||
gray_noise = (gray_noise < gray_prob).float()
|
||||
return generate_poisson_noise_pt(img, scale, gray_noise)
|
||||
|
||||
|
||||
def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
||||
noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
|
||||
out = img + noise
|
||||
if clip and rounds:
|
||||
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
||||
elif clip:
|
||||
out = torch.clamp(out, 0, 1)
|
||||
elif rounds:
|
||||
out = (out * 255.0).round() / 255.
|
||||
return out
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------ #
|
||||
# --------------------------- JPEG compression --------------------------- #
|
||||
# ------------------------------------------------------------------------ #
|
||||
|
||||
|
||||
def add_jpg_compression(img, quality=90):
|
||||
"""Add JPG compression artifacts.
|
||||
|
||||
Args:
|
||||
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
||||
quality (float): JPG compression quality. 0 for lowest quality, 100 for
|
||||
best quality. Default: 90.
|
||||
|
||||
Returns:
|
||||
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
img = np.clip(img, 0, 1)
|
||||
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(quality)]
|
||||
_, encimg = cv2.imencode('.jpg', img * 255., encode_param)
|
||||
img = np.float32(cv2.imdecode(encimg, 1)) / 255.
|
||||
return img
|
||||
|
||||
|
||||
def random_add_jpg_compression(img, quality_range=(90, 100)):
|
||||
"""Randomly add JPG compression artifacts.
|
||||
|
||||
Args:
|
||||
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
||||
quality_range (tuple[float] | list[float]): JPG compression quality
|
||||
range. 0 for lowest quality, 100 for best quality.
|
||||
Default: (90, 100).
|
||||
|
||||
Returns:
|
||||
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
|
||||
float32.
|
||||
"""
|
||||
quality = np.random.uniform(quality_range[0], quality_range[1])
|
||||
return add_jpg_compression(img, quality)
|
||||
80
basicsr/data/ffhq_dataset.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import random
|
||||
import time
|
||||
from os import path as osp
|
||||
from torch.utils import data as data
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from basicsr.data.transforms import augment
|
||||
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class FFHQDataset(data.Dataset):
|
||||
"""FFHQ dataset for StyleGAN.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train datasets. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
mean (list | tuple): Image mean.
|
||||
std (list | tuple): Image std.
|
||||
use_hflip (bool): Whether to horizontally flip.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(FFHQDataset, self).__init__()
|
||||
self.opt = opt
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
|
||||
self.gt_folder = opt['dataroot_gt']
|
||||
self.mean = opt['mean']
|
||||
self.std = opt['std']
|
||||
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = self.gt_folder
|
||||
if not self.gt_folder.endswith('.lmdb'):
|
||||
raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
||||
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
||||
self.paths = [line.split('.')[0] for line in fin]
|
||||
else:
|
||||
# FFHQ has 70000 images in total
|
||||
self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)]
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
# load gt image
|
||||
gt_path = self.paths[index]
|
||||
# avoid errors caused by high latency in reading files
|
||||
retry = 3
|
||||
while retry > 0:
|
||||
try:
|
||||
img_bytes = self.file_client.get(gt_path)
|
||||
except Exception as e:
|
||||
logger = get_root_logger()
|
||||
logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}')
|
||||
# change another file to read
|
||||
index = random.randint(0, self.__len__())
|
||||
gt_path = self.paths[index]
|
||||
time.sleep(1) # sleep 1s for occasional server congestion
|
||||
else:
|
||||
break
|
||||
finally:
|
||||
retry -= 1
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# random horizontal flip
|
||||
img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
|
||||
# normalize
|
||||
normalize(img_gt, self.mean, self.std, inplace=True)
|
||||
return {'gt': img_gt, 'gt_path': gt_path}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
32592
basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt
Normal file
4
basicsr/data/meta_info/meta_info_REDS4_test_GT.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
000 100 (720,1280,3)
|
||||
011 100 (720,1280,3)
|
||||
015 100 (720,1280,3)
|
||||
020 100 (720,1280,3)
|
||||
270
basicsr/data/meta_info/meta_info_REDS_GT.txt
Normal file
@@ -0,0 +1,270 @@
|
||||
000 100 (720,1280,3)
|
||||
001 100 (720,1280,3)
|
||||
002 100 (720,1280,3)
|
||||
003 100 (720,1280,3)
|
||||
004 100 (720,1280,3)
|
||||
005 100 (720,1280,3)
|
||||
006 100 (720,1280,3)
|
||||
007 100 (720,1280,3)
|
||||
008 100 (720,1280,3)
|
||||
009 100 (720,1280,3)
|
||||
010 100 (720,1280,3)
|
||||
011 100 (720,1280,3)
|
||||
012 100 (720,1280,3)
|
||||
013 100 (720,1280,3)
|
||||
014 100 (720,1280,3)
|
||||
015 100 (720,1280,3)
|
||||
016 100 (720,1280,3)
|
||||
017 100 (720,1280,3)
|
||||
018 100 (720,1280,3)
|
||||
019 100 (720,1280,3)
|
||||
020 100 (720,1280,3)
|
||||
021 100 (720,1280,3)
|
||||
022 100 (720,1280,3)
|
||||
023 100 (720,1280,3)
|
||||
024 100 (720,1280,3)
|
||||
025 100 (720,1280,3)
|
||||
026 100 (720,1280,3)
|
||||
027 100 (720,1280,3)
|
||||
028 100 (720,1280,3)
|
||||
029 100 (720,1280,3)
|
||||
030 100 (720,1280,3)
|
||||
031 100 (720,1280,3)
|
||||
032 100 (720,1280,3)
|
||||
033 100 (720,1280,3)
|
||||
034 100 (720,1280,3)
|
||||
035 100 (720,1280,3)
|
||||
036 100 (720,1280,3)
|
||||
037 100 (720,1280,3)
|
||||
038 100 (720,1280,3)
|
||||
039 100 (720,1280,3)
|
||||
040 100 (720,1280,3)
|
||||
041 100 (720,1280,3)
|
||||
042 100 (720,1280,3)
|
||||
043 100 (720,1280,3)
|
||||
044 100 (720,1280,3)
|
||||
045 100 (720,1280,3)
|
||||
046 100 (720,1280,3)
|
||||
047 100 (720,1280,3)
|
||||
048 100 (720,1280,3)
|
||||
049 100 (720,1280,3)
|
||||
050 100 (720,1280,3)
|
||||
051 100 (720,1280,3)
|
||||
052 100 (720,1280,3)
|
||||
053 100 (720,1280,3)
|
||||
054 100 (720,1280,3)
|
||||
055 100 (720,1280,3)
|
||||
056 100 (720,1280,3)
|
||||
057 100 (720,1280,3)
|
||||
058 100 (720,1280,3)
|
||||
059 100 (720,1280,3)
|
||||
060 100 (720,1280,3)
|
||||
061 100 (720,1280,3)
|
||||
062 100 (720,1280,3)
|
||||
063 100 (720,1280,3)
|
||||
064 100 (720,1280,3)
|
||||
065 100 (720,1280,3)
|
||||
066 100 (720,1280,3)
|
||||
067 100 (720,1280,3)
|
||||
068 100 (720,1280,3)
|
||||
069 100 (720,1280,3)
|
||||
070 100 (720,1280,3)
|
||||
071 100 (720,1280,3)
|
||||
072 100 (720,1280,3)
|
||||
073 100 (720,1280,3)
|
||||
074 100 (720,1280,3)
|
||||
075 100 (720,1280,3)
|
||||
076 100 (720,1280,3)
|
||||
077 100 (720,1280,3)
|
||||
078 100 (720,1280,3)
|
||||
079 100 (720,1280,3)
|
||||
080 100 (720,1280,3)
|
||||
081 100 (720,1280,3)
|
||||
082 100 (720,1280,3)
|
||||
083 100 (720,1280,3)
|
||||
084 100 (720,1280,3)
|
||||
085 100 (720,1280,3)
|
||||
086 100 (720,1280,3)
|
||||
087 100 (720,1280,3)
|
||||
088 100 (720,1280,3)
|
||||
089 100 (720,1280,3)
|
||||
090 100 (720,1280,3)
|
||||
091 100 (720,1280,3)
|
||||
092 100 (720,1280,3)
|
||||
093 100 (720,1280,3)
|
||||
094 100 (720,1280,3)
|
||||
095 100 (720,1280,3)
|
||||
096 100 (720,1280,3)
|
||||
097 100 (720,1280,3)
|
||||
098 100 (720,1280,3)
|
||||
099 100 (720,1280,3)
|
||||
100 100 (720,1280,3)
|
||||
101 100 (720,1280,3)
|
||||
102 100 (720,1280,3)
|
||||
103 100 (720,1280,3)
|
||||
104 100 (720,1280,3)
|
||||
105 100 (720,1280,3)
|
||||
106 100 (720,1280,3)
|
||||
107 100 (720,1280,3)
|
||||
108 100 (720,1280,3)
|
||||
109 100 (720,1280,3)
|
||||
110 100 (720,1280,3)
|
||||
111 100 (720,1280,3)
|
||||
112 100 (720,1280,3)
|
||||
113 100 (720,1280,3)
|
||||
114 100 (720,1280,3)
|
||||
115 100 (720,1280,3)
|
||||
116 100 (720,1280,3)
|
||||
117 100 (720,1280,3)
|
||||
118 100 (720,1280,3)
|
||||
119 100 (720,1280,3)
|
||||
120 100 (720,1280,3)
|
||||
121 100 (720,1280,3)
|
||||
122 100 (720,1280,3)
|
||||
123 100 (720,1280,3)
|
||||
124 100 (720,1280,3)
|
||||
125 100 (720,1280,3)
|
||||
126 100 (720,1280,3)
|
||||
127 100 (720,1280,3)
|
||||
128 100 (720,1280,3)
|
||||
129 100 (720,1280,3)
|
||||
130 100 (720,1280,3)
|
||||
131 100 (720,1280,3)
|
||||
132 100 (720,1280,3)
|
||||
133 100 (720,1280,3)
|
||||
134 100 (720,1280,3)
|
||||
135 100 (720,1280,3)
|
||||
136 100 (720,1280,3)
|
||||
137 100 (720,1280,3)
|
||||
138 100 (720,1280,3)
|
||||
139 100 (720,1280,3)
|
||||
140 100 (720,1280,3)
|
||||
141 100 (720,1280,3)
|
||||
142 100 (720,1280,3)
|
||||
143 100 (720,1280,3)
|
||||
144 100 (720,1280,3)
|
||||
145 100 (720,1280,3)
|
||||
146 100 (720,1280,3)
|
||||
147 100 (720,1280,3)
|
||||
148 100 (720,1280,3)
|
||||
149 100 (720,1280,3)
|
||||
150 100 (720,1280,3)
|
||||
151 100 (720,1280,3)
|
||||
152 100 (720,1280,3)
|
||||
153 100 (720,1280,3)
|
||||
154 100 (720,1280,3)
|
||||
155 100 (720,1280,3)
|
||||
156 100 (720,1280,3)
|
||||
157 100 (720,1280,3)
|
||||
158 100 (720,1280,3)
|
||||
159 100 (720,1280,3)
|
||||
160 100 (720,1280,3)
|
||||
161 100 (720,1280,3)
|
||||
162 100 (720,1280,3)
|
||||
163 100 (720,1280,3)
|
||||
164 100 (720,1280,3)
|
||||
165 100 (720,1280,3)
|
||||
166 100 (720,1280,3)
|
||||
167 100 (720,1280,3)
|
||||
168 100 (720,1280,3)
|
||||
169 100 (720,1280,3)
|
||||
170 100 (720,1280,3)
|
||||
171 100 (720,1280,3)
|
||||
172 100 (720,1280,3)
|
||||
173 100 (720,1280,3)
|
||||
174 100 (720,1280,3)
|
||||
175 100 (720,1280,3)
|
||||
176 100 (720,1280,3)
|
||||
177 100 (720,1280,3)
|
||||
178 100 (720,1280,3)
|
||||
179 100 (720,1280,3)
|
||||
180 100 (720,1280,3)
|
||||
181 100 (720,1280,3)
|
||||
182 100 (720,1280,3)
|
||||
183 100 (720,1280,3)
|
||||
184 100 (720,1280,3)
|
||||
185 100 (720,1280,3)
|
||||
186 100 (720,1280,3)
|
||||
187 100 (720,1280,3)
|
||||
188 100 (720,1280,3)
|
||||
189 100 (720,1280,3)
|
||||
190 100 (720,1280,3)
|
||||
191 100 (720,1280,3)
|
||||
192 100 (720,1280,3)
|
||||
193 100 (720,1280,3)
|
||||
194 100 (720,1280,3)
|
||||
195 100 (720,1280,3)
|
||||
196 100 (720,1280,3)
|
||||
197 100 (720,1280,3)
|
||||
198 100 (720,1280,3)
|
||||
199 100 (720,1280,3)
|
||||
200 100 (720,1280,3)
|
||||
201 100 (720,1280,3)
|
||||
202 100 (720,1280,3)
|
||||
203 100 (720,1280,3)
|
||||
204 100 (720,1280,3)
|
||||
205 100 (720,1280,3)
|
||||
206 100 (720,1280,3)
|
||||
207 100 (720,1280,3)
|
||||
208 100 (720,1280,3)
|
||||
209 100 (720,1280,3)
|
||||
210 100 (720,1280,3)
|
||||
211 100 (720,1280,3)
|
||||
212 100 (720,1280,3)
|
||||
213 100 (720,1280,3)
|
||||
214 100 (720,1280,3)
|
||||
215 100 (720,1280,3)
|
||||
216 100 (720,1280,3)
|
||||
217 100 (720,1280,3)
|
||||
218 100 (720,1280,3)
|
||||
219 100 (720,1280,3)
|
||||
220 100 (720,1280,3)
|
||||
221 100 (720,1280,3)
|
||||
222 100 (720,1280,3)
|
||||
223 100 (720,1280,3)
|
||||
224 100 (720,1280,3)
|
||||
225 100 (720,1280,3)
|
||||
226 100 (720,1280,3)
|
||||
227 100 (720,1280,3)
|
||||
228 100 (720,1280,3)
|
||||
229 100 (720,1280,3)
|
||||
230 100 (720,1280,3)
|
||||
231 100 (720,1280,3)
|
||||
232 100 (720,1280,3)
|
||||
233 100 (720,1280,3)
|
||||
234 100 (720,1280,3)
|
||||
235 100 (720,1280,3)
|
||||
236 100 (720,1280,3)
|
||||
237 100 (720,1280,3)
|
||||
238 100 (720,1280,3)
|
||||
239 100 (720,1280,3)
|
||||
240 100 (720,1280,3)
|
||||
241 100 (720,1280,3)
|
||||
242 100 (720,1280,3)
|
||||
243 100 (720,1280,3)
|
||||
244 100 (720,1280,3)
|
||||
245 100 (720,1280,3)
|
||||
246 100 (720,1280,3)
|
||||
247 100 (720,1280,3)
|
||||
248 100 (720,1280,3)
|
||||
249 100 (720,1280,3)
|
||||
250 100 (720,1280,3)
|
||||
251 100 (720,1280,3)
|
||||
252 100 (720,1280,3)
|
||||
253 100 (720,1280,3)
|
||||
254 100 (720,1280,3)
|
||||
255 100 (720,1280,3)
|
||||
256 100 (720,1280,3)
|
||||
257 100 (720,1280,3)
|
||||
258 100 (720,1280,3)
|
||||
259 100 (720,1280,3)
|
||||
260 100 (720,1280,3)
|
||||
261 100 (720,1280,3)
|
||||
262 100 (720,1280,3)
|
||||
263 100 (720,1280,3)
|
||||
264 100 (720,1280,3)
|
||||
265 100 (720,1280,3)
|
||||
266 100 (720,1280,3)
|
||||
267 100 (720,1280,3)
|
||||
268 100 (720,1280,3)
|
||||
269 100 (720,1280,3)
|
||||
@@ -0,0 +1,4 @@
|
||||
240 100 (720,1280,3)
|
||||
241 100 (720,1280,3)
|
||||
246 100 (720,1280,3)
|
||||
257 100 (720,1280,3)
|
||||
@@ -0,0 +1,30 @@
|
||||
240 100 (720,1280,3)
|
||||
241 100 (720,1280,3)
|
||||
242 100 (720,1280,3)
|
||||
243 100 (720,1280,3)
|
||||
244 100 (720,1280,3)
|
||||
245 100 (720,1280,3)
|
||||
246 100 (720,1280,3)
|
||||
247 100 (720,1280,3)
|
||||
248 100 (720,1280,3)
|
||||
249 100 (720,1280,3)
|
||||
250 100 (720,1280,3)
|
||||
251 100 (720,1280,3)
|
||||
252 100 (720,1280,3)
|
||||
253 100 (720,1280,3)
|
||||
254 100 (720,1280,3)
|
||||
255 100 (720,1280,3)
|
||||
256 100 (720,1280,3)
|
||||
257 100 (720,1280,3)
|
||||
258 100 (720,1280,3)
|
||||
259 100 (720,1280,3)
|
||||
260 100 (720,1280,3)
|
||||
261 100 (720,1280,3)
|
||||
262 100 (720,1280,3)
|
||||
263 100 (720,1280,3)
|
||||
264 100 (720,1280,3)
|
||||
265 100 (720,1280,3)
|
||||
266 100 (720,1280,3)
|
||||
267 100 (720,1280,3)
|
||||
268 100 (720,1280,3)
|
||||
269 100 (720,1280,3)
|
||||
7824
basicsr/data/meta_info/meta_info_Vimeo90K_test_GT.txt
Normal file
1225
basicsr/data/meta_info/meta_info_Vimeo90K_test_fast_GT.txt
Normal file
4977
basicsr/data/meta_info/meta_info_Vimeo90K_test_medium_GT.txt
Normal file
1613
basicsr/data/meta_info/meta_info_Vimeo90K_test_slow_GT.txt
Normal file
64612
basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
Normal file
106
basicsr/data/paired_image_dataset.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from torch.utils import data as data
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
|
||||
from basicsr.data.transforms import augment, paired_random_crop
|
||||
from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class PairedImageDataset(data.Dataset):
|
||||
"""Paired image dataset for image restoration.
|
||||
|
||||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
|
||||
|
||||
There are three modes:
|
||||
|
||||
1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
|
||||
2. **meta_info_file**: Use meta information file to generate paths. \
|
||||
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
|
||||
3. **folder**: Scan folders to generate paths. The rest.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train datasets. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
meta_info_file (str): Path for meta information file.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
|
||||
Default: '{}'.
|
||||
gt_size (int): Cropped patched size for gt patches.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
||||
scale (bool): Scale, which will be added automatically.
|
||||
phase (str): 'train' or 'val'.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(PairedImageDataset, self).__init__()
|
||||
self.opt = opt
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
self.mean = opt['mean'] if 'mean' in opt else None
|
||||
self.std = opt['std'] if 'std' in opt else None
|
||||
|
||||
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
||||
if 'filename_tmpl' in opt:
|
||||
self.filename_tmpl = opt['filename_tmpl']
|
||||
else:
|
||||
self.filename_tmpl = '{}'
|
||||
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
||||
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
||||
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
||||
elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
|
||||
self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
|
||||
self.opt['meta_info_file'], self.filename_tmpl)
|
||||
else:
|
||||
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
scale = self.opt['scale']
|
||||
|
||||
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
||||
# image range: [0, 1], float32.
|
||||
gt_path = self.paths[index]['gt_path']
|
||||
img_bytes = self.file_client.get(gt_path, 'gt')
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
lq_path = self.paths[index]['lq_path']
|
||||
img_bytes = self.file_client.get(lq_path, 'lq')
|
||||
img_lq = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# augmentation for training
|
||||
if self.opt['phase'] == 'train':
|
||||
gt_size = self.opt['gt_size']
|
||||
# random crop
|
||||
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
|
||||
# flip, rotation
|
||||
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
# color space transform
|
||||
if 'color' in self.opt and self.opt['color'] == 'y':
|
||||
img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None]
|
||||
img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None]
|
||||
|
||||
# crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
|
||||
# TODO: It is better to update the datasets, rather than force to crop
|
||||
if self.opt['phase'] != 'train':
|
||||
img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
||||
# normalize
|
||||
if self.mean is not None or self.std is not None:
|
||||
normalize(img_lq, self.mean, self.std, inplace=True)
|
||||
normalize(img_gt, self.mean, self.std, inplace=True)
|
||||
|
||||
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
122
basicsr/data/prefetch_dataloader.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import queue as Queue
|
||||
import threading
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class PrefetchGenerator(threading.Thread):
|
||||
"""A general prefetch generator.
|
||||
|
||||
Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
|
||||
|
||||
Args:
|
||||
generator: Python generator.
|
||||
num_prefetch_queue (int): Number of prefetch queue.
|
||||
"""
|
||||
|
||||
def __init__(self, generator, num_prefetch_queue):
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = Queue.Queue(num_prefetch_queue)
|
||||
self.generator = generator
|
||||
self.daemon = True
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
for item in self.generator:
|
||||
self.queue.put(item)
|
||||
self.queue.put(None)
|
||||
|
||||
def __next__(self):
|
||||
next_item = self.queue.get()
|
||||
if next_item is None:
|
||||
raise StopIteration
|
||||
return next_item
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
|
||||
class PrefetchDataLoader(DataLoader):
|
||||
"""Prefetch version of dataloader.
|
||||
|
||||
Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
|
||||
|
||||
TODO:
|
||||
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
|
||||
ddp.
|
||||
|
||||
Args:
|
||||
num_prefetch_queue (int): Number of prefetch queue.
|
||||
kwargs (dict): Other arguments for dataloader.
|
||||
"""
|
||||
|
||||
def __init__(self, num_prefetch_queue, **kwargs):
|
||||
self.num_prefetch_queue = num_prefetch_queue
|
||||
super(PrefetchDataLoader, self).__init__(**kwargs)
|
||||
|
||||
def __iter__(self):
|
||||
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
|
||||
|
||||
|
||||
class CPUPrefetcher():
|
||||
"""CPU prefetcher.
|
||||
|
||||
Args:
|
||||
loader: Dataloader.
|
||||
"""
|
||||
|
||||
def __init__(self, loader):
|
||||
self.ori_loader = loader
|
||||
self.loader = iter(loader)
|
||||
|
||||
def next(self):
|
||||
try:
|
||||
return next(self.loader)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
self.loader = iter(self.ori_loader)
|
||||
|
||||
|
||||
class CUDAPrefetcher():
|
||||
"""CUDA prefetcher.
|
||||
|
||||
Reference: https://github.com/NVIDIA/apex/issues/304#
|
||||
|
||||
It may consume more GPU memory.
|
||||
|
||||
Args:
|
||||
loader: Dataloader.
|
||||
opt (dict): Options.
|
||||
"""
|
||||
|
||||
def __init__(self, loader, opt):
|
||||
self.ori_loader = loader
|
||||
self.loader = iter(loader)
|
||||
self.opt = opt
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
self.batch = next(self.loader) # self.batch is a dict
|
||||
except StopIteration:
|
||||
self.batch = None
|
||||
return None
|
||||
# put tensors to gpu
|
||||
with torch.cuda.stream(self.stream):
|
||||
for k, v in self.batch.items():
|
||||
if torch.is_tensor(v):
|
||||
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
|
||||
|
||||
def next(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
batch = self.batch
|
||||
self.preload()
|
||||
return batch
|
||||
|
||||
def reset(self):
|
||||
self.loader = iter(self.ori_loader)
|
||||
self.preload()
|
||||
384
basicsr/data/realesrgan_dataset.py
Normal file
@@ -0,0 +1,384 @@
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
import time
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
import albumentations
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch.utils import data as data
|
||||
|
||||
from basicsr.utils import DiffJPEG
|
||||
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
|
||||
from basicsr.data.transforms import augment
|
||||
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
from basicsr.utils.img_process_util import filter2D
|
||||
from basicsr.data.transforms import paired_random_crop, random_crop
|
||||
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
||||
|
||||
from utils import util_image
|
||||
|
||||
def readline_txt(txt_file):
|
||||
txt_file = [txt_file, ] if isinstance(txt_file, str) else txt_file
|
||||
out = []
|
||||
for txt_file_current in txt_file:
|
||||
with open(txt_file_current, 'r') as ff:
|
||||
out.extend([x[:-1] for x in ff.readlines()])
|
||||
|
||||
return out
|
||||
|
||||
@DATASET_REGISTRY.register(suffix='basicsr')
|
||||
class RealESRGANDataset(data.Dataset):
|
||||
"""Dataset used for Real-ESRGAN model:
|
||||
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
||||
|
||||
It loads gt (Ground-Truth) images, and augments them.
|
||||
It also generates blur kernels and sinc kernels for generating low-quality images.
|
||||
Note that the low-quality images are processed in tensors on GPUS for faster processing.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train datasets. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
meta_info (str): Path for meta information file.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
||||
Please see more options in the codes.
|
||||
"""
|
||||
|
||||
def __init__(self, opt, mode='training'):
|
||||
super(RealESRGANDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
|
||||
# file client (lmdb io backend)
|
||||
self.image_paths = []
|
||||
self.text_paths = []
|
||||
self.moment_paths = []
|
||||
if opt.get('data_source', None) is not None:
|
||||
for ii in range(len(opt['data_source'])):
|
||||
configs = opt['data_source'].get(f'source{ii+1}')
|
||||
root_path = Path(configs.root_path)
|
||||
im_folder = root_path / configs.image_path
|
||||
im_ext = configs.im_ext
|
||||
image_stems = sorted([x.stem for x in im_folder.glob(f"*.{im_ext}")])
|
||||
if configs.get('length', None) is not None:
|
||||
assert configs.length < len(image_stems)
|
||||
image_stems = image_stems[:configs.length]
|
||||
|
||||
if configs.get("text_path", None) is not None:
|
||||
text_folder = root_path / configs.text_path
|
||||
text_stems = [x.stem for x in text_folder.glob("*.txt")]
|
||||
image_stems = sorted(list(set(image_stems).intersection(set(text_stems))))
|
||||
self.text_paths.extend([str(text_folder / f"{x}.txt") for x in image_stems])
|
||||
else:
|
||||
self.text_paths.extend([None, ] * len(image_stems))
|
||||
|
||||
self.image_paths.extend([str(im_folder / f"{x}.{im_ext}") for x in image_stems])
|
||||
|
||||
if configs.get("moment_path", None) is not None:
|
||||
moment_folder = root_path / configs.moment_path
|
||||
self.moment_paths.extend([str(moment_folder / f"{x}.npy") for x in image_stems])
|
||||
else:
|
||||
self.moment_paths.extend([None, ] * len(image_stems))
|
||||
|
||||
# blur settings for the first degradation
|
||||
self.blur_kernel_size = opt['blur_kernel_size']
|
||||
self.kernel_list = opt['kernel_list']
|
||||
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
|
||||
self.blur_sigma = opt['blur_sigma']
|
||||
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
|
||||
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
|
||||
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
|
||||
|
||||
# blur settings for the second degradation
|
||||
self.blur_kernel_size2 = opt['blur_kernel_size2']
|
||||
self.kernel_list2 = opt['kernel_list2']
|
||||
self.kernel_prob2 = opt['kernel_prob2']
|
||||
self.blur_sigma2 = opt['blur_sigma2']
|
||||
self.betag_range2 = opt['betag_range2']
|
||||
self.betap_range2 = opt['betap_range2']
|
||||
self.sinc_prob2 = opt['sinc_prob2']
|
||||
|
||||
# a final sinc filter
|
||||
self.final_sinc_prob = opt['final_sinc_prob']
|
||||
|
||||
self.kernel_range1 = [x for x in range(3, opt['blur_kernel_size'], 2)] # kernel size ranges from 7 to 21
|
||||
self.kernel_range2 = [x for x in range(3, opt['blur_kernel_size2'], 2)] # kernel size ranges from 7 to 21
|
||||
# TODO: kernel range is now hard-coded, should be in the configure file
|
||||
# convolving with pulse tensor brings no blurry effect
|
||||
self.pulse_tensor = torch.zeros(opt['blur_kernel_size2'], opt['blur_kernel_size2']).float()
|
||||
self.pulse_tensor[opt['blur_kernel_size2']//2, opt['blur_kernel_size2']//2] = 1
|
||||
|
||||
self.mode = mode
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
# -------------------------------- Load gt images -------------------------------- #
|
||||
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
||||
gt_path = self.image_paths[index]
|
||||
# avoid errors caused by high latency in reading files
|
||||
retry = 3
|
||||
while retry > 0:
|
||||
try:
|
||||
img_bytes = self.file_client.get(gt_path, 'gt')
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
except:
|
||||
index = random.randint(0, self.__len__())
|
||||
gt_path = self.image_paths[index]
|
||||
time.sleep(1) # sleep 1s for occasional server congestion
|
||||
finally:
|
||||
retry -= 1
|
||||
if self.mode == 'testing':
|
||||
if not hasattr(self, 'test_aug'):
|
||||
self.test_aug = albumentations.Compose([
|
||||
albumentations.SmallestMaxSize(
|
||||
max_size=self.opt['gt_size'],
|
||||
interpolation=cv2.INTER_AREA,
|
||||
),
|
||||
albumentations.CenterCrop(self.opt['gt_size'], self.opt['gt_size']),
|
||||
])
|
||||
img_gt = self.test_aug(image=img_gt)['image']
|
||||
elif self.mode == 'training':
|
||||
# -------------------- Do augmentation for training: flip, rotation -------------------- #
|
||||
if self.opt['use_hflip'] or self.opt['use_rot']:
|
||||
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
h, w = img_gt.shape[0:2]
|
||||
gt_size = self.opt['gt_size']
|
||||
|
||||
# resize or pad
|
||||
if not self.opt['random_crop']:
|
||||
if not min(h, w) == gt_size:
|
||||
if not hasattr(self, 'smallest_resizer'):
|
||||
self.smallest_resizer = util_image.SmallestMaxSize(
|
||||
max_size=gt_size, pass_resize=False,
|
||||
)
|
||||
img_gt = self.smallest_resizer(img_gt)
|
||||
|
||||
# center crop
|
||||
if not hasattr(self, 'center_cropper'):
|
||||
self.center_cropper = albumentations.CenterCrop(gt_size, gt_size)
|
||||
img_gt = self.center_cropper(image=img_gt)['image']
|
||||
else:
|
||||
img_gt = random_crop(img_gt, self.opt['gt_size'])
|
||||
else:
|
||||
raise ValueError(f'Unexpected value {self.mode} for mode parameter')
|
||||
|
||||
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
|
||||
kernel_size = random.choice(self.kernel_range1)
|
||||
if np.random.uniform() < self.opt['sinc_prob']:
|
||||
# this sinc filter setting is for kernels ranging from [7, 21]
|
||||
if kernel_size < 13:
|
||||
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
||||
else:
|
||||
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
||||
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
||||
else:
|
||||
kernel = random_mixed_kernels(
|
||||
self.kernel_list,
|
||||
self.kernel_prob,
|
||||
kernel_size,
|
||||
self.blur_sigma,
|
||||
self.blur_sigma, [-math.pi, math.pi],
|
||||
self.betag_range,
|
||||
self.betap_range,
|
||||
noise_range=None)
|
||||
# pad kernel
|
||||
pad_size = (self.blur_kernel_size - kernel_size) // 2
|
||||
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
||||
|
||||
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
|
||||
kernel_size = random.choice(self.kernel_range2)
|
||||
if np.random.uniform() < self.opt['sinc_prob2']:
|
||||
if kernel_size < 13:
|
||||
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
||||
else:
|
||||
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
||||
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
||||
else:
|
||||
kernel2 = random_mixed_kernels(
|
||||
self.kernel_list2,
|
||||
self.kernel_prob2,
|
||||
kernel_size,
|
||||
self.blur_sigma2,
|
||||
self.blur_sigma2, [-math.pi, math.pi],
|
||||
self.betag_range2,
|
||||
self.betap_range2,
|
||||
noise_range=None)
|
||||
|
||||
# pad kernel
|
||||
pad_size = (self.blur_kernel_size2 - kernel_size) // 2
|
||||
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
||||
|
||||
# ------------------------------------- the final sinc kernel ------------------------------------- #
|
||||
if np.random.uniform() < self.opt['final_sinc_prob']:
|
||||
kernel_size = random.choice(self.kernel_range2)
|
||||
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
||||
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=self.blur_kernel_size2)
|
||||
sinc_kernel = torch.FloatTensor(sinc_kernel)
|
||||
else:
|
||||
sinc_kernel = self.pulse_tensor
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
|
||||
kernel = torch.FloatTensor(kernel)
|
||||
kernel2 = torch.FloatTensor(kernel2)
|
||||
|
||||
if self.text_paths[index] is None or self.opt['random_crop']:
|
||||
prompt = ""
|
||||
else:
|
||||
with open(self.text_paths[index], 'r') as ff:
|
||||
prompt = ff.read()
|
||||
if self.opt.max_token_length is not None:
|
||||
prompt = prompt[:self.opt.max_token_length]
|
||||
|
||||
return_d = {
|
||||
'gt': img_gt,
|
||||
'gt_path': gt_path,
|
||||
'txt': prompt,
|
||||
'kernel1': kernel,
|
||||
'kernel2': kernel2,
|
||||
'sinc_kernel': sinc_kernel,
|
||||
}
|
||||
if self.moment_paths[index] is not None and (not self.opt['random_crop']):
|
||||
return_d['gt_moment'] = np.load(self.moment_paths[index])
|
||||
|
||||
return return_d
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_paths)
|
||||
|
||||
def degrade_fun(self, conf_degradation, im_gt, kernel1, kernel2, sinc_kernel):
|
||||
if not hasattr(self, 'jpeger'):
|
||||
self.jpeger = DiffJPEG(differentiable=False) # simulate JPEG compression artifacts
|
||||
|
||||
ori_h, ori_w = im_gt.size()[2:4]
|
||||
sf = conf_degradation.sf
|
||||
|
||||
# ----------------------- The first degradation process ----------------------- #
|
||||
# blur
|
||||
out = filter2D(im_gt, kernel1)
|
||||
# random resize
|
||||
updown_type = random.choices(
|
||||
['up', 'down', 'keep'],
|
||||
conf_degradation['resize_prob'],
|
||||
)[0]
|
||||
if updown_type == 'up':
|
||||
scale = random.uniform(1, conf_degradation['resize_range'][1])
|
||||
elif updown_type == 'down':
|
||||
scale = random.uniform(conf_degradation['resize_range'][0], 1)
|
||||
else:
|
||||
scale = 1
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
||||
# add noise
|
||||
gray_noise_prob = conf_degradation['gray_noise_prob']
|
||||
if random.random() < conf_degradation['gaussian_noise_prob']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
out,
|
||||
sigma_range=conf_degradation['noise_range'],
|
||||
clip=True,
|
||||
rounds=False,
|
||||
gray_prob=gray_noise_prob,
|
||||
)
|
||||
else:
|
||||
out = random_add_poisson_noise_pt(
|
||||
out,
|
||||
scale_range=conf_degradation['poisson_scale_range'],
|
||||
gray_prob=gray_noise_prob,
|
||||
clip=True,
|
||||
rounds=False)
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*conf_degradation['jpeg_range'])
|
||||
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
|
||||
# ----------------------- The second degradation process ----------------------- #
|
||||
# blur
|
||||
if random.random() < conf_degradation['second_order_prob']:
|
||||
if random.random() < conf_degradation['second_blur_prob']:
|
||||
out = filter2D(out, kernel2)
|
||||
# random resize
|
||||
updown_type = random.choices(
|
||||
['up', 'down', 'keep'],
|
||||
conf_degradation['resize_prob2'],
|
||||
)[0]
|
||||
if updown_type == 'up':
|
||||
scale = random.uniform(1, conf_degradation['resize_range2'][1])
|
||||
elif updown_type == 'down':
|
||||
scale = random.uniform(conf_degradation['resize_range2'][0], 1)
|
||||
else:
|
||||
scale = 1
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(
|
||||
out,
|
||||
size=(int(ori_h / sf * scale), int(ori_w / sf * scale)),
|
||||
mode=mode,
|
||||
)
|
||||
# add noise
|
||||
gray_noise_prob = conf_degradation['gray_noise_prob2']
|
||||
if random.random() < conf_degradation['gaussian_noise_prob2']:
|
||||
out = random_add_gaussian_noise_pt(
|
||||
out,
|
||||
sigma_range=conf_degradation['noise_range2'],
|
||||
clip=True,
|
||||
rounds=False,
|
||||
gray_prob=gray_noise_prob,
|
||||
)
|
||||
else:
|
||||
out = random_add_poisson_noise_pt(
|
||||
out,
|
||||
scale_range=conf_degradation['poisson_scale_range2'],
|
||||
gray_prob=gray_noise_prob,
|
||||
clip=True,
|
||||
rounds=False,
|
||||
)
|
||||
|
||||
# JPEG compression + the final sinc filter
|
||||
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
||||
# as one operation.
|
||||
# We consider two orders:
|
||||
# 1. [resize back + sinc filter] + JPEG compression
|
||||
# 2. JPEG compression + [resize back + sinc filter]
|
||||
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
||||
if random.random() < 0.5:
|
||||
# resize back + the final sinc filter
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(
|
||||
out,
|
||||
size=(ori_h // sf, ori_w // sf),
|
||||
mode=mode,
|
||||
)
|
||||
out = filter2D(out, sinc_kernel)
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*conf_degradation['jpeg_range2'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
else:
|
||||
# JPEG compression
|
||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*conf_degradation['jpeg_range2'])
|
||||
out = torch.clamp(out, 0, 1)
|
||||
out = self.jpeger(out, quality=jpeg_p)
|
||||
# resize back + the final sinc filter
|
||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
||||
out = F.interpolate(
|
||||
out,
|
||||
size=(ori_h // sf, ori_w // sf),
|
||||
mode=mode,
|
||||
)
|
||||
out = filter2D(out, sinc_kernel)
|
||||
|
||||
# clamp and round
|
||||
im_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
||||
|
||||
return {'lq':im_lq.contiguous(), 'gt':im_gt}
|
||||
106
basicsr/data/realesrgan_paired_dataset.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
from torch.utils import data as data
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
|
||||
from basicsr.data.transforms import augment, paired_random_crop
|
||||
from basicsr.utils import FileClient, imfrombytes, img2tensor
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register(suffix='basicsr')
|
||||
class RealESRGANPairedDataset(data.Dataset):
|
||||
"""Paired image dataset for image restoration.
|
||||
|
||||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
|
||||
|
||||
There are three modes:
|
||||
|
||||
1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
|
||||
2. **meta_info_file**: Use meta information file to generate paths. \
|
||||
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
|
||||
3. **folder**: Scan folders to generate paths. The rest.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train datasets. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
meta_info (str): Path for meta information file.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
|
||||
Default: '{}'.
|
||||
gt_size (int): Cropped patched size for gt patches.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
||||
scale (bool): Scale, which will be added automatically.
|
||||
phase (str): 'train' or 'val'.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(RealESRGANPairedDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
# mean and std for normalizing the input images
|
||||
self.mean = opt['mean'] if 'mean' in opt else None
|
||||
self.std = opt['std'] if 'std' in opt else None
|
||||
|
||||
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
||||
self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
|
||||
|
||||
# file client (lmdb io backend)
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
||||
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
||||
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
||||
elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
|
||||
# disk backend with meta_info
|
||||
# Each line in the meta_info describes the relative path to an image
|
||||
with open(self.opt['meta_info']) as fin:
|
||||
paths = [line.strip() for line in fin]
|
||||
self.paths = []
|
||||
for path in paths:
|
||||
gt_path, lq_path = path.split(', ')
|
||||
gt_path = os.path.join(self.gt_folder, gt_path)
|
||||
lq_path = os.path.join(self.lq_folder, lq_path)
|
||||
self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
|
||||
else:
|
||||
# disk backend
|
||||
# it will scan the whole folder to get meta info
|
||||
# it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
|
||||
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
scale = self.opt['scale']
|
||||
|
||||
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
||||
# image range: [0, 1], float32.
|
||||
gt_path = self.paths[index]['gt_path']
|
||||
img_bytes = self.file_client.get(gt_path, 'gt')
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
lq_path = self.paths[index]['lq_path']
|
||||
img_bytes = self.file_client.get(lq_path, 'lq')
|
||||
img_lq = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# augmentation for training
|
||||
if self.opt['phase'] == 'train':
|
||||
gt_size = self.opt['gt_size']
|
||||
# random crop
|
||||
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
|
||||
# flip, rotation
|
||||
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
||||
# normalize
|
||||
if self.mean is not None or self.std is not None:
|
||||
normalize(img_lq, self.mean, self.std, inplace=True)
|
||||
normalize(img_gt, self.mean, self.std, inplace=True)
|
||||
|
||||
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
352
basicsr/data/reds_dataset.py
Normal file
@@ -0,0 +1,352 @@
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from torch.utils import data as data
|
||||
|
||||
from basicsr.data.transforms import augment, paired_random_crop
|
||||
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
||||
from basicsr.utils.flow_util import dequantize_flow
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class REDSDataset(data.Dataset):
|
||||
"""REDS dataset for training.
|
||||
|
||||
The keys are generated from a meta info txt file.
|
||||
basicsr/data/meta_info/meta_info_REDS_GT.txt
|
||||
|
||||
Each line contains:
|
||||
1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
|
||||
a white space.
|
||||
Examples:
|
||||
000 100 (720,1280,3)
|
||||
001 100 (720,1280,3)
|
||||
...
|
||||
|
||||
Key examples: "000/00000000"
|
||||
GT (gt): Ground-Truth;
|
||||
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train dataset. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
dataroot_flow (str, optional): Data root path for flow.
|
||||
meta_info_file (str): Path for meta information file.
|
||||
val_partition (str): Validation partition types. 'REDS4' or 'official'.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
num_frame (int): Window size for input frames.
|
||||
gt_size (int): Cropped patched size for gt patches.
|
||||
interval_list (list): Interval list for temporal augmentation.
|
||||
random_reverse (bool): Random reverse input frames.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
||||
scale (bool): Scale, which will be added automatically.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(REDSDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
|
||||
self.flow_root = Path(opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None
|
||||
assert opt['num_frame'] % 2 == 1, (f'num_frame should be odd number, but got {opt["num_frame"]}')
|
||||
self.num_frame = opt['num_frame']
|
||||
self.num_half_frames = opt['num_frame'] // 2
|
||||
|
||||
self.keys = []
|
||||
with open(opt['meta_info_file'], 'r') as fin:
|
||||
for line in fin:
|
||||
folder, frame_num, _ = line.split(' ')
|
||||
self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
|
||||
|
||||
# remove the video clips used in validation
|
||||
if opt['val_partition'] == 'REDS4':
|
||||
val_partition = ['000', '011', '015', '020']
|
||||
elif opt['val_partition'] == 'official':
|
||||
val_partition = [f'{v:03d}' for v in range(240, 270)]
|
||||
else:
|
||||
raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
|
||||
f"Supported ones are ['official', 'REDS4'].")
|
||||
self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
|
||||
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
self.is_lmdb = False
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.is_lmdb = True
|
||||
if self.flow_root is not None:
|
||||
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
|
||||
self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
|
||||
else:
|
||||
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
|
||||
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
||||
|
||||
# temporal augmentation configs
|
||||
self.interval_list = opt['interval_list']
|
||||
self.random_reverse = opt['random_reverse']
|
||||
interval_str = ','.join(str(x) for x in opt['interval_list'])
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
|
||||
f'random reverse is {self.random_reverse}.')
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
scale = self.opt['scale']
|
||||
gt_size = self.opt['gt_size']
|
||||
key = self.keys[index]
|
||||
clip_name, frame_name = key.split('/') # key example: 000/00000000
|
||||
center_frame_idx = int(frame_name)
|
||||
|
||||
# determine the neighboring frames
|
||||
interval = random.choice(self.interval_list)
|
||||
|
||||
# ensure not exceeding the borders
|
||||
start_frame_idx = center_frame_idx - self.num_half_frames * interval
|
||||
end_frame_idx = center_frame_idx + self.num_half_frames * interval
|
||||
# each clip has 100 frames starting from 0 to 99
|
||||
while (start_frame_idx < 0) or (end_frame_idx > 99):
|
||||
center_frame_idx = random.randint(0, 99)
|
||||
start_frame_idx = (center_frame_idx - self.num_half_frames * interval)
|
||||
end_frame_idx = center_frame_idx + self.num_half_frames * interval
|
||||
frame_name = f'{center_frame_idx:08d}'
|
||||
neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval))
|
||||
# random reverse
|
||||
if self.random_reverse and random.random() < 0.5:
|
||||
neighbor_list.reverse()
|
||||
|
||||
assert len(neighbor_list) == self.num_frame, (f'Wrong length of neighbor list: {len(neighbor_list)}')
|
||||
|
||||
# get the GT frame (as the center frame)
|
||||
if self.is_lmdb:
|
||||
img_gt_path = f'{clip_name}/{frame_name}'
|
||||
else:
|
||||
img_gt_path = self.gt_root / clip_name / f'{frame_name}.png'
|
||||
img_bytes = self.file_client.get(img_gt_path, 'gt')
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# get the neighboring LQ frames
|
||||
img_lqs = []
|
||||
for neighbor in neighbor_list:
|
||||
if self.is_lmdb:
|
||||
img_lq_path = f'{clip_name}/{neighbor:08d}'
|
||||
else:
|
||||
img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
|
||||
img_bytes = self.file_client.get(img_lq_path, 'lq')
|
||||
img_lq = imfrombytes(img_bytes, float32=True)
|
||||
img_lqs.append(img_lq)
|
||||
|
||||
# get flows
|
||||
if self.flow_root is not None:
|
||||
img_flows = []
|
||||
# read previous flows
|
||||
for i in range(self.num_half_frames, 0, -1):
|
||||
if self.is_lmdb:
|
||||
flow_path = f'{clip_name}/{frame_name}_p{i}'
|
||||
else:
|
||||
flow_path = (self.flow_root / clip_name / f'{frame_name}_p{i}.png')
|
||||
img_bytes = self.file_client.get(flow_path, 'flow')
|
||||
cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
|
||||
dx, dy = np.split(cat_flow, 2, axis=0)
|
||||
flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
|
||||
img_flows.append(flow)
|
||||
# read next flows
|
||||
for i in range(1, self.num_half_frames + 1):
|
||||
if self.is_lmdb:
|
||||
flow_path = f'{clip_name}/{frame_name}_n{i}'
|
||||
else:
|
||||
flow_path = (self.flow_root / clip_name / f'{frame_name}_n{i}.png')
|
||||
img_bytes = self.file_client.get(flow_path, 'flow')
|
||||
cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
|
||||
dx, dy = np.split(cat_flow, 2, axis=0)
|
||||
flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
|
||||
img_flows.append(flow)
|
||||
|
||||
# for random crop, here, img_flows and img_lqs have the same
|
||||
# spatial size
|
||||
img_lqs.extend(img_flows)
|
||||
|
||||
# randomly crop
|
||||
img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
|
||||
if self.flow_root is not None:
|
||||
img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.num_frame:]
|
||||
|
||||
# augmentation - flip, rotate
|
||||
img_lqs.append(img_gt)
|
||||
if self.flow_root is not None:
|
||||
img_results, img_flows = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'], img_flows)
|
||||
else:
|
||||
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
img_results = img2tensor(img_results)
|
||||
img_lqs = torch.stack(img_results[0:-1], dim=0)
|
||||
img_gt = img_results[-1]
|
||||
|
||||
if self.flow_root is not None:
|
||||
img_flows = img2tensor(img_flows)
|
||||
# add the zero center flow
|
||||
img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0]))
|
||||
img_flows = torch.stack(img_flows, dim=0)
|
||||
|
||||
# img_lqs: (t, c, h, w)
|
||||
# img_flows: (t, 2, h, w)
|
||||
# img_gt: (c, h, w)
|
||||
# key: str
|
||||
if self.flow_root is not None:
|
||||
return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key}
|
||||
else:
|
||||
return {'lq': img_lqs, 'gt': img_gt, 'key': key}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.keys)
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class REDSRecurrentDataset(data.Dataset):
|
||||
"""REDS dataset for training recurrent networks.
|
||||
|
||||
The keys are generated from a meta info txt file.
|
||||
basicsr/data/meta_info/meta_info_REDS_GT.txt
|
||||
|
||||
Each line contains:
|
||||
1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
|
||||
a white space.
|
||||
Examples:
|
||||
000 100 (720,1280,3)
|
||||
001 100 (720,1280,3)
|
||||
...
|
||||
|
||||
Key examples: "000/00000000"
|
||||
GT (gt): Ground-Truth;
|
||||
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train dataset. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
dataroot_flow (str, optional): Data root path for flow.
|
||||
meta_info_file (str): Path for meta information file.
|
||||
val_partition (str): Validation partition types. 'REDS4' or 'official'.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
num_frame (int): Window size for input frames.
|
||||
gt_size (int): Cropped patched size for gt patches.
|
||||
interval_list (list): Interval list for temporal augmentation.
|
||||
random_reverse (bool): Random reverse input frames.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
||||
scale (bool): Scale, which will be added automatically.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(REDSRecurrentDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
|
||||
self.num_frame = opt['num_frame']
|
||||
|
||||
self.keys = []
|
||||
with open(opt['meta_info_file'], 'r') as fin:
|
||||
for line in fin:
|
||||
folder, frame_num, _ = line.split(' ')
|
||||
self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
|
||||
|
||||
# remove the video clips used in validation
|
||||
if opt['val_partition'] == 'REDS4':
|
||||
val_partition = ['000', '011', '015', '020']
|
||||
elif opt['val_partition'] == 'official':
|
||||
val_partition = [f'{v:03d}' for v in range(240, 270)]
|
||||
else:
|
||||
raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
|
||||
f"Supported ones are ['official', 'REDS4'].")
|
||||
if opt['test_mode']:
|
||||
self.keys = [v for v in self.keys if v.split('/')[0] in val_partition]
|
||||
else:
|
||||
self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
|
||||
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
self.is_lmdb = False
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.is_lmdb = True
|
||||
if hasattr(self, 'flow_root') and self.flow_root is not None:
|
||||
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
|
||||
self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
|
||||
else:
|
||||
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
|
||||
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
||||
|
||||
# temporal augmentation configs
|
||||
self.interval_list = opt.get('interval_list', [1])
|
||||
self.random_reverse = opt.get('random_reverse', False)
|
||||
interval_str = ','.join(str(x) for x in self.interval_list)
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
|
||||
f'random reverse is {self.random_reverse}.')
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
scale = self.opt['scale']
|
||||
gt_size = self.opt['gt_size']
|
||||
key = self.keys[index]
|
||||
clip_name, frame_name = key.split('/') # key example: 000/00000000
|
||||
|
||||
# determine the neighboring frames
|
||||
interval = random.choice(self.interval_list)
|
||||
|
||||
# ensure not exceeding the borders
|
||||
start_frame_idx = int(frame_name)
|
||||
if start_frame_idx > 100 - self.num_frame * interval:
|
||||
start_frame_idx = random.randint(0, 100 - self.num_frame * interval)
|
||||
end_frame_idx = start_frame_idx + self.num_frame * interval
|
||||
|
||||
neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))
|
||||
|
||||
# random reverse
|
||||
if self.random_reverse and random.random() < 0.5:
|
||||
neighbor_list.reverse()
|
||||
|
||||
# get the neighboring LQ and GT frames
|
||||
img_lqs = []
|
||||
img_gts = []
|
||||
for neighbor in neighbor_list:
|
||||
if self.is_lmdb:
|
||||
img_lq_path = f'{clip_name}/{neighbor:08d}'
|
||||
img_gt_path = f'{clip_name}/{neighbor:08d}'
|
||||
else:
|
||||
img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
|
||||
img_gt_path = self.gt_root / clip_name / f'{neighbor:08d}.png'
|
||||
|
||||
# get LQ
|
||||
img_bytes = self.file_client.get(img_lq_path, 'lq')
|
||||
img_lq = imfrombytes(img_bytes, float32=True)
|
||||
img_lqs.append(img_lq)
|
||||
|
||||
# get GT
|
||||
img_bytes = self.file_client.get(img_gt_path, 'gt')
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
img_gts.append(img_gt)
|
||||
|
||||
# randomly crop
|
||||
img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
|
||||
|
||||
# augmentation - flip, rotate
|
||||
img_lqs.extend(img_gts)
|
||||
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
img_results = img2tensor(img_results)
|
||||
img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0)
|
||||
img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0)
|
||||
|
||||
# img_lqs: (t, c, h, w)
|
||||
# img_gts: (t, c, h, w)
|
||||
# key: str
|
||||
return {'lq': img_lqs, 'gt': img_gts, 'key': key}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.keys)
|
||||
68
basicsr/data/single_image_dataset.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from os import path as osp
|
||||
from torch.utils import data as data
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from basicsr.data.data_util import paths_from_lmdb
|
||||
from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class SingleImageDataset(data.Dataset):
|
||||
"""Read only lq images in the test phase.
|
||||
|
||||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
|
||||
|
||||
There are two modes:
|
||||
1. 'meta_info_file': Use meta information file to generate paths.
|
||||
2. 'folder': Scan folders to generate paths.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train datasets. It contains the following keys:
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
meta_info_file (str): Path for meta information file.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(SingleImageDataset, self).__init__()
|
||||
self.opt = opt
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
self.mean = opt['mean'] if 'mean' in opt else None
|
||||
self.std = opt['std'] if 'std' in opt else None
|
||||
self.lq_folder = opt['dataroot_lq']
|
||||
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = [self.lq_folder]
|
||||
self.io_backend_opt['client_keys'] = ['lq']
|
||||
self.paths = paths_from_lmdb(self.lq_folder)
|
||||
elif 'meta_info_file' in self.opt:
|
||||
with open(self.opt['meta_info_file'], 'r') as fin:
|
||||
self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin]
|
||||
else:
|
||||
self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
# load lq image
|
||||
lq_path = self.paths[index]
|
||||
img_bytes = self.file_client.get(lq_path, 'lq')
|
||||
img_lq = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# color space transform
|
||||
if 'color' in self.opt and self.opt['color'] == 'y':
|
||||
img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
|
||||
# normalize
|
||||
if self.mean is not None or self.std is not None:
|
||||
normalize(img_lq, self.mean, self.std, inplace=True)
|
||||
return {'lq': img_lq, 'lq_path': lq_path}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
207
basicsr/data/transforms.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import cv2
|
||||
import random
|
||||
import torch
|
||||
|
||||
def mod_crop(img, scale):
|
||||
"""Mod crop images, used during testing.
|
||||
|
||||
Args:
|
||||
img (ndarray): Input image.
|
||||
scale (int): Scale factor.
|
||||
|
||||
Returns:
|
||||
ndarray: Result image.
|
||||
"""
|
||||
img = img.copy()
|
||||
if img.ndim in (2, 3):
|
||||
h, w = img.shape[0], img.shape[1]
|
||||
h_remainder, w_remainder = h % scale, w % scale
|
||||
img = img[:h - h_remainder, :w - w_remainder, ...]
|
||||
else:
|
||||
raise ValueError(f'Wrong img ndim: {img.ndim}.')
|
||||
return img
|
||||
|
||||
|
||||
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
|
||||
"""Paired random crop. Support Numpy array and Tensor inputs.
|
||||
|
||||
It crops lists of lq and gt images with corresponding locations.
|
||||
|
||||
Args:
|
||||
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
|
||||
should have the same shape. If the input is an ndarray, it will
|
||||
be transformed to a list containing itself.
|
||||
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
|
||||
should have the same shape. If the input is an ndarray, it will
|
||||
be transformed to a list containing itself.
|
||||
gt_patch_size (int): GT patch size.
|
||||
scale (int): Scale factor.
|
||||
gt_path (str): Path to ground-truth. Default: None.
|
||||
|
||||
Returns:
|
||||
list[ndarray] | ndarray: GT images and LQ images. If returned results
|
||||
only have one element, just return ndarray.
|
||||
"""
|
||||
|
||||
if not isinstance(img_gts, list):
|
||||
img_gts = [img_gts]
|
||||
if not isinstance(img_lqs, list):
|
||||
img_lqs = [img_lqs]
|
||||
|
||||
# determine input type: Numpy array or Tensor
|
||||
input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
|
||||
|
||||
if input_type == 'Tensor':
|
||||
h_lq, w_lq = img_lqs[0].size()[-2:]
|
||||
h_gt, w_gt = img_gts[0].size()[-2:]
|
||||
else:
|
||||
h_lq, w_lq = img_lqs[0].shape[0:2]
|
||||
h_gt, w_gt = img_gts[0].shape[0:2]
|
||||
lq_patch_size = gt_patch_size // scale
|
||||
|
||||
if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
||||
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
||||
f'multiplication of LQ ({h_lq}, {w_lq}).')
|
||||
if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
||||
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
||||
f'({lq_patch_size}, {lq_patch_size}). '
|
||||
f'Please remove {gt_path}.')
|
||||
|
||||
# randomly choose top and left coordinates for lq patch
|
||||
top = random.randint(0, h_lq - lq_patch_size)
|
||||
left = random.randint(0, w_lq - lq_patch_size)
|
||||
|
||||
# crop lq patch
|
||||
if input_type == 'Tensor':
|
||||
img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
|
||||
else:
|
||||
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
|
||||
|
||||
# crop corresponding gt patch
|
||||
top_gt, left_gt = int(top * scale), int(left * scale)
|
||||
if input_type == 'Tensor':
|
||||
img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
|
||||
else:
|
||||
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
|
||||
if len(img_gts) == 1:
|
||||
img_gts = img_gts[0]
|
||||
if len(img_lqs) == 1:
|
||||
img_lqs = img_lqs[0]
|
||||
return img_gts, img_lqs
|
||||
|
||||
|
||||
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
|
||||
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
|
||||
|
||||
We use vertical flip and transpose for rotation implementation.
|
||||
All the images in the list use the same augmentation.
|
||||
|
||||
Args:
|
||||
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
|
||||
is an ndarray, it will be transformed to a list.
|
||||
hflip (bool): Horizontal flip. Default: True.
|
||||
rotation (bool): Ratotation. Default: True.
|
||||
flows (list[ndarray]: Flows to be augmented. If the input is an
|
||||
ndarray, it will be transformed to a list.
|
||||
Dimension is (h, w, 2). Default: None.
|
||||
return_status (bool): Return the status of flip and rotation.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
list[ndarray] | ndarray: Augmented images and flows. If returned
|
||||
results only have one element, just return ndarray.
|
||||
|
||||
"""
|
||||
hflip = hflip and random.random() < 0.5
|
||||
vflip = rotation and random.random() < 0.5
|
||||
rot90 = rotation and random.random() < 0.5
|
||||
|
||||
def _augment(img):
|
||||
if hflip: # horizontal
|
||||
cv2.flip(img, 1, img)
|
||||
if vflip: # vertical
|
||||
cv2.flip(img, 0, img)
|
||||
if rot90:
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
|
||||
def _augment_flow(flow):
|
||||
if hflip: # horizontal
|
||||
cv2.flip(flow, 1, flow)
|
||||
flow[:, :, 0] *= -1
|
||||
if vflip: # vertical
|
||||
cv2.flip(flow, 0, flow)
|
||||
flow[:, :, 1] *= -1
|
||||
if rot90:
|
||||
flow = flow.transpose(1, 0, 2)
|
||||
flow = flow[:, :, [1, 0]]
|
||||
return flow
|
||||
|
||||
if not isinstance(imgs, list):
|
||||
imgs = [imgs]
|
||||
imgs = [_augment(img) for img in imgs]
|
||||
if len(imgs) == 1:
|
||||
imgs = imgs[0]
|
||||
|
||||
if flows is not None:
|
||||
if not isinstance(flows, list):
|
||||
flows = [flows]
|
||||
flows = [_augment_flow(flow) for flow in flows]
|
||||
if len(flows) == 1:
|
||||
flows = flows[0]
|
||||
return imgs, flows
|
||||
else:
|
||||
if return_status:
|
||||
return imgs, (hflip, vflip, rot90)
|
||||
else:
|
||||
return imgs
|
||||
|
||||
|
||||
def img_rotate(img, angle, center=None, scale=1.0):
|
||||
"""Rotate image.
|
||||
|
||||
Args:
|
||||
img (ndarray): Image to be rotated.
|
||||
angle (float): Rotation angle in degrees. Positive values mean
|
||||
counter-clockwise rotation.
|
||||
center (tuple[int]): Rotation center. If the center is None,
|
||||
initialize it as the center of the image. Default: None.
|
||||
scale (float): Isotropic scale factor. Default: 1.0.
|
||||
"""
|
||||
(h, w) = img.shape[:2]
|
||||
|
||||
if center is None:
|
||||
center = (w // 2, h // 2)
|
||||
|
||||
matrix = cv2.getRotationMatrix2D(center, angle, scale)
|
||||
rotated_img = cv2.warpAffine(img, matrix, (w, h))
|
||||
return rotated_img
|
||||
|
||||
def random_crop(im, pch_size):
|
||||
'''
|
||||
Randomly crop a patch from the give image.
|
||||
'''
|
||||
h, w = im.shape[:2]
|
||||
# padding if necessary
|
||||
if h < pch_size or w < pch_size:
|
||||
pad_h = min(max(0, pch_size - h), h)
|
||||
pad_w = min(max(0, pch_size - w), w)
|
||||
im = cv2.copyMakeBorder(im, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
|
||||
|
||||
h, w = im.shape[:2]
|
||||
if h == pch_size:
|
||||
ind_h = 0
|
||||
elif h > pch_size:
|
||||
ind_h = random.randint(0, h-pch_size)
|
||||
else:
|
||||
raise ValueError('Image height is smaller than the patch size')
|
||||
if w == pch_size:
|
||||
ind_w = 0
|
||||
elif w > pch_size:
|
||||
ind_w = random.randint(0, w-pch_size)
|
||||
else:
|
||||
raise ValueError('Image width is smaller than the patch size')
|
||||
|
||||
im_pch = im[ind_h:ind_h+pch_size, ind_w:ind_w+pch_size,]
|
||||
|
||||
return im_pch
|
||||
283
basicsr/data/video_test_dataset.py
Normal file
@@ -0,0 +1,283 @@
|
||||
import glob
|
||||
import torch
|
||||
from os import path as osp
|
||||
from torch.utils import data as data
|
||||
|
||||
from basicsr.data.data_util import duf_downsample, generate_frame_indices, read_img_seq
|
||||
from basicsr.utils import get_root_logger, scandir
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class VideoTestDataset(data.Dataset):
|
||||
"""Video test dataset.
|
||||
|
||||
Supported datasets: Vid4, REDS4, REDSofficial.
|
||||
More generally, it supports testing dataset with following structures:
|
||||
|
||||
::
|
||||
|
||||
dataroot
|
||||
├── subfolder1
|
||||
├── frame000
|
||||
├── frame001
|
||||
├── ...
|
||||
├── subfolder2
|
||||
├── frame000
|
||||
├── frame001
|
||||
├── ...
|
||||
├── ...
|
||||
|
||||
For testing datasets, there is no need to prepare LMDB files.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train dataset. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
cache_data (bool): Whether to cache testing datasets.
|
||||
name (str): Dataset name.
|
||||
meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
|
||||
in the dataroot will be used.
|
||||
num_frame (int): Window size for input frames.
|
||||
padding (str): Padding mode.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(VideoTestDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.cache_data = opt['cache_data']
|
||||
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
|
||||
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
|
||||
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
|
||||
self.imgs_lq, self.imgs_gt = {}, {}
|
||||
if 'meta_info_file' in opt:
|
||||
with open(opt['meta_info_file'], 'r') as fin:
|
||||
subfolders = [line.split(' ')[0] for line in fin]
|
||||
subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders]
|
||||
subfolders_gt = [osp.join(self.gt_root, key) for key in subfolders]
|
||||
else:
|
||||
subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
|
||||
subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*')))
|
||||
|
||||
if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']:
|
||||
for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt):
|
||||
# get frame list for lq and gt
|
||||
subfolder_name = osp.basename(subfolder_lq)
|
||||
img_paths_lq = sorted(list(scandir(subfolder_lq, full_path=True)))
|
||||
img_paths_gt = sorted(list(scandir(subfolder_gt, full_path=True)))
|
||||
|
||||
max_idx = len(img_paths_lq)
|
||||
assert max_idx == len(img_paths_gt), (f'Different number of images in lq ({max_idx})'
|
||||
f' and gt folders ({len(img_paths_gt)})')
|
||||
|
||||
self.data_info['lq_path'].extend(img_paths_lq)
|
||||
self.data_info['gt_path'].extend(img_paths_gt)
|
||||
self.data_info['folder'].extend([subfolder_name] * max_idx)
|
||||
for i in range(max_idx):
|
||||
self.data_info['idx'].append(f'{i}/{max_idx}')
|
||||
border_l = [0] * max_idx
|
||||
for i in range(self.opt['num_frame'] // 2):
|
||||
border_l[i] = 1
|
||||
border_l[max_idx - i - 1] = 1
|
||||
self.data_info['border'].extend(border_l)
|
||||
|
||||
# cache data or save the frame list
|
||||
if self.cache_data:
|
||||
logger.info(f'Cache {subfolder_name} for VideoTestDataset...')
|
||||
self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq)
|
||||
self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt)
|
||||
else:
|
||||
self.imgs_lq[subfolder_name] = img_paths_lq
|
||||
self.imgs_gt[subfolder_name] = img_paths_gt
|
||||
else:
|
||||
raise ValueError(f'Non-supported video test dataset: {type(opt["name"])}')
|
||||
|
||||
def __getitem__(self, index):
|
||||
folder = self.data_info['folder'][index]
|
||||
idx, max_idx = self.data_info['idx'][index].split('/')
|
||||
idx, max_idx = int(idx), int(max_idx)
|
||||
border = self.data_info['border'][index]
|
||||
lq_path = self.data_info['lq_path'][index]
|
||||
|
||||
select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
|
||||
|
||||
if self.cache_data:
|
||||
imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
|
||||
img_gt = self.imgs_gt[folder][idx]
|
||||
else:
|
||||
img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
|
||||
imgs_lq = read_img_seq(img_paths_lq)
|
||||
img_gt = read_img_seq([self.imgs_gt[folder][idx]])
|
||||
img_gt.squeeze_(0)
|
||||
|
||||
return {
|
||||
'lq': imgs_lq, # (t, c, h, w)
|
||||
'gt': img_gt, # (c, h, w)
|
||||
'folder': folder, # folder name
|
||||
'idx': self.data_info['idx'][index], # e.g., 0/99
|
||||
'border': border, # 1 for border, 0 for non-border
|
||||
'lq_path': lq_path # center frame
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_info['gt_path'])
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class VideoTestVimeo90KDataset(data.Dataset):
|
||||
"""Video test dataset for Vimeo90k-Test dataset.
|
||||
|
||||
It only keeps the center frame for testing.
|
||||
For testing datasets, there is no need to prepare LMDB files.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train dataset. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
cache_data (bool): Whether to cache testing datasets.
|
||||
name (str): Dataset name.
|
||||
meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
|
||||
in the dataroot will be used.
|
||||
num_frame (int): Window size for input frames.
|
||||
padding (str): Padding mode.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(VideoTestVimeo90KDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.cache_data = opt['cache_data']
|
||||
if self.cache_data:
|
||||
raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.')
|
||||
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
|
||||
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
|
||||
neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
|
||||
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
|
||||
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
|
||||
with open(opt['meta_info_file'], 'r') as fin:
|
||||
subfolders = [line.split(' ')[0] for line in fin]
|
||||
for idx, subfolder in enumerate(subfolders):
|
||||
gt_path = osp.join(self.gt_root, subfolder, 'im4.png')
|
||||
self.data_info['gt_path'].append(gt_path)
|
||||
lq_paths = [osp.join(self.lq_root, subfolder, f'im{i}.png') for i in neighbor_list]
|
||||
self.data_info['lq_path'].append(lq_paths)
|
||||
self.data_info['folder'].append('vimeo90k')
|
||||
self.data_info['idx'].append(f'{idx}/{len(subfolders)}')
|
||||
self.data_info['border'].append(0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
lq_path = self.data_info['lq_path'][index]
|
||||
gt_path = self.data_info['gt_path'][index]
|
||||
imgs_lq = read_img_seq(lq_path)
|
||||
img_gt = read_img_seq([gt_path])
|
||||
img_gt.squeeze_(0)
|
||||
|
||||
return {
|
||||
'lq': imgs_lq, # (t, c, h, w)
|
||||
'gt': img_gt, # (c, h, w)
|
||||
'folder': self.data_info['folder'][index], # folder name
|
||||
'idx': self.data_info['idx'][index], # e.g., 0/843
|
||||
'border': self.data_info['border'][index], # 0 for non-border
|
||||
'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_info['gt_path'])
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class VideoTestDUFDataset(VideoTestDataset):
|
||||
""" Video test dataset for DUF dataset.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train dataset. Most of keys are the same as VideoTestDataset.
|
||||
It has the following extra keys:
|
||||
use_duf_downsampling (bool): Whether to use duf downsampling to generate low-resolution frames.
|
||||
scale (bool): Scale, which will be added automatically.
|
||||
"""
|
||||
|
||||
def __getitem__(self, index):
|
||||
folder = self.data_info['folder'][index]
|
||||
idx, max_idx = self.data_info['idx'][index].split('/')
|
||||
idx, max_idx = int(idx), int(max_idx)
|
||||
border = self.data_info['border'][index]
|
||||
lq_path = self.data_info['lq_path'][index]
|
||||
|
||||
select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
|
||||
|
||||
if self.cache_data:
|
||||
if self.opt['use_duf_downsampling']:
|
||||
# read imgs_gt to generate low-resolution frames
|
||||
imgs_lq = self.imgs_gt[folder].index_select(0, torch.LongTensor(select_idx))
|
||||
imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
|
||||
else:
|
||||
imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
|
||||
img_gt = self.imgs_gt[folder][idx]
|
||||
else:
|
||||
if self.opt['use_duf_downsampling']:
|
||||
img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx]
|
||||
# read imgs_gt to generate low-resolution frames
|
||||
imgs_lq = read_img_seq(img_paths_lq, require_mod_crop=True, scale=self.opt['scale'])
|
||||
imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
|
||||
else:
|
||||
img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
|
||||
imgs_lq = read_img_seq(img_paths_lq)
|
||||
img_gt = read_img_seq([self.imgs_gt[folder][idx]], require_mod_crop=True, scale=self.opt['scale'])
|
||||
img_gt.squeeze_(0)
|
||||
|
||||
return {
|
||||
'lq': imgs_lq, # (t, c, h, w)
|
||||
'gt': img_gt, # (c, h, w)
|
||||
'folder': folder, # folder name
|
||||
'idx': self.data_info['idx'][index], # e.g., 0/99
|
||||
'border': border, # 1 for border, 0 for non-border
|
||||
'lq_path': lq_path # center frame
|
||||
}
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class VideoRecurrentTestDataset(VideoTestDataset):
|
||||
"""Video test dataset for recurrent architectures, which takes LR video
|
||||
frames as input and output corresponding HR video frames.
|
||||
|
||||
Args:
|
||||
opt (dict): Same as VideoTestDataset. Unused opt:
|
||||
padding (str): Padding mode.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(VideoRecurrentTestDataset, self).__init__(opt)
|
||||
# Find unique folder strings
|
||||
self.folders = sorted(list(set(self.data_info['folder'])))
|
||||
|
||||
def __getitem__(self, index):
|
||||
folder = self.folders[index]
|
||||
|
||||
if self.cache_data:
|
||||
imgs_lq = self.imgs_lq[folder]
|
||||
imgs_gt = self.imgs_gt[folder]
|
||||
else:
|
||||
raise NotImplementedError('Without cache_data is not implemented.')
|
||||
|
||||
return {
|
||||
'lq': imgs_lq,
|
||||
'gt': imgs_gt,
|
||||
'folder': folder,
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.folders)
|
||||
199
basicsr/data/vimeo90k_dataset.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import random
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from torch.utils import data as data
|
||||
|
||||
from basicsr.data.transforms import augment, paired_random_crop
|
||||
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class Vimeo90KDataset(data.Dataset):
|
||||
"""Vimeo90K dataset for training.
|
||||
|
||||
The keys are generated from a meta info txt file.
|
||||
basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
|
||||
|
||||
Each line contains the following items, separated by a white space.
|
||||
|
||||
1. clip name;
|
||||
2. frame number;
|
||||
3. image shape
|
||||
|
||||
Examples:
|
||||
|
||||
::
|
||||
|
||||
00001/0001 7 (256,448,3)
|
||||
00001/0002 7 (256,448,3)
|
||||
|
||||
- Key examples: "00001/0001"
|
||||
- GT (gt): Ground-Truth;
|
||||
- LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
|
||||
|
||||
The neighboring frame list for different num_frame:
|
||||
|
||||
::
|
||||
|
||||
num_frame | frame list
|
||||
1 | 4
|
||||
3 | 3,4,5
|
||||
5 | 2,3,4,5,6
|
||||
7 | 1,2,3,4,5,6,7
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train dataset. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
dataroot_lq (str): Data root path for lq.
|
||||
meta_info_file (str): Path for meta information file.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
num_frame (int): Window size for input frames.
|
||||
gt_size (int): Cropped patched size for gt patches.
|
||||
random_reverse (bool): Random reverse input frames.
|
||||
use_hflip (bool): Use horizontal flips.
|
||||
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
||||
scale (bool): Scale, which will be added automatically.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(Vimeo90KDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
|
||||
|
||||
with open(opt['meta_info_file'], 'r') as fin:
|
||||
self.keys = [line.split(' ')[0] for line in fin]
|
||||
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
self.is_lmdb = False
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.is_lmdb = True
|
||||
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
|
||||
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
||||
|
||||
# indices of input images
|
||||
self.neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
|
||||
|
||||
# temporal augmentation configs
|
||||
self.random_reverse = opt['random_reverse']
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Random reverse is {self.random_reverse}.')
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
# random reverse
|
||||
if self.random_reverse and random.random() < 0.5:
|
||||
self.neighbor_list.reverse()
|
||||
|
||||
scale = self.opt['scale']
|
||||
gt_size = self.opt['gt_size']
|
||||
key = self.keys[index]
|
||||
clip, seq = key.split('/') # key example: 00001/0001
|
||||
|
||||
# get the GT frame (im4.png)
|
||||
if self.is_lmdb:
|
||||
img_gt_path = f'{key}/im4'
|
||||
else:
|
||||
img_gt_path = self.gt_root / clip / seq / 'im4.png'
|
||||
img_bytes = self.file_client.get(img_gt_path, 'gt')
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# get the neighboring LQ frames
|
||||
img_lqs = []
|
||||
for neighbor in self.neighbor_list:
|
||||
if self.is_lmdb:
|
||||
img_lq_path = f'{clip}/{seq}/im{neighbor}'
|
||||
else:
|
||||
img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
|
||||
img_bytes = self.file_client.get(img_lq_path, 'lq')
|
||||
img_lq = imfrombytes(img_bytes, float32=True)
|
||||
img_lqs.append(img_lq)
|
||||
|
||||
# randomly crop
|
||||
img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
|
||||
|
||||
# augmentation - flip, rotate
|
||||
img_lqs.append(img_gt)
|
||||
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
img_results = img2tensor(img_results)
|
||||
img_lqs = torch.stack(img_results[0:-1], dim=0)
|
||||
img_gt = img_results[-1]
|
||||
|
||||
# img_lqs: (t, c, h, w)
|
||||
# img_gt: (c, h, w)
|
||||
# key: str
|
||||
return {'lq': img_lqs, 'gt': img_gt, 'key': key}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.keys)
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class Vimeo90KRecurrentDataset(Vimeo90KDataset):
|
||||
|
||||
def __init__(self, opt):
|
||||
super(Vimeo90KRecurrentDataset, self).__init__(opt)
|
||||
|
||||
self.flip_sequence = opt['flip_sequence']
|
||||
self.neighbor_list = [1, 2, 3, 4, 5, 6, 7]
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
# random reverse
|
||||
if self.random_reverse and random.random() < 0.5:
|
||||
self.neighbor_list.reverse()
|
||||
|
||||
scale = self.opt['scale']
|
||||
gt_size = self.opt['gt_size']
|
||||
key = self.keys[index]
|
||||
clip, seq = key.split('/') # key example: 00001/0001
|
||||
|
||||
# get the neighboring LQ and GT frames
|
||||
img_lqs = []
|
||||
img_gts = []
|
||||
for neighbor in self.neighbor_list:
|
||||
if self.is_lmdb:
|
||||
img_lq_path = f'{clip}/{seq}/im{neighbor}'
|
||||
img_gt_path = f'{clip}/{seq}/im{neighbor}'
|
||||
else:
|
||||
img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
|
||||
img_gt_path = self.gt_root / clip / seq / f'im{neighbor}.png'
|
||||
# LQ
|
||||
img_bytes = self.file_client.get(img_lq_path, 'lq')
|
||||
img_lq = imfrombytes(img_bytes, float32=True)
|
||||
# GT
|
||||
img_bytes = self.file_client.get(img_gt_path, 'gt')
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
img_lqs.append(img_lq)
|
||||
img_gts.append(img_gt)
|
||||
|
||||
# randomly crop
|
||||
img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
|
||||
|
||||
# augmentation - flip, rotate
|
||||
img_lqs.extend(img_gts)
|
||||
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
|
||||
|
||||
img_results = img2tensor(img_results)
|
||||
img_lqs = torch.stack(img_results[:7], dim=0)
|
||||
img_gts = torch.stack(img_results[7:], dim=0)
|
||||
|
||||
if self.flip_sequence: # flip the sequence: 7 frames to 14 frames
|
||||
img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0)
|
||||
img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0)
|
||||
|
||||
# img_lqs: (t, c, h, w)
|
||||
# img_gt: (c, h, w)
|
||||
# key: str
|
||||
return {'lq': img_lqs, 'gt': img_gts, 'key': key}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.keys)
|
||||
47
basicsr/utils/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb
|
||||
from .diffjpeg import DiffJPEG
|
||||
from .file_client import FileClient
|
||||
from .img_process_util import USMSharp, usm_sharp
|
||||
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
|
||||
from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
|
||||
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
|
||||
from .options import yaml_load
|
||||
|
||||
__all__ = [
|
||||
# color_util.py
|
||||
'bgr2ycbcr',
|
||||
'rgb2ycbcr',
|
||||
'rgb2ycbcr_pt',
|
||||
'ycbcr2bgr',
|
||||
'ycbcr2rgb',
|
||||
# file_client.py
|
||||
'FileClient',
|
||||
# img_util.py
|
||||
'img2tensor',
|
||||
'tensor2img',
|
||||
'imfrombytes',
|
||||
'imwrite',
|
||||
'crop_border',
|
||||
# logger.py
|
||||
'MessageLogger',
|
||||
'AvgTimer',
|
||||
'init_tb_logger',
|
||||
'init_wandb_logger',
|
||||
'get_root_logger',
|
||||
'get_env_info',
|
||||
# misc.py
|
||||
'set_random_seed',
|
||||
'get_time_str',
|
||||
'mkdir_and_rename',
|
||||
'make_exp_dirs',
|
||||
'scandir',
|
||||
'check_resume',
|
||||
'sizeof_fmt',
|
||||
# diffjpeg
|
||||
'DiffJPEG',
|
||||
# img_process_util
|
||||
'USMSharp',
|
||||
'usm_sharp',
|
||||
# options
|
||||
'yaml_load'
|
||||
]
|
||||