mirror of
https://github.com/aljazceru/InvSR.git
synced 2026-02-20 06:04:19 +01:00
first commit
This commit is contained in:
BIN
utils/.DS_Store
vendored
Normal file
BIN
utils/.DS_Store
vendored
Normal file
Binary file not shown.
5
utils/__init__.py
Normal file
5
utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
# Power by Zongsheng Yue 2022-01-18 11:40:23
|
||||
|
||||
|
||||
428
utils/resize.py
Normal file
428
utils/resize.py
Normal file
@@ -0,0 +1,428 @@
|
||||
"""
|
||||
A standalone PyTorch implementation for fast and efficient bicubic resampling.
|
||||
The resulting values are the same to MATLAB function imresize('bicubic').
|
||||
## Author: Sanghyun Son
|
||||
## Email: sonsang35@gmail.com (primary), thstkdgus35@snu.ac.kr (secondary)
|
||||
## Version: 1.2.0
|
||||
## Last update: July 9th, 2020 (KST)
|
||||
Dependency: torch
|
||||
Example::
|
||||
>>> import torch
|
||||
>>> import core
|
||||
>>> x = torch.arange(16).float().view(1, 1, 4, 4)
|
||||
>>> y = core.imresize(x, sizes=(3, 3))
|
||||
>>> print(y)
|
||||
tensor([[[[ 0.7506, 2.1004, 3.4503],
|
||||
[ 6.1505, 7.5000, 8.8499],
|
||||
[11.5497, 12.8996, 14.2494]]]])
|
||||
"""
|
||||
|
||||
import math
|
||||
import typing
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
__all__ = ['imresize']
|
||||
|
||||
_I = typing.Optional[int]
|
||||
_D = typing.Optional[torch.dtype]
|
||||
|
||||
|
||||
def nearest_contribution(x: torch.Tensor) -> torch.Tensor:
|
||||
range_around_0 = torch.logical_and(x.gt(-0.5), x.le(0.5))
|
||||
cont = range_around_0.to(dtype=x.dtype)
|
||||
return cont
|
||||
|
||||
|
||||
def linear_contribution(x: torch.Tensor) -> torch.Tensor:
|
||||
ax = x.abs()
|
||||
range_01 = ax.le(1)
|
||||
cont = (1 - ax) * range_01.to(dtype=x.dtype)
|
||||
return cont
|
||||
|
||||
|
||||
def cubic_contribution(x: torch.Tensor, a: float = -0.5) -> torch.Tensor:
|
||||
ax = x.abs()
|
||||
ax2 = ax * ax
|
||||
ax3 = ax * ax2
|
||||
|
||||
range_01 = ax.le(1)
|
||||
range_12 = torch.logical_and(ax.gt(1), ax.le(2))
|
||||
|
||||
cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1
|
||||
cont_01 = cont_01 * range_01.to(dtype=x.dtype)
|
||||
|
||||
cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a)
|
||||
cont_12 = cont_12 * range_12.to(dtype=x.dtype)
|
||||
|
||||
cont = cont_01 + cont_12
|
||||
return cont
|
||||
|
||||
|
||||
def gaussian_contribution(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor:
|
||||
range_3sigma = (x.abs() <= 3 * sigma + 1)
|
||||
# Normalization will be done after
|
||||
cont = torch.exp(-x.pow(2) / (2 * sigma**2))
|
||||
cont = cont * range_3sigma.to(dtype=x.dtype)
|
||||
return cont
|
||||
|
||||
|
||||
def discrete_kernel(kernel: str, scale: float, antialiasing: bool = True) -> torch.Tensor:
|
||||
'''
|
||||
For downsampling with integer scale only.
|
||||
'''
|
||||
downsampling_factor = int(1 / scale)
|
||||
if kernel == 'cubic':
|
||||
kernel_size_orig = 4
|
||||
else:
|
||||
raise ValueError('Pass!')
|
||||
|
||||
if antialiasing:
|
||||
kernel_size = kernel_size_orig * downsampling_factor
|
||||
else:
|
||||
kernel_size = kernel_size_orig
|
||||
|
||||
if downsampling_factor % 2 == 0:
|
||||
a = kernel_size_orig * (0.5 - 1 / (2 * kernel_size))
|
||||
else:
|
||||
kernel_size -= 1
|
||||
a = kernel_size_orig * (0.5 - 1 / (kernel_size + 1))
|
||||
|
||||
with torch.no_grad():
|
||||
r = torch.linspace(-a, a, steps=kernel_size)
|
||||
k = cubic_contribution(r).view(-1, 1)
|
||||
k = torch.matmul(k, k.t())
|
||||
k /= k.sum()
|
||||
|
||||
return k
|
||||
|
||||
|
||||
def reflect_padding(x: torch.Tensor, dim: int, pad_pre: int, pad_post: int) -> torch.Tensor:
|
||||
'''
|
||||
Apply reflect padding to the given Tensor.
|
||||
Note that it is slightly different from the PyTorch functional.pad,
|
||||
where boundary elements are used only once.
|
||||
Instead, we follow the MATLAB implementation
|
||||
which uses boundary elements twice.
|
||||
For example,
|
||||
[a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation,
|
||||
while our implementation yields [a, a, b, c, d, d].
|
||||
'''
|
||||
b, c, h, w = x.size()
|
||||
if dim == 2 or dim == -2:
|
||||
padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w)
|
||||
padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x)
|
||||
for p in range(pad_pre):
|
||||
padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :])
|
||||
for p in range(pad_post):
|
||||
padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :])
|
||||
else:
|
||||
padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post)
|
||||
padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x)
|
||||
for p in range(pad_pre):
|
||||
padding_buffer[..., pad_pre - p - 1].copy_(x[..., p])
|
||||
for p in range(pad_post):
|
||||
padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)])
|
||||
|
||||
return padding_buffer
|
||||
|
||||
|
||||
def padding(x: torch.Tensor,
|
||||
dim: int,
|
||||
pad_pre: int,
|
||||
pad_post: int,
|
||||
padding_type: typing.Optional[str] = 'reflect') -> torch.Tensor:
|
||||
if padding_type is None:
|
||||
return x
|
||||
elif padding_type == 'reflect':
|
||||
x_pad = reflect_padding(x, dim, pad_pre, pad_post)
|
||||
else:
|
||||
raise ValueError('{} padding is not supported!'.format(padding_type))
|
||||
|
||||
return x_pad
|
||||
|
||||
|
||||
def get_padding(base: torch.Tensor, kernel_size: int, x_size: int) -> typing.Tuple[int, int, torch.Tensor]:
|
||||
base = base.long()
|
||||
r_min = base.min()
|
||||
r_max = base.max() + kernel_size - 1
|
||||
|
||||
if r_min <= 0:
|
||||
pad_pre = -r_min
|
||||
pad_pre = pad_pre.item()
|
||||
base += pad_pre
|
||||
else:
|
||||
pad_pre = 0
|
||||
|
||||
if r_max >= x_size:
|
||||
pad_post = r_max - x_size + 1
|
||||
pad_post = pad_post.item()
|
||||
else:
|
||||
pad_post = 0
|
||||
|
||||
return pad_pre, pad_post, base
|
||||
|
||||
|
||||
def get_weight(dist: torch.Tensor,
|
||||
kernel_size: int,
|
||||
kernel: str = 'cubic',
|
||||
sigma: float = 2.0,
|
||||
antialiasing_factor: float = 1) -> torch.Tensor:
|
||||
buffer_pos = dist.new_zeros(kernel_size, len(dist))
|
||||
for idx, buffer_sub in enumerate(buffer_pos):
|
||||
buffer_sub.copy_(dist - idx)
|
||||
|
||||
# Expand (downsampling) / Shrink (upsampling) the receptive field.
|
||||
buffer_pos *= antialiasing_factor
|
||||
if kernel == 'cubic':
|
||||
weight = cubic_contribution(buffer_pos)
|
||||
elif kernel == 'gaussian':
|
||||
weight = gaussian_contribution(buffer_pos, sigma=sigma)
|
||||
else:
|
||||
raise ValueError('{} kernel is not supported!'.format(kernel))
|
||||
|
||||
weight /= weight.sum(dim=0, keepdim=True)
|
||||
return weight
|
||||
|
||||
|
||||
def reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor:
|
||||
# Resize height
|
||||
if dim == 2 or dim == -2:
|
||||
k = (kernel_size, 1)
|
||||
h_out = x.size(-2) - kernel_size + 1
|
||||
w_out = x.size(-1)
|
||||
# Resize width
|
||||
else:
|
||||
k = (1, kernel_size)
|
||||
h_out = x.size(-2)
|
||||
w_out = x.size(-1) - kernel_size + 1
|
||||
|
||||
unfold = F.unfold(x, k)
|
||||
unfold = unfold.view(unfold.size(0), -1, h_out, w_out)
|
||||
return unfold
|
||||
|
||||
|
||||
def reshape_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, int, int]:
|
||||
if x.dim() == 4:
|
||||
b, c, h, w = x.size()
|
||||
elif x.dim() == 3:
|
||||
c, h, w = x.size()
|
||||
b = None
|
||||
elif x.dim() == 2:
|
||||
h, w = x.size()
|
||||
b = c = None
|
||||
else:
|
||||
raise ValueError('{}-dim Tensor is not supported!'.format(x.dim()))
|
||||
|
||||
x = x.view(-1, 1, h, w)
|
||||
return x, b, c, h, w
|
||||
|
||||
|
||||
def reshape_output(x: torch.Tensor, b: _I, c: _I) -> torch.Tensor:
|
||||
rh = x.size(-2)
|
||||
rw = x.size(-1)
|
||||
# Back to the original dimension
|
||||
if b is not None:
|
||||
x = x.view(b, c, rh, rw) # 4-dim
|
||||
else:
|
||||
if c is not None:
|
||||
x = x.view(c, rh, rw) # 3-dim
|
||||
else:
|
||||
x = x.view(rh, rw) # 2-dim
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def cast_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]:
|
||||
if x.dtype != torch.float32 or x.dtype != torch.float64:
|
||||
dtype = x.dtype
|
||||
x = x.float()
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
return x, dtype
|
||||
|
||||
|
||||
def cast_output(x: torch.Tensor, dtype: _D) -> torch.Tensor:
|
||||
if dtype is not None:
|
||||
if not dtype.is_floating_point:
|
||||
x = x - x.detach() + x.round()
|
||||
# To prevent over/underflow when converting types
|
||||
if dtype is torch.uint8:
|
||||
x = x.clamp(0, 255)
|
||||
|
||||
x = x.to(dtype=dtype)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def resize_1d(x: torch.Tensor,
|
||||
dim: int,
|
||||
size: int,
|
||||
scale: float,
|
||||
kernel: str = 'cubic',
|
||||
sigma: float = 2.0,
|
||||
padding_type: str = 'reflect',
|
||||
antialiasing: bool = True) -> torch.Tensor:
|
||||
'''
|
||||
Args:
|
||||
x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W).
|
||||
dim (int):
|
||||
scale (float):
|
||||
size (int):
|
||||
Return:
|
||||
'''
|
||||
# Identity case
|
||||
if scale == 1:
|
||||
return x
|
||||
|
||||
# Default bicubic kernel with antialiasing (only when downsampling)
|
||||
if kernel == 'cubic':
|
||||
kernel_size = 4
|
||||
else:
|
||||
kernel_size = math.floor(6 * sigma)
|
||||
|
||||
if antialiasing and (scale < 1):
|
||||
antialiasing_factor = scale
|
||||
kernel_size = math.ceil(kernel_size / antialiasing_factor)
|
||||
else:
|
||||
antialiasing_factor = 1
|
||||
|
||||
# We allow margin to both sizes
|
||||
kernel_size += 2
|
||||
|
||||
# Weights only depend on the shape of input and output,
|
||||
# so we do not calculate gradients here.
|
||||
with torch.no_grad():
|
||||
pos = torch.linspace(
|
||||
0,
|
||||
size - 1,
|
||||
steps=size,
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
pos = (pos + 0.5) / scale - 0.5
|
||||
base = pos.floor() - (kernel_size // 2) + 1
|
||||
dist = pos - base
|
||||
weight = get_weight(
|
||||
dist,
|
||||
kernel_size,
|
||||
kernel=kernel,
|
||||
sigma=sigma,
|
||||
antialiasing_factor=antialiasing_factor,
|
||||
)
|
||||
pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim))
|
||||
|
||||
# To backpropagate through x
|
||||
x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type)
|
||||
unfold = reshape_tensor(x_pad, dim, kernel_size)
|
||||
# Subsampling first
|
||||
if dim == 2 or dim == -2:
|
||||
sample = unfold[..., base, :]
|
||||
weight = weight.view(1, kernel_size, sample.size(2), 1)
|
||||
else:
|
||||
sample = unfold[..., base]
|
||||
weight = weight.view(1, kernel_size, 1, sample.size(3))
|
||||
|
||||
# Apply the kernel
|
||||
x = sample * weight
|
||||
x = x.sum(dim=1, keepdim=True)
|
||||
return x
|
||||
|
||||
|
||||
def downsampling_2d(x: torch.Tensor, k: torch.Tensor, scale: int, padding_type: str = 'reflect') -> torch.Tensor:
|
||||
c = x.size(1)
|
||||
k_h = k.size(-2)
|
||||
k_w = k.size(-1)
|
||||
|
||||
k = k.to(dtype=x.dtype, device=x.device)
|
||||
k = k.view(1, 1, k_h, k_w)
|
||||
k = k.repeat(c, c, 1, 1)
|
||||
e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False)
|
||||
e = e.view(c, c, 1, 1)
|
||||
k = k * e
|
||||
|
||||
pad_h = (k_h - scale) // 2
|
||||
pad_w = (k_w - scale) // 2
|
||||
x = padding(x, -2, pad_h, pad_h, padding_type=padding_type)
|
||||
x = padding(x, -1, pad_w, pad_w, padding_type=padding_type)
|
||||
y = F.conv2d(x, k, padding=0, stride=scale)
|
||||
return y
|
||||
|
||||
|
||||
def imresize(x: torch.Tensor,
|
||||
scale: typing.Optional[float] = None,
|
||||
sizes: typing.Optional[typing.Tuple[int, int]] = None,
|
||||
kernel: typing.Union[str, torch.Tensor] = 'cubic',
|
||||
sigma: float = 2,
|
||||
rotation_degree: float = 0,
|
||||
padding_type: str = 'reflect',
|
||||
antialiasing: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor):
|
||||
scale (float):
|
||||
sizes (tuple(int, int)):
|
||||
kernel (str, default='cubic'):
|
||||
sigma (float, default=2):
|
||||
rotation_degree (float, default=0):
|
||||
padding_type (str, default='reflect'):
|
||||
antialiasing (bool, default=True):
|
||||
Return:
|
||||
torch.Tensor:
|
||||
"""
|
||||
if scale is None and sizes is None:
|
||||
raise ValueError('One of scale or sizes must be specified!')
|
||||
if scale is not None and sizes is not None:
|
||||
raise ValueError('Please specify scale or sizes to avoid conflict!')
|
||||
|
||||
x, b, c, h, w = reshape_input(x)
|
||||
|
||||
if sizes is None and scale is not None:
|
||||
'''
|
||||
# Check if we can apply the convolution algorithm
|
||||
scale_inv = 1 / scale
|
||||
if isinstance(kernel, str) and scale_inv.is_integer():
|
||||
kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing)
|
||||
elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer():
|
||||
raise ValueError(
|
||||
'An integer downsampling factor '
|
||||
'should be used with a predefined kernel!'
|
||||
)
|
||||
'''
|
||||
# Determine output size
|
||||
sizes = (math.ceil(h * scale), math.ceil(w * scale))
|
||||
scales = (scale, scale)
|
||||
|
||||
if scale is None and sizes is not None:
|
||||
scales = (sizes[0] / h, sizes[1] / w)
|
||||
|
||||
x, dtype = cast_input(x)
|
||||
|
||||
if isinstance(kernel, str) and sizes is not None:
|
||||
# Core resizing module
|
||||
x = resize_1d(
|
||||
x,
|
||||
-2,
|
||||
size=sizes[0],
|
||||
scale=scales[0],
|
||||
kernel=kernel,
|
||||
sigma=sigma,
|
||||
padding_type=padding_type,
|
||||
antialiasing=antialiasing)
|
||||
x = resize_1d(
|
||||
x,
|
||||
-1,
|
||||
size=sizes[1],
|
||||
scale=scales[1],
|
||||
kernel=kernel,
|
||||
sigma=sigma,
|
||||
padding_type=padding_type,
|
||||
antialiasing=antialiasing)
|
||||
elif isinstance(kernel, torch.Tensor) and scale is not None:
|
||||
x = downsampling_2d(x, kernel, scale=int(1 / scale))
|
||||
|
||||
x = reshape_output(x, b, c)
|
||||
x = cast_output(x, dtype)
|
||||
return x
|
||||
136
utils/util_color_fix.py
Normal file
136
utils/util_color_fix.py
Normal file
@@ -0,0 +1,136 @@
|
||||
'''
|
||||
# --------------------------------------------------------------------------------
|
||||
# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
|
||||
# --------------------------------------------------------------------------------
|
||||
'''
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from torchvision.transforms import ToTensor, ToPILImage
|
||||
|
||||
from .util_image import rgb2ycbcrTorch, ycbcr2rgbTorch
|
||||
|
||||
def adain_color_fix(target: Image, source: Image):
|
||||
# Convert images to tensors
|
||||
to_tensor = ToTensor()
|
||||
target_tensor = to_tensor(target).unsqueeze(0)
|
||||
source_tensor = to_tensor(source).unsqueeze(0)
|
||||
|
||||
# Apply adaptive instance normalization
|
||||
result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
|
||||
|
||||
# Convert tensor back to image
|
||||
to_image = ToPILImage()
|
||||
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
||||
|
||||
return result_image
|
||||
|
||||
def wavelet_color_fix(target: Image, source: Image):
|
||||
# Convert images to tensors
|
||||
to_tensor = ToTensor()
|
||||
target_tensor = to_tensor(target).unsqueeze(0)
|
||||
source_tensor = to_tensor(source).unsqueeze(0)
|
||||
|
||||
# Apply wavelet reconstruction
|
||||
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
|
||||
|
||||
# Convert tensor back to image
|
||||
to_image = ToPILImage()
|
||||
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
||||
|
||||
return result_image
|
||||
|
||||
def calc_mean_std(feat: Tensor, eps=1e-5):
|
||||
"""Calculate mean and std for adaptive_instance_normalization.
|
||||
Args:
|
||||
feat (Tensor): 4D tensor.
|
||||
eps (float): A small value added to the variance to avoid
|
||||
divide-by-zero. Default: 1e-5.
|
||||
"""
|
||||
size = feat.size()
|
||||
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
||||
b, c = size[:2]
|
||||
feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
|
||||
feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
|
||||
feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
|
||||
return feat_mean, feat_std
|
||||
|
||||
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
|
||||
"""Adaptive instance normalization.
|
||||
Adjust the reference features to have the similar color and illuminations
|
||||
as those in the degradate features.
|
||||
Args:
|
||||
content_feat (Tensor): The reference feature.
|
||||
style_feat (Tensor): The degradate features.
|
||||
"""
|
||||
size = content_feat.size()
|
||||
style_mean, style_std = calc_mean_std(style_feat)
|
||||
content_mean, content_std = calc_mean_std(content_feat)
|
||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
||||
|
||||
def wavelet_blur(image: Tensor, radius: int):
|
||||
"""
|
||||
Apply wavelet blur to the input tensor.
|
||||
"""
|
||||
# input shape: (1, 3, H, W)
|
||||
# convolution kernel
|
||||
kernel_vals = [
|
||||
[0.0625, 0.125, 0.0625],
|
||||
[0.125, 0.25, 0.125],
|
||||
[0.0625, 0.125, 0.0625],
|
||||
]
|
||||
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
||||
# add channel dimensions to the kernel to make it a 4D tensor
|
||||
kernel = kernel[None, None]
|
||||
# repeat the kernel across all input channels
|
||||
kernel = kernel.repeat(3, 1, 1, 1)
|
||||
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
||||
# apply convolution
|
||||
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
||||
return output
|
||||
|
||||
def wavelet_decomposition(image: Tensor, levels=5):
|
||||
"""
|
||||
Apply wavelet decomposition to the input tensor.
|
||||
This function only returns the low frequency & the high frequency.
|
||||
"""
|
||||
high_freq = torch.zeros_like(image)
|
||||
for i in range(levels):
|
||||
radius = 2 ** i
|
||||
low_freq = wavelet_blur(image, radius)
|
||||
high_freq += (image - low_freq)
|
||||
image = low_freq
|
||||
|
||||
return high_freq, low_freq
|
||||
|
||||
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
|
||||
"""
|
||||
Apply wavelet decomposition, so that the content will have the same color as the style.
|
||||
"""
|
||||
# calculate the wavelet decomposition of the content feature
|
||||
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
||||
del content_low_freq
|
||||
# calculate the wavelet decomposition of the style feature
|
||||
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
||||
del style_high_freq
|
||||
# reconstruct the content feature with the style's high frequency
|
||||
return content_high_freq + style_low_freq
|
||||
|
||||
def ycbcr_color_replace(content_feat:Tensor, style_feat:Tensor):
|
||||
"""
|
||||
Apply ycbcr decomposition, so that the content will have the same color as the style.
|
||||
"""
|
||||
content_y = rgb2ycbcrTorch(content_feat, only_y=True)
|
||||
style_ycbcr = rgb2ycbcrTorch(style_feat, only_y=False)
|
||||
|
||||
target_ycbcr = torch.cat([content_y, style_ycbcr[:, 1:,]], dim=1)
|
||||
|
||||
target_rgb = ycbcr2rgbTorch(target_ycbcr)
|
||||
|
||||
return target_rgb
|
||||
|
||||
|
||||
155
utils/util_common.py
Normal file
155
utils/util_common.py
Normal file
@@ -0,0 +1,155 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
# Power by Zongsheng Yue 2022-02-06 10:34:59
|
||||
|
||||
import os
|
||||
import random
|
||||
import requests
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
def mkdir(dir_path, delete=False, parents=True):
|
||||
import shutil
|
||||
if not isinstance(dir_path, Path):
|
||||
dir_path = Path(dir_path)
|
||||
if delete:
|
||||
if dir_path.exists():
|
||||
shutil.rmtree(str(dir_path))
|
||||
if not dir_path.exists():
|
||||
dir_path.mkdir(parents=parents)
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||
|
||||
def str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
def get_filenames(dir_path, exts=['png', 'jpg'], recursive=True):
|
||||
'''
|
||||
Get the file paths in the given folder.
|
||||
param exts: list, e.g., ['png',]
|
||||
return: list
|
||||
'''
|
||||
if not isinstance(dir_path, Path):
|
||||
dir_path = Path(dir_path)
|
||||
|
||||
file_paths = []
|
||||
for current_ext in exts:
|
||||
if recursive:
|
||||
file_paths.extend([str(x) for x in dir_path.glob('**/*.'+current_ext)])
|
||||
else:
|
||||
file_paths.extend([str(x) for x in dir_path.glob('*.'+current_ext)])
|
||||
|
||||
return file_paths
|
||||
|
||||
def readline_txt(txt_file):
|
||||
txt_file = [txt_file, ] if isinstance(txt_file, str) else txt_file
|
||||
out = []
|
||||
for txt_file_current in txt_file:
|
||||
with open(txt_file_current, 'r') as ff:
|
||||
out.extend([x[:-1] for x in ff.readlines()])
|
||||
|
||||
return out
|
||||
|
||||
def scan_files_from_folder(dir_paths, exts, recursive=True):
|
||||
'''
|
||||
Scaning images from given folder.
|
||||
Input:
|
||||
dir_pathas: str or list.
|
||||
exts: list
|
||||
'''
|
||||
exts = [exts, ] if isinstance(exts, str) else exts
|
||||
dir_paths = [dir_paths, ] if isinstance(dir_paths, str) else dir_paths
|
||||
|
||||
file_paths = []
|
||||
for current_dir in dir_paths:
|
||||
current_dir = Path(current_dir) if not isinstance(current_dir, Path) else current_dir
|
||||
for current_ext in exts:
|
||||
if recursive:
|
||||
search_flag = f"**/*.{current_ext}"
|
||||
else:
|
||||
search_flag = f"*.{current_ext}"
|
||||
file_paths.extend(sorted([str(x) for x in Path(current_dir).glob(search_flag)]))
|
||||
|
||||
return file_paths
|
||||
|
||||
def write_path_to_txt(
|
||||
dir_folder,
|
||||
txt_path,
|
||||
search_key,
|
||||
num_files=None,
|
||||
write_only_name=False,
|
||||
write_only_stem=False,
|
||||
shuffle=False,
|
||||
):
|
||||
'''
|
||||
Scaning the files in the given folder and write them into a txt file
|
||||
Input:
|
||||
dir_folder: path of the target folder
|
||||
txt_path: path to save the txt file
|
||||
search_key: e.g., '*.png'
|
||||
write_only_name: bool, only record the file names (including extension),
|
||||
write_only_stem: bool, only record the file names (not including extension),
|
||||
'''
|
||||
txt_path = Path(txt_path) if not isinstance(txt_path, Path) else txt_path
|
||||
dir_folder = Path(dir_folder) if not isinstance(dir_folder, Path) else dir_folder
|
||||
if txt_path.exists():
|
||||
txt_path.unlink()
|
||||
if write_only_name:
|
||||
path_list = sorted([str(x.name) for x in dir_folder.glob(search_key)])
|
||||
elif write_only_stem:
|
||||
path_list = sorted([str(x.stem) for x in dir_folder.glob(search_key)])
|
||||
else:
|
||||
path_list = sorted([str(x) for x in dir_folder.glob(search_key)])
|
||||
if shuffle:
|
||||
random.shuffle(path_list)
|
||||
if num_files is not None:
|
||||
path_list = path_list[:num_files]
|
||||
with open(txt_path, mode='w') as ff:
|
||||
for line in path_list:
|
||||
ff.write(line+'\n')
|
||||
|
||||
def download_image_from_url(url, dir="./"):
|
||||
# Download a file from a given URI, including minimal checks
|
||||
|
||||
# Download
|
||||
f = str(Path(dir) / os.path.basename(url)) # filename
|
||||
try:
|
||||
with open(f, "wb") as file:
|
||||
file.write(requests.get(url, timeout=10).content)
|
||||
except:
|
||||
print(f'Skip the url: {f}!')
|
||||
|
||||
# Rename (remove wildcard characters)
|
||||
src = f # original name
|
||||
for c in ["%20", "%", "*", "~", "(", ")"]:
|
||||
f = f.replace(c, "_")
|
||||
f = f[: f.index("?")] if "?" in f else f # new name
|
||||
if src != f:
|
||||
os.rename(src, f) # rename
|
||||
|
||||
# Add suffix (if missing)
|
||||
if Path(f).suffix == "":
|
||||
src = f # original name
|
||||
try:
|
||||
f += f".{Image.open(f).format.lower()}"
|
||||
os.rename(src, f) # rename
|
||||
except:
|
||||
Path(f).unlink()
|
||||
102
utils/util_ema.py
Normal file
102
utils/util_ema.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LitEma(nn.Module):
|
||||
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
||||
super().__init__()
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
|
||||
self.m_name2s_name = {}
|
||||
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
|
||||
else torch.tensor(-1, dtype=torch.int))
|
||||
|
||||
for name, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
# remove as '.'-character is not allowed in buffers
|
||||
s_name = name.replace('.', '')
|
||||
self.m_name2s_name.update({name: s_name})
|
||||
self.register_buffer(s_name, p.clone().detach().data)
|
||||
|
||||
self.collected_params = []
|
||||
|
||||
def reset_num_updates(self):
|
||||
del self.num_updates
|
||||
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
|
||||
|
||||
def forward(self, model):
|
||||
decay = self.decay
|
||||
|
||||
if self.num_updates >= 0:
|
||||
self.num_updates += 1
|
||||
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
||||
|
||||
one_minus_decay = 1.0 - decay
|
||||
|
||||
with torch.no_grad():
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
sname = self.m_name2s_name[key]
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
"""
|
||||
Copying the ema state (i.e., buffers) to the targeted model
|
||||
Input:
|
||||
model: targeted model
|
||||
"""
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
Save the parameters of the targeted model into the temporary pool for restoring later.
|
||||
Args:
|
||||
parameters: parameters of the targeted model.
|
||||
Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored.
|
||||
"""
|
||||
self.collected_params = [param.clone() for param in parameters]
|
||||
|
||||
def restore(self, parameters):
|
||||
"""
|
||||
Restore the parameters from the temporaty pool (stored with the `store` method).
|
||||
Useful to validate the model with EMA parameters without affecting the
|
||||
original optimization process. Store the parameters before the
|
||||
`copy_to` method. After validation (or model saving), use this to
|
||||
restore the former parameters.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored parameters.
|
||||
"""
|
||||
for c_param, param in zip(self.collected_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
||||
|
||||
def resume(self, ckpt, num_updates):
|
||||
"""
|
||||
Resume from the targeted checkpoint, i.e., copying the checkpoints to ema buffers
|
||||
Input:
|
||||
model: targerted model
|
||||
"""
|
||||
self.register_buffer('num_updates', torch.tensor(num_updates, dtype=torch.int))
|
||||
|
||||
shadow_params = dict(self.named_buffers())
|
||||
for key, value in ckpt.items():
|
||||
try:
|
||||
shadow_params[self.m_name2s_name[key]].data.copy_(value.data)
|
||||
except:
|
||||
if key.startswith('module') and key not in shadow_params:
|
||||
key = key[7:]
|
||||
shadow_params[self.m_name2s_name[key]].data.copy_(value.data)
|
||||
1149
utils/util_image.py
Normal file
1149
utils/util_image.py
Normal file
File diff suppressed because it is too large
Load Diff
98
utils/util_net.py
Normal file
98
utils/util_net.py
Normal file
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
# Power by Zongsheng Yue 2021-11-24 20:29:36
|
||||
|
||||
import math
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from copy import deepcopy
|
||||
from collections import OrderedDict
|
||||
import torch.nn.functional as F
|
||||
|
||||
def calculate_parameters(net):
|
||||
out = 0
|
||||
for param in net.parameters():
|
||||
out += param.numel()
|
||||
return out
|
||||
|
||||
def pad_input(x, mod):
|
||||
h, w = x.shape[-2:]
|
||||
bottom = int(math.ceil(h/mod)*mod -h)
|
||||
right = int(math.ceil(w/mod)*mod - w)
|
||||
x_pad = F.pad(x, pad=(0, right, 0, bottom), mode='reflect')
|
||||
return x_pad
|
||||
|
||||
def forward_chop(net, x, net_kwargs=None, scale=1, shave=10, min_size=160000):
|
||||
n_GPUs = 1
|
||||
b, c, h, w = x.size()
|
||||
h_half, w_half = h // 2, w // 2
|
||||
h_size, w_size = h_half + shave, w_half + shave
|
||||
lr_list = [
|
||||
x[:, :, 0:h_size, 0:w_size],
|
||||
x[:, :, 0:h_size, (w - w_size):w],
|
||||
x[:, :, (h - h_size):h, 0:w_size],
|
||||
x[:, :, (h - h_size):h, (w - w_size):w]]
|
||||
|
||||
if w_size * h_size < min_size:
|
||||
sr_list = []
|
||||
for i in range(0, 4, n_GPUs):
|
||||
lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
|
||||
if net_kwargs is None:
|
||||
sr_batch = net(lr_batch)
|
||||
else:
|
||||
sr_batch = net(lr_batch, **net_kwargs)
|
||||
sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
|
||||
else:
|
||||
sr_list = [
|
||||
forward_chop(patch, shave=shave, min_size=min_size) \
|
||||
for patch in lr_list
|
||||
]
|
||||
|
||||
h, w = scale * h, scale * w
|
||||
h_half, w_half = scale * h_half, scale * w_half
|
||||
h_size, w_size = scale * h_size, scale * w_size
|
||||
shave *= scale
|
||||
|
||||
output = x.new(b, c, h, w)
|
||||
output[:, :, 0:h_half, 0:w_half] \
|
||||
= sr_list[0][:, :, 0:h_half, 0:w_half]
|
||||
output[:, :, 0:h_half, w_half:w] \
|
||||
= sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
|
||||
output[:, :, h_half:h, 0:w_half] \
|
||||
= sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
|
||||
output[:, :, h_half:h, w_half:w] \
|
||||
= sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
|
||||
|
||||
return output
|
||||
|
||||
def measure_time(net, inputs, num_forward=100):
|
||||
'''
|
||||
Measuring the average runing time (seconds) for pytorch.
|
||||
out = net(*inputs)
|
||||
'''
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start.record()
|
||||
with torch.set_grad_enabled(False):
|
||||
for _ in range(num_forward):
|
||||
out = net(*inputs)
|
||||
end.record()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return start.elapsed_time(end) / 1000
|
||||
|
||||
def reload_model(model, ckpt):
|
||||
module_flag = list(ckpt.keys())[0].startswith('module.')
|
||||
compile_flag = '_orig_mod' in list(ckpt.keys())[0]
|
||||
|
||||
for source_key, source_value in model.state_dict().items():
|
||||
target_key = source_key
|
||||
if compile_flag and (not '_orig_mod.' in source_key):
|
||||
target_key = '_orig_mod.' + target_key
|
||||
if module_flag and (not source_key.startswith('module')):
|
||||
target_key = 'module.' + target_key
|
||||
|
||||
assert target_key in ckpt
|
||||
source_value.copy_(ckpt[target_key])
|
||||
12
utils/util_ops.py
Normal file
12
utils/util_ops.py
Normal file
@@ -0,0 +1,12 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
# Power by Zongsheng Yue 2024-08-15 16:25:07
|
||||
|
||||
def append_dims(x, target_dims:int):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(
|
||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
||||
)
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
22
utils/util_opts.py
Normal file
22
utils/util_opts.py
Normal file
@@ -0,0 +1,22 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
# Power by Zongsheng Yue 2021-11-24 15:07:43
|
||||
|
||||
import argparse
|
||||
|
||||
def update_args(args_json, args_parser):
|
||||
for arg in vars(args_parser):
|
||||
args_json[arg] = getattr(args_parser, arg)
|
||||
|
||||
def str2bool(v):
|
||||
"""
|
||||
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
|
||||
"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("boolean value expected")
|
||||
28
utils/util_sisr.py
Normal file
28
utils/util_sisr.py
Normal file
@@ -0,0 +1,28 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
# Power by Zongsheng Yue 2021-12-07 21:37:58
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def modcrop(im, sf):
|
||||
h, w = im.shape[:2]
|
||||
h -= (h % sf)
|
||||
w -= (w % sf)
|
||||
return im[:h, :w,]
|
||||
|
||||
#-----------------------------------------Transform--------------------------------------------
|
||||
class Bicubic:
|
||||
def __init__(self, scale=None, out_shape=None, matlab_mode=True):
|
||||
self.scale = scale
|
||||
self.out_shape = out_shape
|
||||
|
||||
def __call__(self, im):
|
||||
out = cv2.resize(
|
||||
im,
|
||||
dsize=self.out_shape,
|
||||
fx=self.scale,
|
||||
fy=self.scale,
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
return out
|
||||
Reference in New Issue
Block a user