first commit

This commit is contained in:
zsyOAOA
2024-12-11 18:46:36 +08:00
parent 9e65255d34
commit 27f2eb7dc3
847 changed files with 377076 additions and 2 deletions

BIN
utils/.DS_Store vendored Normal file

Binary file not shown.

5
utils/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-01-18 11:40:23

428
utils/resize.py Normal file
View File

@@ -0,0 +1,428 @@
"""
A standalone PyTorch implementation for fast and efficient bicubic resampling.
The resulting values are the same to MATLAB function imresize('bicubic').
## Author: Sanghyun Son
## Email: sonsang35@gmail.com (primary), thstkdgus35@snu.ac.kr (secondary)
## Version: 1.2.0
## Last update: July 9th, 2020 (KST)
Dependency: torch
Example::
>>> import torch
>>> import core
>>> x = torch.arange(16).float().view(1, 1, 4, 4)
>>> y = core.imresize(x, sizes=(3, 3))
>>> print(y)
tensor([[[[ 0.7506, 2.1004, 3.4503],
[ 6.1505, 7.5000, 8.8499],
[11.5497, 12.8996, 14.2494]]]])
"""
import math
import typing
import torch
from torch.nn import functional as F
__all__ = ['imresize']
_I = typing.Optional[int]
_D = typing.Optional[torch.dtype]
def nearest_contribution(x: torch.Tensor) -> torch.Tensor:
range_around_0 = torch.logical_and(x.gt(-0.5), x.le(0.5))
cont = range_around_0.to(dtype=x.dtype)
return cont
def linear_contribution(x: torch.Tensor) -> torch.Tensor:
ax = x.abs()
range_01 = ax.le(1)
cont = (1 - ax) * range_01.to(dtype=x.dtype)
return cont
def cubic_contribution(x: torch.Tensor, a: float = -0.5) -> torch.Tensor:
ax = x.abs()
ax2 = ax * ax
ax3 = ax * ax2
range_01 = ax.le(1)
range_12 = torch.logical_and(ax.gt(1), ax.le(2))
cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1
cont_01 = cont_01 * range_01.to(dtype=x.dtype)
cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a)
cont_12 = cont_12 * range_12.to(dtype=x.dtype)
cont = cont_01 + cont_12
return cont
def gaussian_contribution(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor:
range_3sigma = (x.abs() <= 3 * sigma + 1)
# Normalization will be done after
cont = torch.exp(-x.pow(2) / (2 * sigma**2))
cont = cont * range_3sigma.to(dtype=x.dtype)
return cont
def discrete_kernel(kernel: str, scale: float, antialiasing: bool = True) -> torch.Tensor:
'''
For downsampling with integer scale only.
'''
downsampling_factor = int(1 / scale)
if kernel == 'cubic':
kernel_size_orig = 4
else:
raise ValueError('Pass!')
if antialiasing:
kernel_size = kernel_size_orig * downsampling_factor
else:
kernel_size = kernel_size_orig
if downsampling_factor % 2 == 0:
a = kernel_size_orig * (0.5 - 1 / (2 * kernel_size))
else:
kernel_size -= 1
a = kernel_size_orig * (0.5 - 1 / (kernel_size + 1))
with torch.no_grad():
r = torch.linspace(-a, a, steps=kernel_size)
k = cubic_contribution(r).view(-1, 1)
k = torch.matmul(k, k.t())
k /= k.sum()
return k
def reflect_padding(x: torch.Tensor, dim: int, pad_pre: int, pad_post: int) -> torch.Tensor:
'''
Apply reflect padding to the given Tensor.
Note that it is slightly different from the PyTorch functional.pad,
where boundary elements are used only once.
Instead, we follow the MATLAB implementation
which uses boundary elements twice.
For example,
[a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation,
while our implementation yields [a, a, b, c, d, d].
'''
b, c, h, w = x.size()
if dim == 2 or dim == -2:
padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w)
padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x)
for p in range(pad_pre):
padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :])
for p in range(pad_post):
padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :])
else:
padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post)
padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x)
for p in range(pad_pre):
padding_buffer[..., pad_pre - p - 1].copy_(x[..., p])
for p in range(pad_post):
padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)])
return padding_buffer
def padding(x: torch.Tensor,
dim: int,
pad_pre: int,
pad_post: int,
padding_type: typing.Optional[str] = 'reflect') -> torch.Tensor:
if padding_type is None:
return x
elif padding_type == 'reflect':
x_pad = reflect_padding(x, dim, pad_pre, pad_post)
else:
raise ValueError('{} padding is not supported!'.format(padding_type))
return x_pad
def get_padding(base: torch.Tensor, kernel_size: int, x_size: int) -> typing.Tuple[int, int, torch.Tensor]:
base = base.long()
r_min = base.min()
r_max = base.max() + kernel_size - 1
if r_min <= 0:
pad_pre = -r_min
pad_pre = pad_pre.item()
base += pad_pre
else:
pad_pre = 0
if r_max >= x_size:
pad_post = r_max - x_size + 1
pad_post = pad_post.item()
else:
pad_post = 0
return pad_pre, pad_post, base
def get_weight(dist: torch.Tensor,
kernel_size: int,
kernel: str = 'cubic',
sigma: float = 2.0,
antialiasing_factor: float = 1) -> torch.Tensor:
buffer_pos = dist.new_zeros(kernel_size, len(dist))
for idx, buffer_sub in enumerate(buffer_pos):
buffer_sub.copy_(dist - idx)
# Expand (downsampling) / Shrink (upsampling) the receptive field.
buffer_pos *= antialiasing_factor
if kernel == 'cubic':
weight = cubic_contribution(buffer_pos)
elif kernel == 'gaussian':
weight = gaussian_contribution(buffer_pos, sigma=sigma)
else:
raise ValueError('{} kernel is not supported!'.format(kernel))
weight /= weight.sum(dim=0, keepdim=True)
return weight
def reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor:
# Resize height
if dim == 2 or dim == -2:
k = (kernel_size, 1)
h_out = x.size(-2) - kernel_size + 1
w_out = x.size(-1)
# Resize width
else:
k = (1, kernel_size)
h_out = x.size(-2)
w_out = x.size(-1) - kernel_size + 1
unfold = F.unfold(x, k)
unfold = unfold.view(unfold.size(0), -1, h_out, w_out)
return unfold
def reshape_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, int, int]:
if x.dim() == 4:
b, c, h, w = x.size()
elif x.dim() == 3:
c, h, w = x.size()
b = None
elif x.dim() == 2:
h, w = x.size()
b = c = None
else:
raise ValueError('{}-dim Tensor is not supported!'.format(x.dim()))
x = x.view(-1, 1, h, w)
return x, b, c, h, w
def reshape_output(x: torch.Tensor, b: _I, c: _I) -> torch.Tensor:
rh = x.size(-2)
rw = x.size(-1)
# Back to the original dimension
if b is not None:
x = x.view(b, c, rh, rw) # 4-dim
else:
if c is not None:
x = x.view(c, rh, rw) # 3-dim
else:
x = x.view(rh, rw) # 2-dim
return x
def cast_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]:
if x.dtype != torch.float32 or x.dtype != torch.float64:
dtype = x.dtype
x = x.float()
else:
dtype = None
return x, dtype
def cast_output(x: torch.Tensor, dtype: _D) -> torch.Tensor:
if dtype is not None:
if not dtype.is_floating_point:
x = x - x.detach() + x.round()
# To prevent over/underflow when converting types
if dtype is torch.uint8:
x = x.clamp(0, 255)
x = x.to(dtype=dtype)
return x
def resize_1d(x: torch.Tensor,
dim: int,
size: int,
scale: float,
kernel: str = 'cubic',
sigma: float = 2.0,
padding_type: str = 'reflect',
antialiasing: bool = True) -> torch.Tensor:
'''
Args:
x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W).
dim (int):
scale (float):
size (int):
Return:
'''
# Identity case
if scale == 1:
return x
# Default bicubic kernel with antialiasing (only when downsampling)
if kernel == 'cubic':
kernel_size = 4
else:
kernel_size = math.floor(6 * sigma)
if antialiasing and (scale < 1):
antialiasing_factor = scale
kernel_size = math.ceil(kernel_size / antialiasing_factor)
else:
antialiasing_factor = 1
# We allow margin to both sizes
kernel_size += 2
# Weights only depend on the shape of input and output,
# so we do not calculate gradients here.
with torch.no_grad():
pos = torch.linspace(
0,
size - 1,
steps=size,
dtype=x.dtype,
device=x.device,
)
pos = (pos + 0.5) / scale - 0.5
base = pos.floor() - (kernel_size // 2) + 1
dist = pos - base
weight = get_weight(
dist,
kernel_size,
kernel=kernel,
sigma=sigma,
antialiasing_factor=antialiasing_factor,
)
pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim))
# To backpropagate through x
x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type)
unfold = reshape_tensor(x_pad, dim, kernel_size)
# Subsampling first
if dim == 2 or dim == -2:
sample = unfold[..., base, :]
weight = weight.view(1, kernel_size, sample.size(2), 1)
else:
sample = unfold[..., base]
weight = weight.view(1, kernel_size, 1, sample.size(3))
# Apply the kernel
x = sample * weight
x = x.sum(dim=1, keepdim=True)
return x
def downsampling_2d(x: torch.Tensor, k: torch.Tensor, scale: int, padding_type: str = 'reflect') -> torch.Tensor:
c = x.size(1)
k_h = k.size(-2)
k_w = k.size(-1)
k = k.to(dtype=x.dtype, device=x.device)
k = k.view(1, 1, k_h, k_w)
k = k.repeat(c, c, 1, 1)
e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False)
e = e.view(c, c, 1, 1)
k = k * e
pad_h = (k_h - scale) // 2
pad_w = (k_w - scale) // 2
x = padding(x, -2, pad_h, pad_h, padding_type=padding_type)
x = padding(x, -1, pad_w, pad_w, padding_type=padding_type)
y = F.conv2d(x, k, padding=0, stride=scale)
return y
def imresize(x: torch.Tensor,
scale: typing.Optional[float] = None,
sizes: typing.Optional[typing.Tuple[int, int]] = None,
kernel: typing.Union[str, torch.Tensor] = 'cubic',
sigma: float = 2,
rotation_degree: float = 0,
padding_type: str = 'reflect',
antialiasing: bool = True) -> torch.Tensor:
"""
Args:
x (torch.Tensor):
scale (float):
sizes (tuple(int, int)):
kernel (str, default='cubic'):
sigma (float, default=2):
rotation_degree (float, default=0):
padding_type (str, default='reflect'):
antialiasing (bool, default=True):
Return:
torch.Tensor:
"""
if scale is None and sizes is None:
raise ValueError('One of scale or sizes must be specified!')
if scale is not None and sizes is not None:
raise ValueError('Please specify scale or sizes to avoid conflict!')
x, b, c, h, w = reshape_input(x)
if sizes is None and scale is not None:
'''
# Check if we can apply the convolution algorithm
scale_inv = 1 / scale
if isinstance(kernel, str) and scale_inv.is_integer():
kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing)
elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer():
raise ValueError(
'An integer downsampling factor '
'should be used with a predefined kernel!'
)
'''
# Determine output size
sizes = (math.ceil(h * scale), math.ceil(w * scale))
scales = (scale, scale)
if scale is None and sizes is not None:
scales = (sizes[0] / h, sizes[1] / w)
x, dtype = cast_input(x)
if isinstance(kernel, str) and sizes is not None:
# Core resizing module
x = resize_1d(
x,
-2,
size=sizes[0],
scale=scales[0],
kernel=kernel,
sigma=sigma,
padding_type=padding_type,
antialiasing=antialiasing)
x = resize_1d(
x,
-1,
size=sizes[1],
scale=scales[1],
kernel=kernel,
sigma=sigma,
padding_type=padding_type,
antialiasing=antialiasing)
elif isinstance(kernel, torch.Tensor) and scale is not None:
x = downsampling_2d(x, kernel, scale=int(1 / scale))
x = reshape_output(x, b, c)
x = cast_output(x, dtype)
return x

136
utils/util_color_fix.py Normal file
View File

@@ -0,0 +1,136 @@
'''
# --------------------------------------------------------------------------------
# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
# --------------------------------------------------------------------------------
'''
import torch
from PIL import Image
from torch import Tensor
from torch.nn import functional as F
from torchvision.transforms import ToTensor, ToPILImage
from .util_image import rgb2ycbcrTorch, ycbcr2rgbTorch
def adain_color_fix(target: Image, source: Image):
# Convert images to tensors
to_tensor = ToTensor()
target_tensor = to_tensor(target).unsqueeze(0)
source_tensor = to_tensor(source).unsqueeze(0)
# Apply adaptive instance normalization
result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
return result_image
def wavelet_color_fix(target: Image, source: Image):
# Convert images to tensors
to_tensor = ToTensor()
target_tensor = to_tensor(target).unsqueeze(0)
source_tensor = to_tensor(source).unsqueeze(0)
# Apply wavelet reconstruction
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
return result_image
def calc_mean_std(feat: Tensor, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.'
b, c = size[:2]
feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
def wavelet_blur(image: Tensor, radius: int):
"""
Apply wavelet blur to the input tensor.
"""
# input shape: (1, 3, H, W)
# convolution kernel
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
# add channel dimensions to the kernel to make it a 4D tensor
kernel = kernel[None, None]
# repeat the kernel across all input channels
kernel = kernel.repeat(3, 1, 1, 1)
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
# apply convolution
output = F.conv2d(image, kernel, groups=3, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels=5):
"""
Apply wavelet decomposition to the input tensor.
This function only returns the low frequency & the high frequency.
"""
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq += (image - low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
"""
Apply wavelet decomposition, so that the content will have the same color as the style.
"""
# calculate the wavelet decomposition of the content feature
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq
# calculate the wavelet decomposition of the style feature
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq
# reconstruct the content feature with the style's high frequency
return content_high_freq + style_low_freq
def ycbcr_color_replace(content_feat:Tensor, style_feat:Tensor):
"""
Apply ycbcr decomposition, so that the content will have the same color as the style.
"""
content_y = rgb2ycbcrTorch(content_feat, only_y=True)
style_ycbcr = rgb2ycbcrTorch(style_feat, only_y=False)
target_ycbcr = torch.cat([content_y, style_ycbcr[:, 1:,]], dim=1)
target_rgb = ycbcr2rgbTorch(target_ycbcr)
return target_rgb

155
utils/util_common.py Normal file
View File

@@ -0,0 +1,155 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-02-06 10:34:59
import os
import random
import requests
import importlib
from pathlib import Path
from PIL import Image
def mkdir(dir_path, delete=False, parents=True):
import shutil
if not isinstance(dir_path, Path):
dir_path = Path(dir_path)
if delete:
if dir_path.exists():
shutil.rmtree(str(dir_path))
if not dir_path.exists():
dir_path.mkdir(parents=parents)
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def get_filenames(dir_path, exts=['png', 'jpg'], recursive=True):
'''
Get the file paths in the given folder.
param exts: list, e.g., ['png',]
return: list
'''
if not isinstance(dir_path, Path):
dir_path = Path(dir_path)
file_paths = []
for current_ext in exts:
if recursive:
file_paths.extend([str(x) for x in dir_path.glob('**/*.'+current_ext)])
else:
file_paths.extend([str(x) for x in dir_path.glob('*.'+current_ext)])
return file_paths
def readline_txt(txt_file):
txt_file = [txt_file, ] if isinstance(txt_file, str) else txt_file
out = []
for txt_file_current in txt_file:
with open(txt_file_current, 'r') as ff:
out.extend([x[:-1] for x in ff.readlines()])
return out
def scan_files_from_folder(dir_paths, exts, recursive=True):
'''
Scaning images from given folder.
Input:
dir_pathas: str or list.
exts: list
'''
exts = [exts, ] if isinstance(exts, str) else exts
dir_paths = [dir_paths, ] if isinstance(dir_paths, str) else dir_paths
file_paths = []
for current_dir in dir_paths:
current_dir = Path(current_dir) if not isinstance(current_dir, Path) else current_dir
for current_ext in exts:
if recursive:
search_flag = f"**/*.{current_ext}"
else:
search_flag = f"*.{current_ext}"
file_paths.extend(sorted([str(x) for x in Path(current_dir).glob(search_flag)]))
return file_paths
def write_path_to_txt(
dir_folder,
txt_path,
search_key,
num_files=None,
write_only_name=False,
write_only_stem=False,
shuffle=False,
):
'''
Scaning the files in the given folder and write them into a txt file
Input:
dir_folder: path of the target folder
txt_path: path to save the txt file
search_key: e.g., '*.png'
write_only_name: bool, only record the file names (including extension),
write_only_stem: bool, only record the file names (not including extension),
'''
txt_path = Path(txt_path) if not isinstance(txt_path, Path) else txt_path
dir_folder = Path(dir_folder) if not isinstance(dir_folder, Path) else dir_folder
if txt_path.exists():
txt_path.unlink()
if write_only_name:
path_list = sorted([str(x.name) for x in dir_folder.glob(search_key)])
elif write_only_stem:
path_list = sorted([str(x.stem) for x in dir_folder.glob(search_key)])
else:
path_list = sorted([str(x) for x in dir_folder.glob(search_key)])
if shuffle:
random.shuffle(path_list)
if num_files is not None:
path_list = path_list[:num_files]
with open(txt_path, mode='w') as ff:
for line in path_list:
ff.write(line+'\n')
def download_image_from_url(url, dir="./"):
# Download a file from a given URI, including minimal checks
# Download
f = str(Path(dir) / os.path.basename(url)) # filename
try:
with open(f, "wb") as file:
file.write(requests.get(url, timeout=10).content)
except:
print(f'Skip the url: {f}!')
# Rename (remove wildcard characters)
src = f # original name
for c in ["%20", "%", "*", "~", "(", ")"]:
f = f.replace(c, "_")
f = f[: f.index("?")] if "?" in f else f # new name
if src != f:
os.rename(src, f) # rename
# Add suffix (if missing)
if Path(f).suffix == "":
src = f # original name
try:
f += f".{Image.open(f).format.lower()}"
os.rename(src, f) # rename
except:
Path(f).unlink()

102
utils/util_ema.py Normal file
View File

@@ -0,0 +1,102 @@
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
else torch.tensor(-1, dtype=torch.int))
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace('.', '')
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def reset_num_updates(self):
del self.num_updates
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
"""
Copying the ema state (i.e., buffers) to the targeted model
Input:
model: targeted model
"""
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
Save the parameters of the targeted model into the temporary pool for restoring later.
Args:
parameters: parameters of the targeted model.
Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters from the temporaty pool (stored with the `store` method).
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
def resume(self, ckpt, num_updates):
"""
Resume from the targeted checkpoint, i.e., copying the checkpoints to ema buffers
Input:
model: targerted model
"""
self.register_buffer('num_updates', torch.tensor(num_updates, dtype=torch.int))
shadow_params = dict(self.named_buffers())
for key, value in ckpt.items():
try:
shadow_params[self.m_name2s_name[key]].data.copy_(value.data)
except:
if key.startswith('module') and key not in shadow_params:
key = key[7:]
shadow_params[self.m_name2s_name[key]].data.copy_(value.data)

1149
utils/util_image.py Normal file

File diff suppressed because it is too large Load Diff

98
utils/util_net.py Normal file
View File

@@ -0,0 +1,98 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2021-11-24 20:29:36
import math
import torch
from pathlib import Path
from copy import deepcopy
from collections import OrderedDict
import torch.nn.functional as F
def calculate_parameters(net):
out = 0
for param in net.parameters():
out += param.numel()
return out
def pad_input(x, mod):
h, w = x.shape[-2:]
bottom = int(math.ceil(h/mod)*mod -h)
right = int(math.ceil(w/mod)*mod - w)
x_pad = F.pad(x, pad=(0, right, 0, bottom), mode='reflect')
return x_pad
def forward_chop(net, x, net_kwargs=None, scale=1, shave=10, min_size=160000):
n_GPUs = 1
b, c, h, w = x.size()
h_half, w_half = h // 2, w // 2
h_size, w_size = h_half + shave, w_half + shave
lr_list = [
x[:, :, 0:h_size, 0:w_size],
x[:, :, 0:h_size, (w - w_size):w],
x[:, :, (h - h_size):h, 0:w_size],
x[:, :, (h - h_size):h, (w - w_size):w]]
if w_size * h_size < min_size:
sr_list = []
for i in range(0, 4, n_GPUs):
lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
if net_kwargs is None:
sr_batch = net(lr_batch)
else:
sr_batch = net(lr_batch, **net_kwargs)
sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
else:
sr_list = [
forward_chop(patch, shave=shave, min_size=min_size) \
for patch in lr_list
]
h, w = scale * h, scale * w
h_half, w_half = scale * h_half, scale * w_half
h_size, w_size = scale * h_size, scale * w_size
shave *= scale
output = x.new(b, c, h, w)
output[:, :, 0:h_half, 0:w_half] \
= sr_list[0][:, :, 0:h_half, 0:w_half]
output[:, :, 0:h_half, w_half:w] \
= sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
output[:, :, h_half:h, 0:w_half] \
= sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
output[:, :, h_half:h, w_half:w] \
= sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
return output
def measure_time(net, inputs, num_forward=100):
'''
Measuring the average runing time (seconds) for pytorch.
out = net(*inputs)
'''
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
with torch.set_grad_enabled(False):
for _ in range(num_forward):
out = net(*inputs)
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / 1000
def reload_model(model, ckpt):
module_flag = list(ckpt.keys())[0].startswith('module.')
compile_flag = '_orig_mod' in list(ckpt.keys())[0]
for source_key, source_value in model.state_dict().items():
target_key = source_key
if compile_flag and (not '_orig_mod.' in source_key):
target_key = '_orig_mod.' + target_key
if module_flag and (not source_key.startswith('module')):
target_key = 'module.' + target_key
assert target_key in ckpt
source_value.copy_(ckpt[target_key])

12
utils/util_ops.py Normal file
View File

@@ -0,0 +1,12 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2024-08-15 16:25:07
def append_dims(x, target_dims:int):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
return x[(...,) + (None,) * dims_to_append]

22
utils/util_opts.py Normal file
View File

@@ -0,0 +1,22 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2021-11-24 15:07:43
import argparse
def update_args(args_json, args_parser):
for arg in vars(args_parser):
args_json[arg] = getattr(args_parser, arg)
def str2bool(v):
"""
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
"""
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("boolean value expected")

28
utils/util_sisr.py Normal file
View File

@@ -0,0 +1,28 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2021-12-07 21:37:58
import cv2
import numpy as np
def modcrop(im, sf):
h, w = im.shape[:2]
h -= (h % sf)
w -= (w % sf)
return im[:h, :w,]
#-----------------------------------------Transform--------------------------------------------
class Bicubic:
def __init__(self, scale=None, out_shape=None, matlab_mode=True):
self.scale = scale
self.out_shape = out_shape
def __call__(self, im):
out = cv2.resize(
im,
dsize=self.out_shape,
fx=self.scale,
fy=self.scale,
interpolation=cv2.INTER_CUBIC,
)
return out