mirror of
https://github.com/aljazceru/InvSR.git
synced 2025-12-17 06:14:22 +01:00
216 lines
7.8 KiB
Python
216 lines
7.8 KiB
Python
import random
|
|
import json
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
from omegaconf import ListConfig
|
|
|
|
import cv2
|
|
import torch
|
|
from functools import partial
|
|
import torchvision as thv
|
|
from torch.utils.data import Dataset
|
|
|
|
from utils import util_sisr
|
|
from utils import util_image
|
|
from utils import util_common
|
|
|
|
from basicsr.data.transforms import augment
|
|
from basicsr.data.realesrgan_dataset import RealESRGANDataset
|
|
|
|
def get_transforms(transform_type, kwargs):
|
|
'''
|
|
Accepted optins in kwargs.
|
|
mean: scaler or sequence, for nornmalization
|
|
std: scaler or sequence, for nornmalization
|
|
crop_size: int or sequence, random or center cropping
|
|
scale, out_shape: for Bicubic
|
|
min_max: tuple or list with length 2, for cliping
|
|
'''
|
|
if transform_type == 'default':
|
|
transform = thv.transforms.Compose([
|
|
thv.transforms.ToTensor(),
|
|
thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)),
|
|
])
|
|
elif transform_type == 'resize_ccrop_norm':
|
|
transform = thv.transforms.Compose([
|
|
util_image.SmallestMaxSize(
|
|
max_size=kwargs.get('size'),
|
|
interpolation=kwargs.get('interpolation'),
|
|
),
|
|
thv.transforms.ToTensor(),
|
|
thv.transforms.CenterCrop(size=kwargs.get('size', None)),
|
|
thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)),
|
|
])
|
|
elif transform_type == 'ccrop_norm':
|
|
transform = thv.transforms.Compose([
|
|
thv.transforms.ToTensor(),
|
|
thv.transforms.CenterCrop(size=kwargs.get('size', None)),
|
|
thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)),
|
|
])
|
|
elif transform_type == 'rcrop_aug_norm':
|
|
transform = thv.transforms.Compose([
|
|
util_image.RandomCrop(pch_size=kwargs.get('pch_size', 256)),
|
|
util_image.SpatialAug(
|
|
only_hflip=kwargs.get('only_hflip', False),
|
|
only_vflip=kwargs.get('only_vflip', False),
|
|
only_hvflip=kwargs.get('only_hvflip', False),
|
|
),
|
|
util_image.ToTensor(max_value=kwargs.get('max_value')), # (ndarray, hwc) --> (Tensor, chw)
|
|
thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)),
|
|
])
|
|
elif transform_type == 'aug_norm':
|
|
transform = thv.transforms.Compose([
|
|
util_image.SpatialAug(
|
|
only_hflip=kwargs.get('only_hflip', False),
|
|
only_vflip=kwargs.get('only_vflip', False),
|
|
only_hvflip=kwargs.get('only_hvflip', False),
|
|
),
|
|
util_image.ToTensor(), # hwc --> chw
|
|
thv.transforms.Normalize(mean=kwargs.get('mean', 0.5), std=kwargs.get('std', 0.5)),
|
|
])
|
|
else:
|
|
raise ValueError(f'Unexpected transform_variant {transform_variant}')
|
|
return transform
|
|
|
|
def create_dataset(dataset_config):
|
|
if dataset_config['type'] == 'base':
|
|
dataset = BaseData(**dataset_config['params'])
|
|
elif dataset_config['type'] == 'base_meta':
|
|
dataset = BaseDataMetaCond(**dataset_config['params'])
|
|
elif dataset_config['type'] == 'realesrgan':
|
|
dataset = RealESRGANDataset(dataset_config['params'])
|
|
else:
|
|
raise NotImplementedError(f"{dataset_config['type']}")
|
|
|
|
return dataset
|
|
|
|
class BaseData(Dataset):
|
|
def __init__(
|
|
self,
|
|
dir_path,
|
|
txt_path=None,
|
|
transform_type='default',
|
|
transform_kwargs={'mean':0.0, 'std':1.0},
|
|
extra_dir_path=None,
|
|
extra_transform_type=None,
|
|
extra_transform_kwargs=None,
|
|
length=None,
|
|
need_path=False,
|
|
im_exts=['png', 'jpg', 'jpeg', 'JPEG', 'bmp'],
|
|
recursive=False,
|
|
):
|
|
super().__init__()
|
|
|
|
file_paths_all = []
|
|
if dir_path is not None:
|
|
file_paths_all.extend(util_common.scan_files_from_folder(dir_path, im_exts, recursive))
|
|
if txt_path is not None:
|
|
file_paths_all.extend(util_common.readline_txt(txt_path))
|
|
|
|
self.file_paths = file_paths_all if length is None else random.sample(file_paths_all, length)
|
|
self.file_paths_all = file_paths_all
|
|
|
|
self.length = length
|
|
self.need_path = need_path
|
|
self.transform = get_transforms(transform_type, transform_kwargs)
|
|
|
|
self.extra_dir_path = extra_dir_path
|
|
if extra_dir_path is not None:
|
|
assert extra_transform_type is not None
|
|
self.extra_transform = get_transforms(extra_transform_type, extra_transform_kwargs)
|
|
|
|
def __len__(self):
|
|
return len(self.file_paths)
|
|
|
|
def __getitem__(self, index):
|
|
im_path_base = self.file_paths[index]
|
|
im_base = util_image.imread(im_path_base, chn='rgb', dtype='float32')
|
|
|
|
im_target = self.transform(im_base)
|
|
out = {'image':im_target, 'lq':im_target}
|
|
|
|
if self.extra_dir_path is not None:
|
|
im_path_extra = Path(self.extra_dir_path) / Path(im_path_base).name
|
|
im_extra = util_image.imread(im_path_extra, chn='rgb', dtype='float32')
|
|
im_extra = self.extra_transform(im_extra)
|
|
out['gt'] = im_extra
|
|
|
|
if self.need_path:
|
|
out['path'] = im_path_base
|
|
|
|
return out
|
|
|
|
def reset_dataset(self):
|
|
self.file_paths = random.sample(self.file_paths_all, self.length)
|
|
|
|
class BaseDataMetaCond(Dataset):
|
|
def __init__(
|
|
self,
|
|
meta_dir,
|
|
transform_type='default',
|
|
transform_kwargs={'mean':0.5, 'std':0.5},
|
|
length=None,
|
|
need_path=False,
|
|
cond_key='canny',
|
|
cond_transform_type='default',
|
|
cond_transform_kwargs={'mean':0.5, 'std':0.5},
|
|
):
|
|
super().__init__()
|
|
if not isinstance(meta_dir, ListConfig):
|
|
meta_dir = [meta_dir,]
|
|
meta_list = []
|
|
# for current_dir in meta_dir:
|
|
# for json_path in Path(current_dir).glob("*.json"):
|
|
# with open(json_path, 'r') as json_file:
|
|
# meta_info = json.load(json_file)
|
|
# meta_list.append(meta_info)
|
|
for current_dir in meta_dir:
|
|
meta_list.extend(sorted([str(x) for x in Path(current_dir).glob("*.json")]))
|
|
self.meta_list = meta_list if length is None else meta_list[:length]
|
|
|
|
self.cond_key = cond_key
|
|
self.length = length
|
|
self.need_path = need_path
|
|
self.transform = get_transforms(transform_type, transform_kwargs)
|
|
self.cond_trasform = get_transforms(cond_transform_type, cond_transform_kwargs)
|
|
|
|
def __len__(self):
|
|
return len(self.meta_list)
|
|
|
|
def __getitem__(self, index):
|
|
# meta_info = self.meta_list[index]
|
|
json_path = self.meta_list[index]
|
|
with open(json_path, 'r') as json_file:
|
|
meta_info = json.load(json_file)
|
|
|
|
# images
|
|
im_path = meta_info['source']
|
|
im_source = util_image.imread(im_path, chn='rgb', dtype='uint8')
|
|
im_source = self.transform(im_source)
|
|
out = {'image': im_source,}
|
|
if self.need_path:
|
|
out['path'] = im_path
|
|
|
|
# latent
|
|
if 'latent' in meta_info:
|
|
latent_path = meta_info['latent']
|
|
out['latent'] = np.load(latent_path)
|
|
|
|
# prompt
|
|
out['txt'] = meta_info['prompt']
|
|
|
|
# condition
|
|
cond_key = self.cond_key
|
|
cond_path = meta_info[cond_key]
|
|
if cond_key == 'canny':
|
|
cond = util_image.imread(cond_path, chn='gray', dtype='uint8')[:, :, None]
|
|
elif cond_key == 'seg':
|
|
cond = util_image.imread(cond_path, chn='rgb', dtype='uint8')
|
|
else:
|
|
raise ValueError(f"Unexpected cond key: {cond_key}")
|
|
cond = self.cond_trasform(cond)
|
|
out['cond'] = cond
|
|
|
|
return out
|