mirror of
https://github.com/aljazceru/InvSR.git
synced 2025-12-17 06:14:22 +01:00
first commit
This commit is contained in:
215
datapipe/datasets.py
Normal file
215
datapipe/datasets.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import random
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
from omegaconf import ListConfig
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from functools import partial
|
||||
import torchvision as thv
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from utils import util_sisr
|
||||
from utils import util_image
|
||||
from utils import util_common
|
||||
|
||||
from basicsr.data.transforms import augment
|
||||
from basicsr.data.realesrgan_dataset import RealESRGANDataset
|
||||
|
||||
def get_transforms(transform_type, kwargs):
|
||||
'''
|
||||
Accepted optins in kwargs.
|
||||
mean: scaler or sequence, for nornmalization
|
||||
std: scaler or sequence, for nornmalization
|
||||
crop_size: int or sequence, random or center cropping
|
||||
scale, out_shape: for Bicubic
|
||||
min_max: tuple or list with length 2, for cliping
|
||||
'''
|
||||
if transform_type == 'default':
|
||||
transform = thv.transforms.Compose([
|
||||
thv.transforms.ToTensor(),
|
||||
thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)),
|
||||
])
|
||||
elif transform_type == 'resize_ccrop_norm':
|
||||
transform = thv.transforms.Compose([
|
||||
util_image.SmallestMaxSize(
|
||||
max_size=kwargs.get('size'),
|
||||
interpolation=kwargs.get('interpolation'),
|
||||
),
|
||||
thv.transforms.ToTensor(),
|
||||
thv.transforms.CenterCrop(size=kwargs.get('size', None)),
|
||||
thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)),
|
||||
])
|
||||
elif transform_type == 'ccrop_norm':
|
||||
transform = thv.transforms.Compose([
|
||||
thv.transforms.ToTensor(),
|
||||
thv.transforms.CenterCrop(size=kwargs.get('size', None)),
|
||||
thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)),
|
||||
])
|
||||
elif transform_type == 'rcrop_aug_norm':
|
||||
transform = thv.transforms.Compose([
|
||||
util_image.RandomCrop(pch_size=kwargs.get('pch_size', 256)),
|
||||
util_image.SpatialAug(
|
||||
only_hflip=kwargs.get('only_hflip', False),
|
||||
only_vflip=kwargs.get('only_vflip', False),
|
||||
only_hvflip=kwargs.get('only_hvflip', False),
|
||||
),
|
||||
util_image.ToTensor(max_value=kwargs.get('max_value')), # (ndarray, hwc) --> (Tensor, chw)
|
||||
thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)),
|
||||
])
|
||||
elif transform_type == 'aug_norm':
|
||||
transform = thv.transforms.Compose([
|
||||
util_image.SpatialAug(
|
||||
only_hflip=kwargs.get('only_hflip', False),
|
||||
only_vflip=kwargs.get('only_vflip', False),
|
||||
only_hvflip=kwargs.get('only_hvflip', False),
|
||||
),
|
||||
util_image.ToTensor(), # hwc --> chw
|
||||
thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)),
|
||||
])
|
||||
else:
|
||||
raise ValueError(f'Unexpected transform_variant {transform_variant}')
|
||||
return transform
|
||||
|
||||
def create_dataset(dataset_config):
|
||||
if dataset_config['type'] == 'base':
|
||||
dataset = BaseData(**dataset_config['params'])
|
||||
elif dataset_config['type'] == 'base_meta':
|
||||
dataset = BaseDataMetaCond(**dataset_config['params'])
|
||||
elif dataset_config['type'] == 'realesrgan':
|
||||
dataset = RealESRGANDataset(dataset_config['params'])
|
||||
else:
|
||||
raise NotImplementedError(f"{dataset_config['type']}")
|
||||
|
||||
return dataset
|
||||
|
||||
class BaseData(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
dir_path,
|
||||
txt_path=None,
|
||||
transform_type='default',
|
||||
transform_kwargs={'mean':0.0, 'std':1.0},
|
||||
extra_dir_path=None,
|
||||
extra_transform_type=None,
|
||||
extra_transform_kwargs=None,
|
||||
length=None,
|
||||
need_path=False,
|
||||
im_exts=['png', 'jpg', 'jpeg', 'JPEG', 'bmp'],
|
||||
recursive=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
file_paths_all = []
|
||||
if dir_path is not None:
|
||||
file_paths_all.extend(util_common.scan_files_from_folder(dir_path, im_exts, recursive))
|
||||
if txt_path is not None:
|
||||
file_paths_all.extend(util_common.readline_txt(txt_path))
|
||||
|
||||
self.file_paths = file_paths_all if length is None else random.sample(file_paths_all, length)
|
||||
self.file_paths_all = file_paths_all
|
||||
|
||||
self.length = length
|
||||
self.need_path = need_path
|
||||
self.transform = get_transforms(transform_type, transform_kwargs)
|
||||
|
||||
self.extra_dir_path = extra_dir_path
|
||||
if extra_dir_path is not None:
|
||||
assert extra_transform_type is not None
|
||||
self.extra_transform = get_transforms(extra_transform_type, extra_transform_kwargs)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
im_path_base = self.file_paths[index]
|
||||
im_base = util_image.imread(im_path_base, chn='rgb', dtype='float32')
|
||||
|
||||
im_target = self.transform(im_base)
|
||||
out = {'image':im_target, 'lq':im_target}
|
||||
|
||||
if self.extra_dir_path is not None:
|
||||
im_path_extra = Path(self.extra_dir_path) / Path(im_path_base).name
|
||||
im_extra = util_image.imread(im_path_extra, chn='rgb', dtype='float32')
|
||||
im_extra = self.extra_transform(im_extra)
|
||||
out['gt'] = im_extra
|
||||
|
||||
if self.need_path:
|
||||
out['path'] = im_path_base
|
||||
|
||||
return out
|
||||
|
||||
def reset_dataset(self):
|
||||
self.file_paths = random.sample(self.file_paths_all, self.length)
|
||||
|
||||
class BaseDataMetaCond(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
meta_dir,
|
||||
transform_type='default',
|
||||
transform_kwargs={'mean':0.5, 'std':0.5},
|
||||
length=None,
|
||||
need_path=False,
|
||||
cond_key='canny',
|
||||
cond_transform_type='default',
|
||||
cond_transform_kwargs={'mean':0.5, 'std':0.5},
|
||||
):
|
||||
super().__init__()
|
||||
if not isinstance(meta_dir, ListConfig):
|
||||
meta_dir = [meta_dir,]
|
||||
meta_list = []
|
||||
# for current_dir in meta_dir:
|
||||
# for json_path in Path(current_dir).glob("*.json"):
|
||||
# with open(json_path, 'r') as json_file:
|
||||
# meta_info = json.load(json_file)
|
||||
# meta_list.append(meta_info)
|
||||
for current_dir in meta_dir:
|
||||
meta_list.extend(sorted([str(x) for x in Path(current_dir).glob("*.json")]))
|
||||
self.meta_list = meta_list if length is None else meta_list[:length]
|
||||
|
||||
self.cond_key = cond_key
|
||||
self.length = length
|
||||
self.need_path = need_path
|
||||
self.transform = get_transforms(transform_type, transform_kwargs)
|
||||
self.cond_trasform = get_transforms(cond_transform_type, cond_transform_kwargs)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.meta_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# meta_info = self.meta_list[index]
|
||||
json_path = self.meta_list[index]
|
||||
with open(json_path, 'r') as json_file:
|
||||
meta_info = json.load(json_file)
|
||||
|
||||
# images
|
||||
im_path = meta_info['source']
|
||||
im_source = util_image.imread(im_path, chn='rgb', dtype='uint8')
|
||||
im_source = self.transform(im_source)
|
||||
out = {'image': im_source,}
|
||||
if self.need_path:
|
||||
out['path'] = im_path
|
||||
|
||||
# latent
|
||||
if 'latent' in meta_info:
|
||||
latent_path = meta_info['latent']
|
||||
out['latent'] = np.load(latent_path)
|
||||
|
||||
# prompt
|
||||
out['txt'] = meta_info['prompt']
|
||||
|
||||
# condition
|
||||
cond_key = self.cond_key
|
||||
cond_path = meta_info[cond_key]
|
||||
if cond_key == 'canny':
|
||||
cond = util_image.imread(cond_path, chn='gray', dtype='uint8')[:, :, None]
|
||||
elif cond_key == 'seg':
|
||||
cond = util_image.imread(cond_path, chn='rgb', dtype='uint8')
|
||||
else:
|
||||
raise ValueError(f"Unexpected cond key: {cond_key}")
|
||||
cond = self.cond_trasform(cond)
|
||||
out['cond'] = cond
|
||||
|
||||
return out
|
||||
Reference in New Issue
Block a user