mirror of
https://github.com/aljazceru/InvSR.git
synced 2026-02-23 15:44:30 +01:00
first commit
This commit is contained in:
47
basicsr/utils/__init__.py
Normal file
47
basicsr/utils/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb
|
||||
from .diffjpeg import DiffJPEG
|
||||
from .file_client import FileClient
|
||||
from .img_process_util import USMSharp, usm_sharp
|
||||
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
|
||||
from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
|
||||
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
|
||||
from .options import yaml_load
|
||||
|
||||
__all__ = [
|
||||
# color_util.py
|
||||
'bgr2ycbcr',
|
||||
'rgb2ycbcr',
|
||||
'rgb2ycbcr_pt',
|
||||
'ycbcr2bgr',
|
||||
'ycbcr2rgb',
|
||||
# file_client.py
|
||||
'FileClient',
|
||||
# img_util.py
|
||||
'img2tensor',
|
||||
'tensor2img',
|
||||
'imfrombytes',
|
||||
'imwrite',
|
||||
'crop_border',
|
||||
# logger.py
|
||||
'MessageLogger',
|
||||
'AvgTimer',
|
||||
'init_tb_logger',
|
||||
'init_wandb_logger',
|
||||
'get_root_logger',
|
||||
'get_env_info',
|
||||
# misc.py
|
||||
'set_random_seed',
|
||||
'get_time_str',
|
||||
'mkdir_and_rename',
|
||||
'make_exp_dirs',
|
||||
'scandir',
|
||||
'check_resume',
|
||||
'sizeof_fmt',
|
||||
# diffjpeg
|
||||
'DiffJPEG',
|
||||
# img_process_util
|
||||
'USMSharp',
|
||||
'usm_sharp',
|
||||
# options
|
||||
'yaml_load'
|
||||
]
|
||||
BIN
basicsr/utils/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/color_util.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/color_util.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/color_util.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/color_util.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/diffjpeg.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/diffjpeg.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/diffjpeg.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/diffjpeg.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/dist_util.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/dist_util.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/dist_util.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/dist_util.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/file_client.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/file_client.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/file_client.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/file_client.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/flow_util.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/flow_util.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/flow_util.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/flow_util.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/img_process_util.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/img_process_util.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/img_process_util.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/img_process_util.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/img_util.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/img_util.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/img_util.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/img_util.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/logger.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/logger.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/logger.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/logger.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/misc.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/misc.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/misc.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/misc.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/options.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/options.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/options.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/options.cpython-38.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/registry.cpython-310.pyc
Normal file
BIN
basicsr/utils/__pycache__/registry.cpython-310.pyc
Normal file
Binary file not shown.
BIN
basicsr/utils/__pycache__/registry.cpython-38.pyc
Normal file
BIN
basicsr/utils/__pycache__/registry.cpython-38.pyc
Normal file
Binary file not shown.
208
basicsr/utils/color_util.py
Normal file
208
basicsr/utils/color_util.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def rgb2ycbcr(img, y_only=False):
|
||||
"""Convert a RGB image to YCbCr image.
|
||||
|
||||
This function produces the same results as Matlab's `rgb2ycbcr` function.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
y_only (bool): Whether to only return Y channel. Default: False.
|
||||
|
||||
Returns:
|
||||
ndarray: The converted YCbCr image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img)
|
||||
if y_only:
|
||||
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
|
||||
else:
|
||||
out_img = np.matmul(
|
||||
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def bgr2ycbcr(img, y_only=False):
|
||||
"""Convert a BGR image to YCbCr image.
|
||||
|
||||
The bgr version of rgb2ycbcr.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
y_only (bool): Whether to only return Y channel. Default: False.
|
||||
|
||||
Returns:
|
||||
ndarray: The converted YCbCr image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img)
|
||||
if y_only:
|
||||
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
|
||||
else:
|
||||
out_img = np.matmul(
|
||||
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def ycbcr2rgb(img):
|
||||
"""Convert a YCbCr image to RGB image.
|
||||
|
||||
This function produces the same results as Matlab's ycbcr2rgb function.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
|
||||
Returns:
|
||||
ndarray: The converted RGB image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img) * 255
|
||||
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
||||
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def ycbcr2bgr(img):
|
||||
"""Convert a YCbCr image to BGR image.
|
||||
|
||||
The bgr version of ycbcr2rgb.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
|
||||
Returns:
|
||||
ndarray: The converted BGR image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img) * 255
|
||||
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
|
||||
[0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def _convert_input_type_range(img):
|
||||
"""Convert the type and range of the input image.
|
||||
|
||||
It converts the input image to np.float32 type and range of [0, 1].
|
||||
It is mainly used for pre-processing the input image in colorspace
|
||||
conversion functions such as rgb2ycbcr and ycbcr2rgb.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
|
||||
Returns:
|
||||
(ndarray): The converted image with type of np.float32 and range of
|
||||
[0, 1].
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = img.astype(np.float32)
|
||||
if img_type == np.float32:
|
||||
pass
|
||||
elif img_type == np.uint8:
|
||||
img /= 255.
|
||||
else:
|
||||
raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
|
||||
return img
|
||||
|
||||
|
||||
def _convert_output_type_range(img, dst_type):
|
||||
"""Convert the type and range of the image according to dst_type.
|
||||
|
||||
It converts the image to desired type and range. If `dst_type` is np.uint8,
|
||||
images will be converted to np.uint8 type with range [0, 255]. If
|
||||
`dst_type` is np.float32, it converts the image to np.float32 type with
|
||||
range [0, 1].
|
||||
It is mainly used for post-processing images in colorspace conversion
|
||||
functions such as rgb2ycbcr and ycbcr2rgb.
|
||||
|
||||
Args:
|
||||
img (ndarray): The image to be converted with np.float32 type and
|
||||
range [0, 255].
|
||||
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
|
||||
converts the image to np.uint8 type with range [0, 255]. If
|
||||
dst_type is np.float32, it converts the image to np.float32 type
|
||||
with range [0, 1].
|
||||
|
||||
Returns:
|
||||
(ndarray): The converted image with desired type and range.
|
||||
"""
|
||||
if dst_type not in (np.uint8, np.float32):
|
||||
raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
|
||||
if dst_type == np.uint8:
|
||||
img = img.round()
|
||||
else:
|
||||
img /= 255.
|
||||
return img.astype(dst_type)
|
||||
|
||||
|
||||
def rgb2ycbcr_pt(img, y_only=False):
|
||||
"""Convert RGB images to YCbCr images (PyTorch version).
|
||||
|
||||
It implements the ITU-R BT.601 conversion for standard-definition television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
Args:
|
||||
img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format.
|
||||
y_only (bool): Whether to only return Y channel. Default: False.
|
||||
|
||||
Returns:
|
||||
(Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float.
|
||||
"""
|
||||
if y_only:
|
||||
weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img)
|
||||
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
|
||||
else:
|
||||
weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img)
|
||||
bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img)
|
||||
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
|
||||
|
||||
out_img = out_img / 255.
|
||||
return out_img
|
||||
515
basicsr/utils/diffjpeg.py
Normal file
515
basicsr/utils/diffjpeg.py
Normal file
@@ -0,0 +1,515 @@
|
||||
"""
|
||||
Modified from https://github.com/mlomnitz/DiffJPEG
|
||||
|
||||
For images not divisible by 8
|
||||
https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343
|
||||
"""
|
||||
import itertools
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
# ------------------------ utils ------------------------#
|
||||
y_table = np.array(
|
||||
[[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56],
|
||||
[14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92],
|
||||
[49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
|
||||
dtype=np.float32).T
|
||||
y_table = nn.Parameter(torch.from_numpy(y_table))
|
||||
c_table = np.empty((8, 8), dtype=np.float32)
|
||||
c_table.fill(99)
|
||||
c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T
|
||||
c_table = nn.Parameter(torch.from_numpy(c_table))
|
||||
|
||||
|
||||
def diff_round(x):
|
||||
""" Differentiable rounding function
|
||||
"""
|
||||
return torch.round(x) + (x - torch.round(x))**3
|
||||
|
||||
|
||||
def quality_to_factor(quality):
|
||||
""" Calculate factor corresponding to quality
|
||||
|
||||
Args:
|
||||
quality(float): Quality for jpeg compression.
|
||||
|
||||
Returns:
|
||||
float: Compression factor.
|
||||
"""
|
||||
if quality < 50:
|
||||
quality = 5000. / quality
|
||||
else:
|
||||
quality = 200. - quality * 2
|
||||
return quality / 100.
|
||||
|
||||
|
||||
# ------------------------ compression ------------------------#
|
||||
class RGB2YCbCrJpeg(nn.Module):
|
||||
""" Converts RGB image to YCbCr
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(RGB2YCbCrJpeg, self).__init__()
|
||||
matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]],
|
||||
dtype=np.float32).T
|
||||
self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
|
||||
self.matrix = nn.Parameter(torch.from_numpy(matrix))
|
||||
|
||||
def forward(self, image):
|
||||
"""
|
||||
Args:
|
||||
image(Tensor): batch x 3 x height x width
|
||||
|
||||
Returns:
|
||||
Tensor: batch x height x width x 3
|
||||
"""
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
result = torch.tensordot(image, self.matrix, dims=1) + self.shift
|
||||
return result.view(image.shape)
|
||||
|
||||
|
||||
class ChromaSubsampling(nn.Module):
|
||||
""" Chroma subsampling on CbCr channels
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ChromaSubsampling, self).__init__()
|
||||
|
||||
def forward(self, image):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x height x width x 3
|
||||
|
||||
Returns:
|
||||
y(tensor): batch x height x width
|
||||
cb(tensor): batch x height/2 x width/2
|
||||
cr(tensor): batch x height/2 x width/2
|
||||
"""
|
||||
image_2 = image.permute(0, 3, 1, 2).clone()
|
||||
cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
|
||||
cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
|
||||
cb = cb.permute(0, 2, 3, 1)
|
||||
cr = cr.permute(0, 2, 3, 1)
|
||||
return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
|
||||
|
||||
|
||||
class BlockSplitting(nn.Module):
|
||||
""" Splitting image into patches
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(BlockSplitting, self).__init__()
|
||||
self.k = 8
|
||||
|
||||
def forward(self, image):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x height x width
|
||||
|
||||
Returns:
|
||||
Tensor: batch x h*w/64 x h x w
|
||||
"""
|
||||
height, _ = image.shape[1:3]
|
||||
batch_size = image.shape[0]
|
||||
image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
|
||||
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
|
||||
return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
|
||||
|
||||
|
||||
class DCT8x8(nn.Module):
|
||||
""" Discrete Cosine Transformation
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(DCT8x8, self).__init__()
|
||||
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
|
||||
for x, y, u, v in itertools.product(range(8), repeat=4):
|
||||
tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)
|
||||
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
|
||||
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
|
||||
self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())
|
||||
|
||||
def forward(self, image):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x height x width
|
||||
|
||||
Returns:
|
||||
Tensor: batch x height x width
|
||||
"""
|
||||
image = image - 128
|
||||
result = self.scale * torch.tensordot(image, self.tensor, dims=2)
|
||||
result.view(image.shape)
|
||||
return result
|
||||
|
||||
|
||||
class YQuantize(nn.Module):
|
||||
""" JPEG Quantization for Y channel
|
||||
|
||||
Args:
|
||||
rounding(function): rounding function to use
|
||||
"""
|
||||
|
||||
def __init__(self, rounding):
|
||||
super(YQuantize, self).__init__()
|
||||
self.rounding = rounding
|
||||
self.y_table = y_table
|
||||
|
||||
def forward(self, image, factor=1):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x height x width
|
||||
|
||||
Returns:
|
||||
Tensor: batch x height x width
|
||||
"""
|
||||
if isinstance(factor, (int, float)):
|
||||
image = image.float() / (self.y_table * factor)
|
||||
else:
|
||||
b = factor.size(0)
|
||||
table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
|
||||
image = image.float() / table
|
||||
image = self.rounding(image)
|
||||
return image
|
||||
|
||||
|
||||
class CQuantize(nn.Module):
|
||||
""" JPEG Quantization for CbCr channels
|
||||
|
||||
Args:
|
||||
rounding(function): rounding function to use
|
||||
"""
|
||||
|
||||
def __init__(self, rounding):
|
||||
super(CQuantize, self).__init__()
|
||||
self.rounding = rounding
|
||||
self.c_table = c_table
|
||||
|
||||
def forward(self, image, factor=1):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x height x width
|
||||
|
||||
Returns:
|
||||
Tensor: batch x height x width
|
||||
"""
|
||||
if isinstance(factor, (int, float)):
|
||||
image = image.float() / (self.c_table * factor)
|
||||
else:
|
||||
b = factor.size(0)
|
||||
table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
|
||||
image = image.float() / table
|
||||
image = self.rounding(image)
|
||||
return image
|
||||
|
||||
|
||||
class CompressJpeg(nn.Module):
|
||||
"""Full JPEG compression algorithm
|
||||
|
||||
Args:
|
||||
rounding(function): rounding function to use
|
||||
"""
|
||||
|
||||
def __init__(self, rounding=torch.round):
|
||||
super(CompressJpeg, self).__init__()
|
||||
self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling())
|
||||
self.l2 = nn.Sequential(BlockSplitting(), DCT8x8())
|
||||
self.c_quantize = CQuantize(rounding=rounding)
|
||||
self.y_quantize = YQuantize(rounding=rounding)
|
||||
|
||||
def forward(self, image, factor=1):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x 3 x height x width
|
||||
|
||||
Returns:
|
||||
dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8.
|
||||
"""
|
||||
y, cb, cr = self.l1(image * 255)
|
||||
components = {'y': y, 'cb': cb, 'cr': cr}
|
||||
for k in components.keys():
|
||||
comp = self.l2(components[k])
|
||||
if k in ('cb', 'cr'):
|
||||
comp = self.c_quantize(comp, factor=factor)
|
||||
else:
|
||||
comp = self.y_quantize(comp, factor=factor)
|
||||
|
||||
components[k] = comp
|
||||
|
||||
return components['y'], components['cb'], components['cr']
|
||||
|
||||
|
||||
# ------------------------ decompression ------------------------#
|
||||
|
||||
|
||||
class YDequantize(nn.Module):
|
||||
"""Dequantize Y channel
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(YDequantize, self).__init__()
|
||||
self.y_table = y_table
|
||||
|
||||
def forward(self, image, factor=1):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x height x width
|
||||
|
||||
Returns:
|
||||
Tensor: batch x height x width
|
||||
"""
|
||||
if isinstance(factor, (int, float)):
|
||||
out = image * (self.y_table * factor)
|
||||
else:
|
||||
b = factor.size(0)
|
||||
table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
|
||||
out = image * table
|
||||
return out
|
||||
|
||||
|
||||
class CDequantize(nn.Module):
|
||||
"""Dequantize CbCr channel
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(CDequantize, self).__init__()
|
||||
self.c_table = c_table
|
||||
|
||||
def forward(self, image, factor=1):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x height x width
|
||||
|
||||
Returns:
|
||||
Tensor: batch x height x width
|
||||
"""
|
||||
if isinstance(factor, (int, float)):
|
||||
out = image * (self.c_table * factor)
|
||||
else:
|
||||
b = factor.size(0)
|
||||
table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
|
||||
out = image * table
|
||||
return out
|
||||
|
||||
|
||||
class iDCT8x8(nn.Module):
|
||||
"""Inverse discrete Cosine Transformation
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(iDCT8x8, self).__init__()
|
||||
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
|
||||
self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
|
||||
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
|
||||
for x, y, u, v in itertools.product(range(8), repeat=4):
|
||||
tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16)
|
||||
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
|
||||
|
||||
def forward(self, image):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x height x width
|
||||
|
||||
Returns:
|
||||
Tensor: batch x height x width
|
||||
"""
|
||||
image = image * self.alpha
|
||||
result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
|
||||
result.view(image.shape)
|
||||
return result
|
||||
|
||||
|
||||
class BlockMerging(nn.Module):
|
||||
"""Merge patches into image
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(BlockMerging, self).__init__()
|
||||
|
||||
def forward(self, patches, height, width):
|
||||
"""
|
||||
Args:
|
||||
patches(tensor) batch x height*width/64, height x width
|
||||
height(int)
|
||||
width(int)
|
||||
|
||||
Returns:
|
||||
Tensor: batch x height x width
|
||||
"""
|
||||
k = 8
|
||||
batch_size = patches.shape[0]
|
||||
image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
|
||||
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
|
||||
return image_transposed.contiguous().view(batch_size, height, width)
|
||||
|
||||
|
||||
class ChromaUpsampling(nn.Module):
|
||||
"""Upsample chroma layers
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ChromaUpsampling, self).__init__()
|
||||
|
||||
def forward(self, y, cb, cr):
|
||||
"""
|
||||
Args:
|
||||
y(tensor): y channel image
|
||||
cb(tensor): cb channel
|
||||
cr(tensor): cr channel
|
||||
|
||||
Returns:
|
||||
Tensor: batch x height x width x 3
|
||||
"""
|
||||
|
||||
def repeat(x, k=2):
|
||||
height, width = x.shape[1:3]
|
||||
x = x.unsqueeze(-1)
|
||||
x = x.repeat(1, 1, k, k)
|
||||
x = x.view(-1, height * k, width * k)
|
||||
return x
|
||||
|
||||
cb = repeat(cb)
|
||||
cr = repeat(cr)
|
||||
return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
|
||||
|
||||
|
||||
class YCbCr2RGBJpeg(nn.Module):
|
||||
"""Converts YCbCr image to RGB JPEG
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(YCbCr2RGBJpeg, self).__init__()
|
||||
|
||||
matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T
|
||||
self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
|
||||
self.matrix = nn.Parameter(torch.from_numpy(matrix))
|
||||
|
||||
def forward(self, image):
|
||||
"""
|
||||
Args:
|
||||
image(tensor): batch x height x width x 3
|
||||
|
||||
Returns:
|
||||
Tensor: batch x 3 x height x width
|
||||
"""
|
||||
result = torch.tensordot(image + self.shift, self.matrix, dims=1)
|
||||
return result.view(image.shape).permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
class DeCompressJpeg(nn.Module):
|
||||
"""Full JPEG decompression algorithm
|
||||
|
||||
Args:
|
||||
rounding(function): rounding function to use
|
||||
"""
|
||||
|
||||
def __init__(self, rounding=torch.round):
|
||||
super(DeCompressJpeg, self).__init__()
|
||||
self.c_dequantize = CDequantize()
|
||||
self.y_dequantize = YDequantize()
|
||||
self.idct = iDCT8x8()
|
||||
self.merging = BlockMerging()
|
||||
self.chroma = ChromaUpsampling()
|
||||
self.colors = YCbCr2RGBJpeg()
|
||||
|
||||
def forward(self, y, cb, cr, imgh, imgw, factor=1):
|
||||
"""
|
||||
Args:
|
||||
compressed(dict(tensor)): batch x h*w/64 x 8 x 8
|
||||
imgh(int)
|
||||
imgw(int)
|
||||
factor(float)
|
||||
|
||||
Returns:
|
||||
Tensor: batch x 3 x height x width
|
||||
"""
|
||||
components = {'y': y, 'cb': cb, 'cr': cr}
|
||||
for k in components.keys():
|
||||
if k in ('cb', 'cr'):
|
||||
comp = self.c_dequantize(components[k], factor=factor)
|
||||
height, width = int(imgh / 2), int(imgw / 2)
|
||||
else:
|
||||
comp = self.y_dequantize(components[k], factor=factor)
|
||||
height, width = imgh, imgw
|
||||
comp = self.idct(comp)
|
||||
components[k] = self.merging(comp, height, width)
|
||||
#
|
||||
image = self.chroma(components['y'], components['cb'], components['cr'])
|
||||
image = self.colors(image)
|
||||
|
||||
image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image))
|
||||
return image / 255
|
||||
|
||||
|
||||
# ------------------------ main DiffJPEG ------------------------ #
|
||||
|
||||
|
||||
class DiffJPEG(nn.Module):
|
||||
"""This JPEG algorithm result is slightly different from cv2.
|
||||
DiffJPEG supports batch processing.
|
||||
|
||||
Args:
|
||||
differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round
|
||||
"""
|
||||
|
||||
def __init__(self, differentiable=True):
|
||||
super(DiffJPEG, self).__init__()
|
||||
if differentiable:
|
||||
rounding = diff_round
|
||||
else:
|
||||
rounding = torch.round
|
||||
|
||||
self.compress = CompressJpeg(rounding=rounding)
|
||||
self.decompress = DeCompressJpeg(rounding=rounding)
|
||||
|
||||
def forward(self, x, quality):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): Input image, bchw, rgb, [0, 1]
|
||||
quality(float): Quality factor for jpeg compression scheme.
|
||||
"""
|
||||
factor = quality
|
||||
if isinstance(factor, (int, float)):
|
||||
factor = quality_to_factor(factor)
|
||||
else:
|
||||
for i in range(factor.size(0)):
|
||||
factor[i] = quality_to_factor(factor[i])
|
||||
h, w = x.size()[-2:]
|
||||
h_pad, w_pad = 0, 0
|
||||
# why should use 16
|
||||
if h % 16 != 0:
|
||||
h_pad = 16 - h % 16
|
||||
if w % 16 != 0:
|
||||
w_pad = 16 - w % 16
|
||||
x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0)
|
||||
|
||||
y, cb, cr = self.compress(x, factor=factor)
|
||||
recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor)
|
||||
recovered = recovered[:, :, 0:h, 0:w]
|
||||
return recovered
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import cv2
|
||||
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
|
||||
img_gt = cv2.imread('test.png') / 255.
|
||||
|
||||
# -------------- cv2 -------------- #
|
||||
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20]
|
||||
_, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param)
|
||||
img_lq = np.float32(cv2.imdecode(encimg, 1))
|
||||
cv2.imwrite('cv2_JPEG_20.png', img_lq)
|
||||
|
||||
# -------------- DiffJPEG -------------- #
|
||||
jpeger = DiffJPEG(differentiable=False).cuda()
|
||||
img_gt = img2tensor(img_gt)
|
||||
img_gt = torch.stack([img_gt, img_gt]).cuda()
|
||||
quality = img_gt.new_tensor([20, 40])
|
||||
out = jpeger(img_gt, quality=quality)
|
||||
|
||||
cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0]))
|
||||
cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1]))
|
||||
82
basicsr/utils/dist_util.py
Normal file
82
basicsr/utils/dist_util.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
|
||||
import functools
|
||||
import os
|
||||
import subprocess
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
|
||||
def init_dist(launcher, backend='nccl', **kwargs):
|
||||
if mp.get_start_method(allow_none=True) is None:
|
||||
mp.set_start_method('spawn')
|
||||
if launcher == 'pytorch':
|
||||
_init_dist_pytorch(backend, **kwargs)
|
||||
elif launcher == 'slurm':
|
||||
_init_dist_slurm(backend, **kwargs)
|
||||
else:
|
||||
raise ValueError(f'Invalid launcher type: {launcher}')
|
||||
|
||||
|
||||
def _init_dist_pytorch(backend, **kwargs):
|
||||
rank = int(os.environ['RANK'])
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(rank % num_gpus)
|
||||
dist.init_process_group(backend=backend, **kwargs)
|
||||
|
||||
|
||||
def _init_dist_slurm(backend, port=None):
|
||||
"""Initialize slurm distributed training environment.
|
||||
|
||||
If argument ``port`` is not specified, then the master port will be system
|
||||
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
|
||||
environment variable, then a default port ``29500`` will be used.
|
||||
|
||||
Args:
|
||||
backend (str): Backend of torch.distributed.
|
||||
port (int, optional): Master port. Defaults to None.
|
||||
"""
|
||||
proc_id = int(os.environ['SLURM_PROCID'])
|
||||
ntasks = int(os.environ['SLURM_NTASKS'])
|
||||
node_list = os.environ['SLURM_NODELIST']
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(proc_id % num_gpus)
|
||||
addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
|
||||
# specify master port
|
||||
if port is not None:
|
||||
os.environ['MASTER_PORT'] = str(port)
|
||||
elif 'MASTER_PORT' in os.environ:
|
||||
pass # use MASTER_PORT in the environment variable
|
||||
else:
|
||||
# 29500 is torch.distributed default port
|
||||
os.environ['MASTER_PORT'] = '29500'
|
||||
os.environ['MASTER_ADDR'] = addr
|
||||
os.environ['WORLD_SIZE'] = str(ntasks)
|
||||
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
||||
os.environ['RANK'] = str(proc_id)
|
||||
dist.init_process_group(backend=backend)
|
||||
|
||||
|
||||
def get_dist_info():
|
||||
if dist.is_available():
|
||||
initialized = dist.is_initialized()
|
||||
else:
|
||||
initialized = False
|
||||
if initialized:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
return rank, world_size
|
||||
|
||||
|
||||
def master_only(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
98
basicsr/utils/download_util.py
Normal file
98
basicsr/utils/download_util.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import math
|
||||
import os
|
||||
import requests
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
from tqdm import tqdm
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from .misc import sizeof_fmt
|
||||
|
||||
|
||||
def download_file_from_google_drive(file_id, save_path):
|
||||
"""Download files from google drive.
|
||||
|
||||
Reference: https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive
|
||||
|
||||
Args:
|
||||
file_id (str): File id.
|
||||
save_path (str): Save path.
|
||||
"""
|
||||
|
||||
session = requests.Session()
|
||||
URL = 'https://docs.google.com/uc?export=download'
|
||||
params = {'id': file_id}
|
||||
|
||||
response = session.get(URL, params=params, stream=True)
|
||||
token = get_confirm_token(response)
|
||||
if token:
|
||||
params['confirm'] = token
|
||||
response = session.get(URL, params=params, stream=True)
|
||||
|
||||
# get file size
|
||||
response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
|
||||
if 'Content-Range' in response_file_size.headers:
|
||||
file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
|
||||
else:
|
||||
file_size = None
|
||||
|
||||
save_response_content(response, save_path, file_size)
|
||||
|
||||
|
||||
def get_confirm_token(response):
|
||||
for key, value in response.cookies.items():
|
||||
if key.startswith('download_warning'):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def save_response_content(response, destination, file_size=None, chunk_size=32768):
|
||||
if file_size is not None:
|
||||
pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
|
||||
|
||||
readable_file_size = sizeof_fmt(file_size)
|
||||
else:
|
||||
pbar = None
|
||||
|
||||
with open(destination, 'wb') as f:
|
||||
downloaded_size = 0
|
||||
for chunk in response.iter_content(chunk_size):
|
||||
downloaded_size += chunk_size
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
if pbar is not None:
|
||||
pbar.close()
|
||||
|
||||
|
||||
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
||||
"""Load file form http url, will download models if necessary.
|
||||
|
||||
Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||
|
||||
Args:
|
||||
url (str): URL to be downloaded.
|
||||
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
|
||||
Default: None.
|
||||
progress (bool): Whether to show the download progress. Default: True.
|
||||
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
|
||||
|
||||
Returns:
|
||||
str: The path to the downloaded file.
|
||||
"""
|
||||
if model_dir is None: # use the pytorch hub_dir
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
if file_name is not None:
|
||||
filename = file_name
|
||||
cached_file = os.path.abspath(os.path.join(model_dir, filename))
|
||||
if not os.path.exists(cached_file):
|
||||
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
||||
return cached_file
|
||||
167
basicsr/utils/file_client.py
Normal file
167
basicsr/utils/file_client.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class BaseStorageBackend(metaclass=ABCMeta):
|
||||
"""Abstract class of storage backends.
|
||||
|
||||
All backends need to implement two apis: ``get()`` and ``get_text()``.
|
||||
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
|
||||
as texts.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, filepath):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_text(self, filepath):
|
||||
pass
|
||||
|
||||
|
||||
class MemcachedBackend(BaseStorageBackend):
|
||||
"""Memcached storage backend.
|
||||
|
||||
Attributes:
|
||||
server_list_cfg (str): Config file for memcached server list.
|
||||
client_cfg (str): Config file for memcached client.
|
||||
sys_path (str | None): Additional path to be appended to `sys.path`.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
|
||||
if sys_path is not None:
|
||||
import sys
|
||||
sys.path.append(sys_path)
|
||||
try:
|
||||
import mc
|
||||
except ImportError:
|
||||
raise ImportError('Please install memcached to enable MemcachedBackend.')
|
||||
|
||||
self.server_list_cfg = server_list_cfg
|
||||
self.client_cfg = client_cfg
|
||||
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
|
||||
# mc.pyvector servers as a point which points to a memory cache
|
||||
self._mc_buffer = mc.pyvector()
|
||||
|
||||
def get(self, filepath):
|
||||
filepath = str(filepath)
|
||||
import mc
|
||||
self._client.Get(filepath, self._mc_buffer)
|
||||
value_buf = mc.ConvertBuffer(self._mc_buffer)
|
||||
return value_buf
|
||||
|
||||
def get_text(self, filepath):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HardDiskBackend(BaseStorageBackend):
|
||||
"""Raw hard disks storage backend."""
|
||||
|
||||
def get(self, filepath):
|
||||
filepath = str(filepath)
|
||||
with open(filepath, 'rb') as f:
|
||||
value_buf = f.read()
|
||||
return value_buf
|
||||
|
||||
def get_text(self, filepath):
|
||||
filepath = str(filepath)
|
||||
with open(filepath, 'r') as f:
|
||||
value_buf = f.read()
|
||||
return value_buf
|
||||
|
||||
|
||||
class LmdbBackend(BaseStorageBackend):
|
||||
"""Lmdb storage backend.
|
||||
|
||||
Args:
|
||||
db_paths (str | list[str]): Lmdb database paths.
|
||||
client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
|
||||
readonly (bool, optional): Lmdb environment parameter. If True,
|
||||
disallow any write operations. Default: True.
|
||||
lock (bool, optional): Lmdb environment parameter. If False, when
|
||||
concurrent access occurs, do not lock the database. Default: False.
|
||||
readahead (bool, optional): Lmdb environment parameter. If False,
|
||||
disable the OS filesystem readahead mechanism, which may improve
|
||||
random read performance when a database is larger than RAM.
|
||||
Default: False.
|
||||
|
||||
Attributes:
|
||||
db_paths (list): Lmdb database path.
|
||||
_client (list): A list of several lmdb envs.
|
||||
"""
|
||||
|
||||
def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
|
||||
try:
|
||||
import lmdb
|
||||
except ImportError:
|
||||
raise ImportError('Please install lmdb to enable LmdbBackend.')
|
||||
|
||||
if isinstance(client_keys, str):
|
||||
client_keys = [client_keys]
|
||||
|
||||
if isinstance(db_paths, list):
|
||||
self.db_paths = [str(v) for v in db_paths]
|
||||
elif isinstance(db_paths, str):
|
||||
self.db_paths = [str(db_paths)]
|
||||
assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
|
||||
f'but received {len(client_keys)} and {len(self.db_paths)}.')
|
||||
|
||||
self._client = {}
|
||||
for client, path in zip(client_keys, self.db_paths):
|
||||
self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
|
||||
|
||||
def get(self, filepath, client_key):
|
||||
"""Get values according to the filepath from one lmdb named client_key.
|
||||
|
||||
Args:
|
||||
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
|
||||
client_key (str): Used for distinguishing different lmdb envs.
|
||||
"""
|
||||
filepath = str(filepath)
|
||||
assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.')
|
||||
client = self._client[client_key]
|
||||
with client.begin(write=False) as txn:
|
||||
value_buf = txn.get(filepath.encode('ascii'))
|
||||
return value_buf
|
||||
|
||||
def get_text(self, filepath):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FileClient(object):
|
||||
"""A general file client to access files in different backend.
|
||||
|
||||
The client loads a file or text in a specified backend from its path
|
||||
and return it as a binary file. it can also register other backend
|
||||
accessor with a given name and backend class.
|
||||
|
||||
Attributes:
|
||||
backend (str): The storage backend type. Options are "disk",
|
||||
"memcached" and "lmdb".
|
||||
client (:obj:`BaseStorageBackend`): The backend object.
|
||||
"""
|
||||
|
||||
_backends = {
|
||||
'disk': HardDiskBackend,
|
||||
'memcached': MemcachedBackend,
|
||||
'lmdb': LmdbBackend,
|
||||
}
|
||||
|
||||
def __init__(self, backend='disk', **kwargs):
|
||||
if backend not in self._backends:
|
||||
raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
|
||||
f' are {list(self._backends.keys())}')
|
||||
self.backend = backend
|
||||
self.client = self._backends[backend](**kwargs)
|
||||
|
||||
def get(self, filepath, client_key='default'):
|
||||
# client_key is used only for lmdb, where different fileclients have
|
||||
# different lmdb environments.
|
||||
if self.backend == 'lmdb':
|
||||
return self.client.get(filepath, client_key)
|
||||
else:
|
||||
return self.client.get(filepath)
|
||||
|
||||
def get_text(self, filepath):
|
||||
return self.client.get_text(filepath)
|
||||
170
basicsr/utils/flow_util.py
Normal file
170
basicsr/utils/flow_util.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
|
||||
"""Read an optical flow map.
|
||||
|
||||
Args:
|
||||
flow_path (ndarray or str): Flow path.
|
||||
quantize (bool): whether to read quantized pair, if set to True,
|
||||
remaining args will be passed to :func:`dequantize_flow`.
|
||||
concat_axis (int): The axis that dx and dy are concatenated,
|
||||
can be either 0 or 1. Ignored if quantize is False.
|
||||
|
||||
Returns:
|
||||
ndarray: Optical flow represented as a (h, w, 2) numpy array
|
||||
"""
|
||||
if quantize:
|
||||
assert concat_axis in [0, 1]
|
||||
cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
|
||||
if cat_flow.ndim != 2:
|
||||
raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.')
|
||||
assert cat_flow.shape[concat_axis] % 2 == 0
|
||||
dx, dy = np.split(cat_flow, 2, axis=concat_axis)
|
||||
flow = dequantize_flow(dx, dy, *args, **kwargs)
|
||||
else:
|
||||
with open(flow_path, 'rb') as f:
|
||||
try:
|
||||
header = f.read(4).decode('utf-8')
|
||||
except Exception:
|
||||
raise IOError(f'Invalid flow file: {flow_path}')
|
||||
else:
|
||||
if header != 'PIEH':
|
||||
raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH')
|
||||
|
||||
w = np.fromfile(f, np.int32, 1).squeeze()
|
||||
h = np.fromfile(f, np.int32, 1).squeeze()
|
||||
flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
|
||||
|
||||
return flow.astype(np.float32)
|
||||
|
||||
|
||||
def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
|
||||
"""Write optical flow to file.
|
||||
|
||||
If the flow is not quantized, it will be saved as a .flo file losslessly,
|
||||
otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
|
||||
will be concatenated horizontally into a single image if quantize is True.)
|
||||
|
||||
Args:
|
||||
flow (ndarray): (h, w, 2) array of optical flow.
|
||||
filename (str): Output filepath.
|
||||
quantize (bool): Whether to quantize the flow and save it to 2 jpeg
|
||||
images. If set to True, remaining args will be passed to
|
||||
:func:`quantize_flow`.
|
||||
concat_axis (int): The axis that dx and dy are concatenated,
|
||||
can be either 0 or 1. Ignored if quantize is False.
|
||||
"""
|
||||
if not quantize:
|
||||
with open(filename, 'wb') as f:
|
||||
f.write('PIEH'.encode('utf-8'))
|
||||
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
|
||||
flow = flow.astype(np.float32)
|
||||
flow.tofile(f)
|
||||
f.flush()
|
||||
else:
|
||||
assert concat_axis in [0, 1]
|
||||
dx, dy = quantize_flow(flow, *args, **kwargs)
|
||||
dxdy = np.concatenate((dx, dy), axis=concat_axis)
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
cv2.imwrite(filename, dxdy)
|
||||
|
||||
|
||||
def quantize_flow(flow, max_val=0.02, norm=True):
|
||||
"""Quantize flow to [0, 255].
|
||||
|
||||
After this step, the size of flow will be much smaller, and can be
|
||||
dumped as jpeg images.
|
||||
|
||||
Args:
|
||||
flow (ndarray): (h, w, 2) array of optical flow.
|
||||
max_val (float): Maximum value of flow, values beyond
|
||||
[-max_val, max_val] will be truncated.
|
||||
norm (bool): Whether to divide flow values by image width/height.
|
||||
|
||||
Returns:
|
||||
tuple[ndarray]: Quantized dx and dy.
|
||||
"""
|
||||
h, w, _ = flow.shape
|
||||
dx = flow[..., 0]
|
||||
dy = flow[..., 1]
|
||||
if norm:
|
||||
dx = dx / w # avoid inplace operations
|
||||
dy = dy / h
|
||||
# use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
|
||||
flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]]
|
||||
return tuple(flow_comps)
|
||||
|
||||
|
||||
def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
|
||||
"""Recover from quantized flow.
|
||||
|
||||
Args:
|
||||
dx (ndarray): Quantized dx.
|
||||
dy (ndarray): Quantized dy.
|
||||
max_val (float): Maximum value used when quantizing.
|
||||
denorm (bool): Whether to multiply flow values with width/height.
|
||||
|
||||
Returns:
|
||||
ndarray: Dequantized flow.
|
||||
"""
|
||||
assert dx.shape == dy.shape
|
||||
assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
|
||||
|
||||
dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
|
||||
|
||||
if denorm:
|
||||
dx *= dx.shape[1]
|
||||
dy *= dx.shape[0]
|
||||
flow = np.dstack((dx, dy))
|
||||
return flow
|
||||
|
||||
|
||||
def quantize(arr, min_val, max_val, levels, dtype=np.int64):
|
||||
"""Quantize an array of (-inf, inf) to [0, levels-1].
|
||||
|
||||
Args:
|
||||
arr (ndarray): Input array.
|
||||
min_val (scalar): Minimum value to be clipped.
|
||||
max_val (scalar): Maximum value to be clipped.
|
||||
levels (int): Quantization levels.
|
||||
dtype (np.type): The type of the quantized array.
|
||||
|
||||
Returns:
|
||||
tuple: Quantized array.
|
||||
"""
|
||||
if not (isinstance(levels, int) and levels > 1):
|
||||
raise ValueError(f'levels must be a positive integer, but got {levels}')
|
||||
if min_val >= max_val:
|
||||
raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
|
||||
|
||||
arr = np.clip(arr, min_val, max_val) - min_val
|
||||
quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
|
||||
|
||||
return quantized_arr
|
||||
|
||||
|
||||
def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
|
||||
"""Dequantize an array.
|
||||
|
||||
Args:
|
||||
arr (ndarray): Input array.
|
||||
min_val (scalar): Minimum value to be clipped.
|
||||
max_val (scalar): Maximum value to be clipped.
|
||||
levels (int): Quantization levels.
|
||||
dtype (np.type): The type of the dequantized array.
|
||||
|
||||
Returns:
|
||||
tuple: Dequantized array.
|
||||
"""
|
||||
if not (isinstance(levels, int) and levels > 1):
|
||||
raise ValueError(f'levels must be a positive integer, but got {levels}')
|
||||
if min_val >= max_val:
|
||||
raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
|
||||
|
||||
dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val
|
||||
|
||||
return dequantized_arr
|
||||
83
basicsr/utils/img_process_util.py
Normal file
83
basicsr/utils/img_process_util.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def filter2D(img, kernel):
|
||||
"""PyTorch version of cv2.filter2D
|
||||
|
||||
Args:
|
||||
img (Tensor): (b, c, h, w)
|
||||
kernel (Tensor): (b, k, k)
|
||||
"""
|
||||
k = kernel.size(-1)
|
||||
b, c, h, w = img.size()
|
||||
if k % 2 == 1:
|
||||
img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
|
||||
else:
|
||||
raise ValueError('Wrong kernel size')
|
||||
|
||||
ph, pw = img.size()[-2:]
|
||||
|
||||
if kernel.size(0) == 1:
|
||||
# apply the same kernel to all batch images
|
||||
img = img.view(b * c, 1, ph, pw)
|
||||
kernel = kernel.view(1, 1, k, k)
|
||||
return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
|
||||
else:
|
||||
img = img.view(1, b * c, ph, pw)
|
||||
kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
|
||||
return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
|
||||
|
||||
|
||||
def usm_sharp(img, weight=0.5, radius=50, threshold=10):
|
||||
"""USM sharpening.
|
||||
|
||||
Input image: I; Blurry image: B.
|
||||
1. sharp = I + weight * (I - B)
|
||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
||||
3. Blur mask:
|
||||
4. Out = Mask * sharp + (1 - Mask) * I
|
||||
|
||||
|
||||
Args:
|
||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
||||
weight (float): Sharp weight. Default: 1.
|
||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
||||
threshold (int):
|
||||
"""
|
||||
if radius % 2 == 0:
|
||||
radius += 1
|
||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
||||
residual = img - blur
|
||||
mask = np.abs(residual) * 255 > threshold
|
||||
mask = mask.astype('float32')
|
||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
||||
|
||||
sharp = img + weight * residual
|
||||
sharp = np.clip(sharp, 0, 1)
|
||||
return soft_mask * sharp + (1 - soft_mask) * img
|
||||
|
||||
|
||||
class USMSharp(torch.nn.Module):
|
||||
|
||||
def __init__(self, radius=50, sigma=0):
|
||||
super(USMSharp, self).__init__()
|
||||
if radius % 2 == 0:
|
||||
radius += 1
|
||||
self.radius = radius
|
||||
kernel = cv2.getGaussianKernel(radius, sigma)
|
||||
kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
|
||||
self.register_buffer('kernel', kernel)
|
||||
|
||||
def forward(self, img, weight=0.5, threshold=10):
|
||||
blur = filter2D(img, self.kernel)
|
||||
residual = img - blur
|
||||
|
||||
mask = torch.abs(residual) * 255 > threshold
|
||||
mask = mask.float()
|
||||
soft_mask = filter2D(mask, self.kernel)
|
||||
sharp = img + weight * residual
|
||||
sharp = torch.clip(sharp, 0, 1)
|
||||
return soft_mask * sharp + (1 - soft_mask) * img
|
||||
172
basicsr/utils/img_util.py
Normal file
172
basicsr/utils/img_util.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
|
||||
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
||||
"""Numpy array to tensor.
|
||||
|
||||
Args:
|
||||
imgs (list[ndarray] | ndarray): Input images.
|
||||
bgr2rgb (bool): Whether to change bgr to rgb.
|
||||
float32 (bool): Whether to change to float32.
|
||||
|
||||
Returns:
|
||||
list[tensor] | tensor: Tensor images. If returned results only have
|
||||
one element, just return tensor.
|
||||
"""
|
||||
|
||||
def _totensor(img, bgr2rgb, float32):
|
||||
if img.shape[2] == 3 and bgr2rgb:
|
||||
if img.dtype == 'float64':
|
||||
img = img.astype('float32')
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = torch.from_numpy(img.transpose(2, 0, 1))
|
||||
if float32:
|
||||
img = img.float()
|
||||
return img
|
||||
|
||||
if isinstance(imgs, list):
|
||||
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
||||
else:
|
||||
return _totensor(imgs, bgr2rgb, float32)
|
||||
|
||||
|
||||
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
|
||||
"""Convert torch Tensors into image numpy arrays.
|
||||
|
||||
After clamping to [min, max], values will be normalized to [0, 1].
|
||||
|
||||
Args:
|
||||
tensor (Tensor or list[Tensor]): Accept shapes:
|
||||
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
|
||||
2) 3D Tensor of shape (3/1 x H x W);
|
||||
3) 2D Tensor of shape (H x W).
|
||||
Tensor channel should be in RGB order.
|
||||
rgb2bgr (bool): Whether to change rgb to bgr.
|
||||
out_type (numpy type): output types. If ``np.uint8``, transform outputs
|
||||
to uint8 type with range [0, 255]; otherwise, float type with
|
||||
range [0, 1]. Default: ``np.uint8``.
|
||||
min_max (tuple[int]): min and max values for clamp.
|
||||
|
||||
Returns:
|
||||
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
|
||||
shape (H x W). The channel order is BGR.
|
||||
"""
|
||||
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
|
||||
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
|
||||
|
||||
if torch.is_tensor(tensor):
|
||||
tensor = [tensor]
|
||||
result = []
|
||||
for _tensor in tensor:
|
||||
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
||||
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
|
||||
|
||||
n_dim = _tensor.dim()
|
||||
if n_dim == 4:
|
||||
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
|
||||
img_np = img_np.transpose(1, 2, 0)
|
||||
if rgb2bgr:
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||||
elif n_dim == 3:
|
||||
img_np = _tensor.numpy()
|
||||
img_np = img_np.transpose(1, 2, 0)
|
||||
if img_np.shape[2] == 1: # gray image
|
||||
img_np = np.squeeze(img_np, axis=2)
|
||||
else:
|
||||
if rgb2bgr:
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||||
elif n_dim == 2:
|
||||
img_np = _tensor.numpy()
|
||||
else:
|
||||
raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
|
||||
if out_type == np.uint8:
|
||||
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
|
||||
img_np = (img_np * 255.0).round()
|
||||
img_np = img_np.astype(out_type)
|
||||
result.append(img_np)
|
||||
if len(result) == 1 and torch.is_tensor(tensor):
|
||||
result = result[0]
|
||||
return result
|
||||
|
||||
|
||||
def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
|
||||
"""This implementation is slightly faster than tensor2img.
|
||||
It now only supports torch tensor with shape (1, c, h, w).
|
||||
|
||||
Args:
|
||||
tensor (Tensor): Now only support torch tensor with (1, c, h, w).
|
||||
rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
|
||||
min_max (tuple[int]): min and max values for clamp.
|
||||
"""
|
||||
output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
|
||||
output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
|
||||
output = output.type(torch.uint8).cpu().numpy()
|
||||
if rgb2bgr:
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
|
||||
|
||||
def imfrombytes(content, flag='color', float32=False):
|
||||
"""Read an image from bytes.
|
||||
|
||||
Args:
|
||||
content (bytes): Image bytes got from files or other streams.
|
||||
flag (str): Flags specifying the color type of a loaded image,
|
||||
candidates are `color`, `grayscale` and `unchanged`.
|
||||
float32 (bool): Whether to change to float32., If True, will also norm
|
||||
to [0, 1]. Default: False.
|
||||
|
||||
Returns:
|
||||
ndarray: Loaded image array.
|
||||
"""
|
||||
img_np = np.frombuffer(content, np.uint8)
|
||||
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
|
||||
img = cv2.imdecode(img_np, imread_flags[flag])
|
||||
if float32:
|
||||
img = img.astype(np.float32) / 255.
|
||||
return img
|
||||
|
||||
|
||||
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
||||
"""Write image to file.
|
||||
|
||||
Args:
|
||||
img (ndarray): Image array to be written.
|
||||
file_path (str): Image file path.
|
||||
params (None or list): Same as opencv's :func:`imwrite` interface.
|
||||
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
|
||||
whether to create it automatically.
|
||||
|
||||
Returns:
|
||||
bool: Successful or not.
|
||||
"""
|
||||
if auto_mkdir:
|
||||
dir_name = os.path.abspath(os.path.dirname(file_path))
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
ok = cv2.imwrite(file_path, img, params)
|
||||
if not ok:
|
||||
raise IOError('Failed in writing images.')
|
||||
|
||||
|
||||
def crop_border(imgs, crop_border):
|
||||
"""Crop borders of images.
|
||||
|
||||
Args:
|
||||
imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
|
||||
crop_border (int): Crop border for each end of height and weight.
|
||||
|
||||
Returns:
|
||||
list[ndarray]: Cropped images.
|
||||
"""
|
||||
if crop_border == 0:
|
||||
return imgs
|
||||
else:
|
||||
if isinstance(imgs, list):
|
||||
return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
|
||||
else:
|
||||
return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
|
||||
199
basicsr/utils/lmdb_util.py
Normal file
199
basicsr/utils/lmdb_util.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import cv2
|
||||
import lmdb
|
||||
import sys
|
||||
from multiprocessing import Pool
|
||||
from os import path as osp
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def make_lmdb_from_imgs(data_path,
|
||||
lmdb_path,
|
||||
img_path_list,
|
||||
keys,
|
||||
batch=5000,
|
||||
compress_level=1,
|
||||
multiprocessing_read=False,
|
||||
n_thread=40,
|
||||
map_size=None):
|
||||
"""Make lmdb from images.
|
||||
|
||||
Contents of lmdb. The file structure is:
|
||||
|
||||
::
|
||||
|
||||
example.lmdb
|
||||
├── data.mdb
|
||||
├── lock.mdb
|
||||
├── meta_info.txt
|
||||
|
||||
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
||||
https://lmdb.readthedocs.io/en/release/ for more details.
|
||||
|
||||
The meta_info.txt is a specified txt file to record the meta information
|
||||
of our datasets. It will be automatically created when preparing
|
||||
datasets by our provided dataset tools.
|
||||
Each line in the txt file records 1)image name (with extension),
|
||||
2)image shape, and 3)compression level, separated by a white space.
|
||||
|
||||
For example, the meta information could be:
|
||||
`000_00000000.png (720,1280,3) 1`, which means:
|
||||
1) image name (with extension): 000_00000000.png;
|
||||
2) image shape: (720,1280,3);
|
||||
3) compression level: 1
|
||||
|
||||
We use the image name without extension as the lmdb key.
|
||||
|
||||
If `multiprocessing_read` is True, it will read all the images to memory
|
||||
using multiprocessing. Thus, your server needs to have enough memory.
|
||||
|
||||
Args:
|
||||
data_path (str): Data path for reading images.
|
||||
lmdb_path (str): Lmdb save path.
|
||||
img_path_list (str): Image path list.
|
||||
keys (str): Used for lmdb keys.
|
||||
batch (int): After processing batch images, lmdb commits.
|
||||
Default: 5000.
|
||||
compress_level (int): Compress level when encoding images. Default: 1.
|
||||
multiprocessing_read (bool): Whether use multiprocessing to read all
|
||||
the images to memory. Default: False.
|
||||
n_thread (int): For multiprocessing.
|
||||
map_size (int | None): Map size for lmdb env. If None, use the
|
||||
estimated size from images. Default: None
|
||||
"""
|
||||
|
||||
assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
|
||||
f'but got {len(img_path_list)} and {len(keys)}')
|
||||
print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
|
||||
print(f'Totoal images: {len(img_path_list)}')
|
||||
if not lmdb_path.endswith('.lmdb'):
|
||||
raise ValueError("lmdb_path must end with '.lmdb'.")
|
||||
if osp.exists(lmdb_path):
|
||||
print(f'Folder {lmdb_path} already exists. Exit.')
|
||||
sys.exit(1)
|
||||
|
||||
if multiprocessing_read:
|
||||
# read all the images to memory (multiprocessing)
|
||||
dataset = {} # use dict to keep the order for multiprocessing
|
||||
shapes = {}
|
||||
print(f'Read images with multiprocessing, #thread: {n_thread} ...')
|
||||
pbar = tqdm(total=len(img_path_list), unit='image')
|
||||
|
||||
def callback(arg):
|
||||
"""get the image data and update pbar."""
|
||||
key, dataset[key], shapes[key] = arg
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Read {key}')
|
||||
|
||||
pool = Pool(n_thread)
|
||||
for path, key in zip(img_path_list, keys):
|
||||
pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
|
||||
pool.close()
|
||||
pool.join()
|
||||
pbar.close()
|
||||
print(f'Finish reading {len(img_path_list)} images.')
|
||||
|
||||
# create lmdb environment
|
||||
if map_size is None:
|
||||
# obtain data size for one image
|
||||
img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
|
||||
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
||||
data_size_per_img = img_byte.nbytes
|
||||
print('Data size per image is: ', data_size_per_img)
|
||||
data_size = data_size_per_img * len(img_path_list)
|
||||
map_size = data_size * 10
|
||||
|
||||
env = lmdb.open(lmdb_path, map_size=map_size)
|
||||
|
||||
# write data to lmdb
|
||||
pbar = tqdm(total=len(img_path_list), unit='chunk')
|
||||
txn = env.begin(write=True)
|
||||
txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
|
||||
for idx, (path, key) in enumerate(zip(img_path_list, keys)):
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Write {key}')
|
||||
key_byte = key.encode('ascii')
|
||||
if multiprocessing_read:
|
||||
img_byte = dataset[key]
|
||||
h, w, c = shapes[key]
|
||||
else:
|
||||
_, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
|
||||
h, w, c = img_shape
|
||||
|
||||
txn.put(key_byte, img_byte)
|
||||
# write meta information
|
||||
txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
|
||||
if idx % batch == 0:
|
||||
txn.commit()
|
||||
txn = env.begin(write=True)
|
||||
pbar.close()
|
||||
txn.commit()
|
||||
env.close()
|
||||
txt_file.close()
|
||||
print('\nFinish writing lmdb.')
|
||||
|
||||
|
||||
def read_img_worker(path, key, compress_level):
|
||||
"""Read image worker.
|
||||
|
||||
Args:
|
||||
path (str): Image path.
|
||||
key (str): Image key.
|
||||
compress_level (int): Compress level when encoding images.
|
||||
|
||||
Returns:
|
||||
str: Image key.
|
||||
byte: Image byte.
|
||||
tuple[int]: Image shape.
|
||||
"""
|
||||
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
if img.ndim == 2:
|
||||
h, w = img.shape
|
||||
c = 1
|
||||
else:
|
||||
h, w, c = img.shape
|
||||
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
||||
return (key, img_byte, (h, w, c))
|
||||
|
||||
|
||||
class LmdbMaker():
|
||||
"""LMDB Maker.
|
||||
|
||||
Args:
|
||||
lmdb_path (str): Lmdb save path.
|
||||
map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
|
||||
batch (int): After processing batch images, lmdb commits.
|
||||
Default: 5000.
|
||||
compress_level (int): Compress level when encoding images. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
|
||||
if not lmdb_path.endswith('.lmdb'):
|
||||
raise ValueError("lmdb_path must end with '.lmdb'.")
|
||||
if osp.exists(lmdb_path):
|
||||
print(f'Folder {lmdb_path} already exists. Exit.')
|
||||
sys.exit(1)
|
||||
|
||||
self.lmdb_path = lmdb_path
|
||||
self.batch = batch
|
||||
self.compress_level = compress_level
|
||||
self.env = lmdb.open(lmdb_path, map_size=map_size)
|
||||
self.txn = self.env.begin(write=True)
|
||||
self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
|
||||
self.counter = 0
|
||||
|
||||
def put(self, img_byte, key, img_shape):
|
||||
self.counter += 1
|
||||
key_byte = key.encode('ascii')
|
||||
self.txn.put(key_byte, img_byte)
|
||||
# write meta information
|
||||
h, w, c = img_shape
|
||||
self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
|
||||
if self.counter % self.batch == 0:
|
||||
self.txn.commit()
|
||||
self.txn = self.env.begin(write=True)
|
||||
|
||||
def close(self):
|
||||
self.txn.commit()
|
||||
self.env.close()
|
||||
self.txt_file.close()
|
||||
213
basicsr/utils/logger.py
Normal file
213
basicsr/utils/logger.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
from .dist_util import get_dist_info, master_only
|
||||
|
||||
initialized_logger = {}
|
||||
|
||||
|
||||
class AvgTimer():
|
||||
|
||||
def __init__(self, window=200):
|
||||
self.window = window # average window
|
||||
self.current_time = 0
|
||||
self.total_time = 0
|
||||
self.count = 0
|
||||
self.avg_time = 0
|
||||
self.start()
|
||||
|
||||
def start(self):
|
||||
self.start_time = self.tic = time.time()
|
||||
|
||||
def record(self):
|
||||
self.count += 1
|
||||
self.toc = time.time()
|
||||
self.current_time = self.toc - self.tic
|
||||
self.total_time += self.current_time
|
||||
# calculate average time
|
||||
self.avg_time = self.total_time / self.count
|
||||
|
||||
# reset
|
||||
if self.count > self.window:
|
||||
self.count = 0
|
||||
self.total_time = 0
|
||||
|
||||
self.tic = time.time()
|
||||
|
||||
def get_current_time(self):
|
||||
return self.current_time
|
||||
|
||||
def get_avg_time(self):
|
||||
return self.avg_time
|
||||
|
||||
|
||||
class MessageLogger():
|
||||
"""Message logger for printing.
|
||||
|
||||
Args:
|
||||
opt (dict): Config. It contains the following keys:
|
||||
name (str): Exp name.
|
||||
logger (dict): Contains 'print_freq' (str) for logger interval.
|
||||
train (dict): Contains 'total_iter' (int) for total iters.
|
||||
use_tb_logger (bool): Use tensorboard logger.
|
||||
start_iter (int): Start iter. Default: 1.
|
||||
tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, opt, start_iter=1, tb_logger=None):
|
||||
self.exp_name = opt['name']
|
||||
self.interval = opt['logger']['print_freq']
|
||||
self.start_iter = start_iter
|
||||
self.max_iters = opt['train']['total_iter']
|
||||
self.use_tb_logger = opt['logger']['use_tb_logger']
|
||||
self.tb_logger = tb_logger
|
||||
self.start_time = time.time()
|
||||
self.logger = get_root_logger()
|
||||
|
||||
def reset_start_time(self):
|
||||
self.start_time = time.time()
|
||||
|
||||
@master_only
|
||||
def __call__(self, log_vars):
|
||||
"""Format logging message.
|
||||
|
||||
Args:
|
||||
log_vars (dict): It contains the following keys:
|
||||
epoch (int): Epoch number.
|
||||
iter (int): Current iter.
|
||||
lrs (list): List for learning rates.
|
||||
|
||||
time (float): Iter time.
|
||||
data_time (float): Data time for each iter.
|
||||
"""
|
||||
# epoch, iter, learning rates
|
||||
epoch = log_vars.pop('epoch')
|
||||
current_iter = log_vars.pop('iter')
|
||||
lrs = log_vars.pop('lrs')
|
||||
|
||||
message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
|
||||
for v in lrs:
|
||||
message += f'{v:.3e},'
|
||||
message += ')] '
|
||||
|
||||
# time and estimated time
|
||||
if 'time' in log_vars.keys():
|
||||
iter_time = log_vars.pop('time')
|
||||
data_time = log_vars.pop('data_time')
|
||||
|
||||
total_time = time.time() - self.start_time
|
||||
time_sec_avg = total_time / (current_iter - self.start_iter + 1)
|
||||
eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
|
||||
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
||||
message += f'[eta: {eta_str}, '
|
||||
message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
|
||||
|
||||
# other items, especially losses
|
||||
for k, v in log_vars.items():
|
||||
message += f'{k}: {v:.4e} '
|
||||
# tensorboard logger
|
||||
if self.use_tb_logger and 'debug' not in self.exp_name:
|
||||
if k.startswith('l_'):
|
||||
self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
|
||||
else:
|
||||
self.tb_logger.add_scalar(k, v, current_iter)
|
||||
self.logger.info(message)
|
||||
|
||||
|
||||
@master_only
|
||||
def init_tb_logger(log_dir):
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
tb_logger = SummaryWriter(log_dir=log_dir)
|
||||
return tb_logger
|
||||
|
||||
|
||||
@master_only
|
||||
def init_wandb_logger(opt):
|
||||
"""We now only use wandb to sync tensorboard log."""
|
||||
import wandb
|
||||
logger = get_root_logger()
|
||||
|
||||
project = opt['logger']['wandb']['project']
|
||||
resume_id = opt['logger']['wandb'].get('resume_id')
|
||||
if resume_id:
|
||||
wandb_id = resume_id
|
||||
resume = 'allow'
|
||||
logger.warning(f'Resume wandb logger with id={wandb_id}.')
|
||||
else:
|
||||
wandb_id = wandb.util.generate_id()
|
||||
resume = 'never'
|
||||
|
||||
wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
|
||||
|
||||
logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
|
||||
|
||||
|
||||
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
|
||||
"""Get the root logger.
|
||||
|
||||
The logger will be initialized if it has not been initialized. By default a
|
||||
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
||||
also be added.
|
||||
|
||||
Args:
|
||||
logger_name (str): root logger name. Default: 'basicsr'.
|
||||
log_file (str | None): The log filename. If specified, a FileHandler
|
||||
will be added to the root logger.
|
||||
log_level (int): The root logger level. Note that only the process of
|
||||
rank 0 is affected, while other processes will set the level to
|
||||
"Error" and be silent most of the time.
|
||||
|
||||
Returns:
|
||||
logging.Logger: The root logger.
|
||||
"""
|
||||
logger = logging.getLogger(logger_name)
|
||||
# if the logger has been initialized, just return it
|
||||
if logger_name in initialized_logger:
|
||||
return logger
|
||||
|
||||
format_str = '%(asctime)s %(levelname)s: %(message)s'
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(logging.Formatter(format_str))
|
||||
logger.addHandler(stream_handler)
|
||||
logger.propagate = False
|
||||
rank, _ = get_dist_info()
|
||||
if rank != 0:
|
||||
logger.setLevel('ERROR')
|
||||
elif log_file is not None:
|
||||
logger.setLevel(log_level)
|
||||
# add file handler
|
||||
file_handler = logging.FileHandler(log_file, 'w')
|
||||
file_handler.setFormatter(logging.Formatter(format_str))
|
||||
file_handler.setLevel(log_level)
|
||||
logger.addHandler(file_handler)
|
||||
initialized_logger[logger_name] = True
|
||||
return logger
|
||||
|
||||
|
||||
def get_env_info():
|
||||
"""Get environment information.
|
||||
|
||||
Currently, only log the software version.
|
||||
"""
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from basicsr.version import __version__
|
||||
msg = r"""
|
||||
____ _ _____ ____
|
||||
/ __ ) ____ _ _____ (_)_____/ ___/ / __ \
|
||||
/ __ |/ __ `// ___// // ___/\__ \ / /_/ /
|
||||
/ /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
|
||||
/_____/ \__,_//____//_/ \___//____//_/ |_|
|
||||
______ __ __ __ __
|
||||
/ ____/____ ____ ____/ / / / __ __ _____ / /__ / /
|
||||
/ / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
|
||||
/ /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
|
||||
\____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
|
||||
"""
|
||||
msg += ('\nVersion Information: '
|
||||
f'\n\tBasicSR: {__version__}'
|
||||
f'\n\tPyTorch: {torch.__version__}'
|
||||
f'\n\tTorchVision: {torchvision.__version__}')
|
||||
return msg
|
||||
178
basicsr/utils/matlab_functions.py
Normal file
178
basicsr/utils/matlab_functions.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def cubic(x):
|
||||
"""cubic function used for calculate_weights_indices."""
|
||||
absx = torch.abs(x)
|
||||
absx2 = absx**2
|
||||
absx3 = absx**3
|
||||
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
|
||||
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
|
||||
(absx <= 2)).type_as(absx))
|
||||
|
||||
|
||||
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
||||
"""Calculate weights and indices, used for imresize function.
|
||||
|
||||
Args:
|
||||
in_length (int): Input length.
|
||||
out_length (int): Output length.
|
||||
scale (float): Scale factor.
|
||||
kernel_width (int): Kernel width.
|
||||
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
|
||||
"""
|
||||
|
||||
if (scale < 1) and antialiasing:
|
||||
# Use a modified kernel (larger kernel width) to simultaneously
|
||||
# interpolate and antialias
|
||||
kernel_width = kernel_width / scale
|
||||
|
||||
# Output-space coordinates
|
||||
x = torch.linspace(1, out_length, out_length)
|
||||
|
||||
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
||||
# in output space maps to 0.5 in input space, and 0.5 + scale in output
|
||||
# space maps to 1.5 in input space.
|
||||
u = x / scale + 0.5 * (1 - 1 / scale)
|
||||
|
||||
# What is the left-most pixel that can be involved in the computation?
|
||||
left = torch.floor(u - kernel_width / 2)
|
||||
|
||||
# What is the maximum number of pixels that can be involved in the
|
||||
# computation? Note: it's OK to use an extra pixel here; if the
|
||||
# corresponding weights are all zero, it will be eliminated at the end
|
||||
# of this function.
|
||||
p = math.ceil(kernel_width) + 2
|
||||
|
||||
# The indices of the input pixels involved in computing the k-th output
|
||||
# pixel are in row k of the indices matrix.
|
||||
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
|
||||
out_length, p)
|
||||
|
||||
# The weights used to compute the k-th output pixel are in row k of the
|
||||
# weights matrix.
|
||||
distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
|
||||
|
||||
# apply cubic kernel
|
||||
if (scale < 1) and antialiasing:
|
||||
weights = scale * cubic(distance_to_center * scale)
|
||||
else:
|
||||
weights = cubic(distance_to_center)
|
||||
|
||||
# Normalize the weights matrix so that each row sums to 1.
|
||||
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
||||
weights = weights / weights_sum.expand(out_length, p)
|
||||
|
||||
# If a column in weights is all zero, get rid of it. only consider the
|
||||
# first and last column.
|
||||
weights_zero_tmp = torch.sum((weights == 0), 0)
|
||||
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
||||
indices = indices.narrow(1, 1, p - 2)
|
||||
weights = weights.narrow(1, 1, p - 2)
|
||||
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
||||
indices = indices.narrow(1, 0, p - 2)
|
||||
weights = weights.narrow(1, 0, p - 2)
|
||||
weights = weights.contiguous()
|
||||
indices = indices.contiguous()
|
||||
sym_len_s = -indices.min() + 1
|
||||
sym_len_e = indices.max() - in_length
|
||||
indices = indices + sym_len_s - 1
|
||||
return weights, indices, int(sym_len_s), int(sym_len_e)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def imresize(img, scale, antialiasing=True):
|
||||
"""imresize function same as MATLAB.
|
||||
|
||||
It now only supports bicubic.
|
||||
The same scale applies for both height and width.
|
||||
|
||||
Args:
|
||||
img (Tensor | Numpy array):
|
||||
Tensor: Input image with shape (c, h, w), [0, 1] range.
|
||||
Numpy: Input image with shape (h, w, c), [0, 1] range.
|
||||
scale (float): Scale factor. The same scale applies for both height
|
||||
and width.
|
||||
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
|
||||
Default: True.
|
||||
|
||||
Returns:
|
||||
Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
|
||||
"""
|
||||
squeeze_flag = False
|
||||
if type(img).__module__ == np.__name__: # numpy type
|
||||
numpy_type = True
|
||||
if img.ndim == 2:
|
||||
img = img[:, :, None]
|
||||
squeeze_flag = True
|
||||
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
|
||||
else:
|
||||
numpy_type = False
|
||||
if img.ndim == 2:
|
||||
img = img.unsqueeze(0)
|
||||
squeeze_flag = True
|
||||
|
||||
in_c, in_h, in_w = img.size()
|
||||
out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
|
||||
kernel_width = 4
|
||||
kernel = 'cubic'
|
||||
|
||||
# get weights and indices
|
||||
weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
|
||||
antialiasing)
|
||||
weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
|
||||
antialiasing)
|
||||
# process H dimension
|
||||
# symmetric copying
|
||||
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
|
||||
img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
|
||||
|
||||
sym_patch = img[:, :sym_len_hs, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = img[:, -sym_len_he:, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
|
||||
|
||||
out_1 = torch.FloatTensor(in_c, out_h, in_w)
|
||||
kernel_width = weights_h.size(1)
|
||||
for i in range(out_h):
|
||||
idx = int(indices_h[i][0])
|
||||
for j in range(in_c):
|
||||
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
|
||||
out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
|
||||
|
||||
sym_patch = out_1[:, :, :sym_len_ws]
|
||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||
out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = out_1[:, :, -sym_len_we:]
|
||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||
out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
|
||||
|
||||
out_2 = torch.FloatTensor(in_c, out_h, out_w)
|
||||
kernel_width = weights_w.size(1)
|
||||
for i in range(out_w):
|
||||
idx = int(indices_w[i][0])
|
||||
for j in range(in_c):
|
||||
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
|
||||
|
||||
if squeeze_flag:
|
||||
out_2 = out_2.squeeze(0)
|
||||
if numpy_type:
|
||||
out_2 = out_2.numpy()
|
||||
if not squeeze_flag:
|
||||
out_2 = out_2.transpose(1, 2, 0)
|
||||
|
||||
return out_2
|
||||
141
basicsr/utils/misc.py
Normal file
141
basicsr/utils/misc.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import torch
|
||||
from os import path as osp
|
||||
|
||||
from .dist_util import master_only
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
"""Set random seeds."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def get_time_str():
|
||||
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
|
||||
|
||||
def mkdir_and_rename(path):
|
||||
"""mkdirs. If path exists, rename it with timestamp and create a new one.
|
||||
|
||||
Args:
|
||||
path (str): Folder path.
|
||||
"""
|
||||
if osp.exists(path):
|
||||
new_name = path + '_archived_' + get_time_str()
|
||||
print(f'Path already exists. Rename it to {new_name}', flush=True)
|
||||
os.rename(path, new_name)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
|
||||
@master_only
|
||||
def make_exp_dirs(opt):
|
||||
"""Make dirs for experiments."""
|
||||
path_opt = opt['path'].copy()
|
||||
if opt['is_train']:
|
||||
mkdir_and_rename(path_opt.pop('experiments_root'))
|
||||
else:
|
||||
mkdir_and_rename(path_opt.pop('results_root'))
|
||||
for key, path in path_opt.items():
|
||||
if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key):
|
||||
continue
|
||||
else:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
|
||||
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
||||
"""Scan a directory to find the interested files.
|
||||
|
||||
Args:
|
||||
dir_path (str): Path of the directory.
|
||||
suffix (str | tuple(str), optional): File suffix that we are
|
||||
interested in. Default: None.
|
||||
recursive (bool, optional): If set to True, recursively scan the
|
||||
directory. Default: False.
|
||||
full_path (bool, optional): If set to True, include the dir_path.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
A generator for all the interested files with relative paths.
|
||||
"""
|
||||
|
||||
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
||||
raise TypeError('"suffix" must be a string or tuple of strings')
|
||||
|
||||
root = dir_path
|
||||
|
||||
def _scandir(dir_path, suffix, recursive):
|
||||
for entry in os.scandir(dir_path):
|
||||
if not entry.name.startswith('.') and entry.is_file():
|
||||
if full_path:
|
||||
return_path = entry.path
|
||||
else:
|
||||
return_path = osp.relpath(entry.path, root)
|
||||
|
||||
if suffix is None:
|
||||
yield return_path
|
||||
elif return_path.endswith(suffix):
|
||||
yield return_path
|
||||
else:
|
||||
if recursive:
|
||||
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
|
||||
else:
|
||||
continue
|
||||
|
||||
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
||||
|
||||
|
||||
def check_resume(opt, resume_iter):
|
||||
"""Check resume states and pretrain_network paths.
|
||||
|
||||
Args:
|
||||
opt (dict): Options.
|
||||
resume_iter (int): Resume iteration.
|
||||
"""
|
||||
if opt['path']['resume_state']:
|
||||
# get all the networks
|
||||
networks = [key for key in opt.keys() if key.startswith('network_')]
|
||||
flag_pretrain = False
|
||||
for network in networks:
|
||||
if opt['path'].get(f'pretrain_{network}') is not None:
|
||||
flag_pretrain = True
|
||||
if flag_pretrain:
|
||||
print('pretrain_network path will be ignored during resuming.')
|
||||
# set pretrained model paths
|
||||
for network in networks:
|
||||
name = f'pretrain_{network}'
|
||||
basename = network.replace('network_', '')
|
||||
if opt['path'].get('ignore_resume_networks') is None or (network
|
||||
not in opt['path']['ignore_resume_networks']):
|
||||
opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
|
||||
print(f"Set {name} to {opt['path'][name]}")
|
||||
|
||||
# change param_key to params in resume
|
||||
param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')]
|
||||
for param_key in param_keys:
|
||||
if opt['path'][param_key] == 'params_ema':
|
||||
opt['path'][param_key] = 'params'
|
||||
print(f'Set {param_key} to params')
|
||||
|
||||
|
||||
def sizeof_fmt(size, suffix='B'):
|
||||
"""Get human readable file size.
|
||||
|
||||
Args:
|
||||
size (int): File size.
|
||||
suffix (str): Suffix. Default: 'B'.
|
||||
|
||||
Return:
|
||||
str: Formatted file size.
|
||||
"""
|
||||
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
|
||||
if abs(size) < 1024.0:
|
||||
return f'{size:3.1f} {unit}{suffix}'
|
||||
size /= 1024.0
|
||||
return f'{size:3.1f} Y{suffix}'
|
||||
210
basicsr/utils/options.py
Normal file
210
basicsr/utils/options.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import yaml
|
||||
from collections import OrderedDict
|
||||
from os import path as osp
|
||||
|
||||
from basicsr.utils import set_random_seed
|
||||
from basicsr.utils.dist_util import get_dist_info, init_dist, master_only
|
||||
|
||||
|
||||
def ordered_yaml():
|
||||
"""Support OrderedDict for yaml.
|
||||
|
||||
Returns:
|
||||
tuple: yaml Loader and Dumper.
|
||||
"""
|
||||
try:
|
||||
from yaml import CDumper as Dumper
|
||||
from yaml import CLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import Dumper, Loader
|
||||
|
||||
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
|
||||
|
||||
def dict_representer(dumper, data):
|
||||
return dumper.represent_dict(data.items())
|
||||
|
||||
def dict_constructor(loader, node):
|
||||
return OrderedDict(loader.construct_pairs(node))
|
||||
|
||||
Dumper.add_representer(OrderedDict, dict_representer)
|
||||
Loader.add_constructor(_mapping_tag, dict_constructor)
|
||||
return Loader, Dumper
|
||||
|
||||
|
||||
def yaml_load(f):
|
||||
"""Load yaml file or string.
|
||||
|
||||
Args:
|
||||
f (str): File path or a python string.
|
||||
|
||||
Returns:
|
||||
dict: Loaded dict.
|
||||
"""
|
||||
if os.path.isfile(f):
|
||||
with open(f, 'r') as f:
|
||||
return yaml.load(f, Loader=ordered_yaml()[0])
|
||||
else:
|
||||
return yaml.load(f, Loader=ordered_yaml()[0])
|
||||
|
||||
|
||||
def dict2str(opt, indent_level=1):
|
||||
"""dict to string for printing options.
|
||||
|
||||
Args:
|
||||
opt (dict): Option dict.
|
||||
indent_level (int): Indent level. Default: 1.
|
||||
|
||||
Return:
|
||||
(str): Option string for printing.
|
||||
"""
|
||||
msg = '\n'
|
||||
for k, v in opt.items():
|
||||
if isinstance(v, dict):
|
||||
msg += ' ' * (indent_level * 2) + k + ':['
|
||||
msg += dict2str(v, indent_level + 1)
|
||||
msg += ' ' * (indent_level * 2) + ']\n'
|
||||
else:
|
||||
msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
|
||||
return msg
|
||||
|
||||
|
||||
def _postprocess_yml_value(value):
|
||||
# None
|
||||
if value == '~' or value.lower() == 'none':
|
||||
return None
|
||||
# bool
|
||||
if value.lower() == 'true':
|
||||
return True
|
||||
elif value.lower() == 'false':
|
||||
return False
|
||||
# !!float number
|
||||
if value.startswith('!!float'):
|
||||
return float(value.replace('!!float', ''))
|
||||
# number
|
||||
if value.isdigit():
|
||||
return int(value)
|
||||
elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
|
||||
return float(value)
|
||||
# list
|
||||
if value.startswith('['):
|
||||
return eval(value)
|
||||
# str
|
||||
return value
|
||||
|
||||
|
||||
def parse_options(root_path, is_train=True):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
|
||||
parser.add_argument('--auto_resume', action='store_true')
|
||||
parser.add_argument('--debug', action='store_true')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
parser.add_argument(
|
||||
'--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
|
||||
args = parser.parse_args()
|
||||
|
||||
# parse yml to dict
|
||||
opt = yaml_load(args.opt)
|
||||
|
||||
# distributed settings
|
||||
if args.launcher == 'none':
|
||||
opt['dist'] = False
|
||||
print('Disable distributed.', flush=True)
|
||||
else:
|
||||
opt['dist'] = True
|
||||
if args.launcher == 'slurm' and 'dist_params' in opt:
|
||||
init_dist(args.launcher, **opt['dist_params'])
|
||||
else:
|
||||
init_dist(args.launcher)
|
||||
opt['rank'], opt['world_size'] = get_dist_info()
|
||||
|
||||
# random seed
|
||||
seed = opt.get('manual_seed')
|
||||
if seed is None:
|
||||
seed = random.randint(1, 10000)
|
||||
opt['manual_seed'] = seed
|
||||
set_random_seed(seed + opt['rank'])
|
||||
|
||||
# force to update yml options
|
||||
if args.force_yml is not None:
|
||||
for entry in args.force_yml:
|
||||
# now do not support creating new keys
|
||||
keys, value = entry.split('=')
|
||||
keys, value = keys.strip(), value.strip()
|
||||
value = _postprocess_yml_value(value)
|
||||
eval_str = 'opt'
|
||||
for key in keys.split(':'):
|
||||
eval_str += f'["{key}"]'
|
||||
eval_str += '=value'
|
||||
# using exec function
|
||||
exec(eval_str)
|
||||
|
||||
opt['auto_resume'] = args.auto_resume
|
||||
opt['is_train'] = is_train
|
||||
|
||||
# debug setting
|
||||
if args.debug and not opt['name'].startswith('debug'):
|
||||
opt['name'] = 'debug_' + opt['name']
|
||||
|
||||
if opt['num_gpu'] == 'auto':
|
||||
opt['num_gpu'] = torch.cuda.device_count()
|
||||
|
||||
# datasets
|
||||
for phase, dataset in opt['datasets'].items():
|
||||
# for multiple datasets, e.g., val_1, val_2; test_1, test_2
|
||||
phase = phase.split('_')[0]
|
||||
dataset['phase'] = phase
|
||||
if 'scale' in opt:
|
||||
dataset['scale'] = opt['scale']
|
||||
if dataset.get('dataroot_gt') is not None:
|
||||
dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
|
||||
if dataset.get('dataroot_lq') is not None:
|
||||
dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
|
||||
|
||||
# paths
|
||||
for key, val in opt['path'].items():
|
||||
if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
|
||||
opt['path'][key] = osp.expanduser(val)
|
||||
|
||||
if is_train:
|
||||
experiments_root = osp.join(root_path, 'experiments', opt['name'])
|
||||
opt['path']['experiments_root'] = experiments_root
|
||||
opt['path']['models'] = osp.join(experiments_root, 'models')
|
||||
opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
|
||||
opt['path']['log'] = experiments_root
|
||||
opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
|
||||
|
||||
# change some options for debug mode
|
||||
if 'debug' in opt['name']:
|
||||
if 'val' in opt:
|
||||
opt['val']['val_freq'] = 8
|
||||
opt['logger']['print_freq'] = 1
|
||||
opt['logger']['save_checkpoint_freq'] = 8
|
||||
else: # test
|
||||
results_root = osp.join(root_path, 'results', opt['name'])
|
||||
opt['path']['results_root'] = results_root
|
||||
opt['path']['log'] = results_root
|
||||
opt['path']['visualization'] = osp.join(results_root, 'visualization')
|
||||
|
||||
return opt, args
|
||||
|
||||
|
||||
@master_only
|
||||
def copy_opt_file(opt_file, experiments_root):
|
||||
# copy the yml file to the experiment root
|
||||
import sys
|
||||
import time
|
||||
from shutil import copyfile
|
||||
cmd = ' '.join(sys.argv)
|
||||
filename = osp.join(experiments_root, osp.basename(opt_file))
|
||||
copyfile(opt_file, filename)
|
||||
|
||||
with open(filename, 'r+') as f:
|
||||
lines = f.readlines()
|
||||
lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
|
||||
f.seek(0)
|
||||
f.writelines(lines)
|
||||
83
basicsr/utils/plot_util.py
Normal file
83
basicsr/utils/plot_util.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import re
|
||||
|
||||
|
||||
def read_data_from_tensorboard(log_path, tag):
|
||||
"""Get raw data (steps and values) from tensorboard events.
|
||||
|
||||
Args:
|
||||
log_path (str): Path to the tensorboard log.
|
||||
tag (str): tag to be read.
|
||||
"""
|
||||
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
|
||||
|
||||
# tensorboard event
|
||||
event_acc = EventAccumulator(log_path)
|
||||
event_acc.Reload()
|
||||
scalar_list = event_acc.Tags()['scalars']
|
||||
print('tag list: ', scalar_list)
|
||||
steps = [int(s.step) for s in event_acc.Scalars(tag)]
|
||||
values = [s.value for s in event_acc.Scalars(tag)]
|
||||
return steps, values
|
||||
|
||||
|
||||
def read_data_from_txt_2v(path, pattern, step_one=False):
|
||||
"""Read data from txt with 2 returned values (usually [step, value]).
|
||||
|
||||
Args:
|
||||
path (str): path to the txt file.
|
||||
pattern (str): re (regular expression) pattern.
|
||||
step_one (bool): add 1 to steps. Default: False.
|
||||
"""
|
||||
with open(path) as f:
|
||||
lines = f.readlines()
|
||||
lines = [line.strip() for line in lines]
|
||||
steps = []
|
||||
values = []
|
||||
|
||||
pattern = re.compile(pattern)
|
||||
for line in lines:
|
||||
match = pattern.match(line)
|
||||
if match:
|
||||
steps.append(int(match.group(1)))
|
||||
values.append(float(match.group(2)))
|
||||
if step_one:
|
||||
steps = [v + 1 for v in steps]
|
||||
return steps, values
|
||||
|
||||
|
||||
def read_data_from_txt_1v(path, pattern):
|
||||
"""Read data from txt with 1 returned values.
|
||||
|
||||
Args:
|
||||
path (str): path to the txt file.
|
||||
pattern (str): re (regular expression) pattern.
|
||||
"""
|
||||
with open(path) as f:
|
||||
lines = f.readlines()
|
||||
lines = [line.strip() for line in lines]
|
||||
data = []
|
||||
|
||||
pattern = re.compile(pattern)
|
||||
for line in lines:
|
||||
match = pattern.match(line)
|
||||
if match:
|
||||
data.append(float(match.group(1)))
|
||||
return data
|
||||
|
||||
|
||||
def smooth_data(values, smooth_weight):
|
||||
""" Smooth data using 1st-order IIR low-pass filter (what tensorflow does).
|
||||
|
||||
Reference: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704 # noqa: E501
|
||||
|
||||
Args:
|
||||
values (list): A list of values to be smoothed.
|
||||
smooth_weight (float): Smooth weight.
|
||||
"""
|
||||
values_sm = []
|
||||
last_sm_value = values[0]
|
||||
for value in values:
|
||||
value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value
|
||||
values_sm.append(value_sm)
|
||||
last_sm_value = value_sm
|
||||
return values_sm
|
||||
293
basicsr/utils/realesrgan_utils.py
Normal file
293
basicsr/utils/realesrgan_utils.py
Normal file
@@ -0,0 +1,293 @@
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import torch
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from torch.nn import functional as F
|
||||
|
||||
# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
class RealESRGANer():
|
||||
"""A helper class for upsampling images with RealESRGAN.
|
||||
|
||||
Args:
|
||||
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
|
||||
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
|
||||
model (nn.Module): The defined network. Default: None.
|
||||
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
|
||||
input images into tiles, and then process each of them. Finally, they will be merged into one image.
|
||||
0 denotes for do not use tile. Default: 0.
|
||||
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
|
||||
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
|
||||
half (float): Whether to use half precision during inference. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
scale,
|
||||
model_path,
|
||||
model=None,
|
||||
tile=0,
|
||||
tile_pad=10,
|
||||
pre_pad=10,
|
||||
half=False,
|
||||
device=None,
|
||||
gpu_id=None):
|
||||
self.scale = scale
|
||||
self.tile_size = tile
|
||||
self.tile_pad = tile_pad
|
||||
self.pre_pad = pre_pad
|
||||
self.mod_scale = None
|
||||
self.half = half
|
||||
|
||||
# initialize model
|
||||
if gpu_id:
|
||||
self.device = torch.device(
|
||||
f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
|
||||
else:
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
||||
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(
|
||||
url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None)
|
||||
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
|
||||
# prefer to use params_ema
|
||||
if 'params_ema' in loadnet:
|
||||
keyname = 'params_ema'
|
||||
else:
|
||||
keyname = 'params'
|
||||
model.load_state_dict(loadnet[keyname], strict=True)
|
||||
model.eval()
|
||||
self.model = model.to(self.device)
|
||||
if self.half:
|
||||
self.model = self.model.half()
|
||||
|
||||
def pre_process(self, img):
|
||||
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
|
||||
"""
|
||||
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
||||
self.img = img.unsqueeze(0).to(self.device)
|
||||
if self.half:
|
||||
self.img = self.img.half()
|
||||
|
||||
# pre_pad
|
||||
if self.pre_pad != 0:
|
||||
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
||||
# mod pad for divisible borders
|
||||
if self.scale == 2:
|
||||
self.mod_scale = 2
|
||||
elif self.scale == 1:
|
||||
self.mod_scale = 4
|
||||
if self.mod_scale is not None:
|
||||
self.mod_pad_h, self.mod_pad_w = 0, 0
|
||||
_, _, h, w = self.img.size()
|
||||
if (h % self.mod_scale != 0):
|
||||
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
|
||||
if (w % self.mod_scale != 0):
|
||||
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
|
||||
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
||||
|
||||
def process(self):
|
||||
# model inference
|
||||
self.output = self.model(self.img)
|
||||
|
||||
def tile_process(self):
|
||||
"""It will first crop input images to tiles, and then process each tile.
|
||||
Finally, all the processed tiles are merged into one images.
|
||||
|
||||
Modified from: https://github.com/ata4/esrgan-launcher
|
||||
"""
|
||||
batch, channel, height, width = self.img.shape
|
||||
output_height = height * self.scale
|
||||
output_width = width * self.scale
|
||||
output_shape = (batch, channel, output_height, output_width)
|
||||
|
||||
# start with black image
|
||||
self.output = self.img.new_zeros(output_shape)
|
||||
tiles_x = math.ceil(width / self.tile_size)
|
||||
tiles_y = math.ceil(height / self.tile_size)
|
||||
|
||||
# loop over all tiles
|
||||
for y in range(tiles_y):
|
||||
for x in range(tiles_x):
|
||||
# extract tile from input image
|
||||
ofs_x = x * self.tile_size
|
||||
ofs_y = y * self.tile_size
|
||||
# input tile area on total image
|
||||
input_start_x = ofs_x
|
||||
input_end_x = min(ofs_x + self.tile_size, width)
|
||||
input_start_y = ofs_y
|
||||
input_end_y = min(ofs_y + self.tile_size, height)
|
||||
|
||||
# input tile area on total image with padding
|
||||
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
||||
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
||||
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
||||
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
||||
|
||||
# input tile dimensions
|
||||
input_tile_width = input_end_x - input_start_x
|
||||
input_tile_height = input_end_y - input_start_y
|
||||
tile_idx = y * tiles_x + x + 1
|
||||
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
||||
|
||||
# upscale tile
|
||||
try:
|
||||
with torch.no_grad():
|
||||
output_tile = self.model(input_tile)
|
||||
except RuntimeError as error:
|
||||
print('Error', error)
|
||||
# print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
|
||||
|
||||
# output tile area on total image
|
||||
output_start_x = input_start_x * self.scale
|
||||
output_end_x = input_end_x * self.scale
|
||||
output_start_y = input_start_y * self.scale
|
||||
output_end_y = input_end_y * self.scale
|
||||
|
||||
# output tile area without padding
|
||||
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
||||
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
||||
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
||||
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
||||
|
||||
# put tile into output image
|
||||
self.output[:, :, output_start_y:output_end_y,
|
||||
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
|
||||
output_start_x_tile:output_end_x_tile]
|
||||
|
||||
def post_process(self):
|
||||
# remove extra pad
|
||||
if self.mod_scale is not None:
|
||||
_, _, h, w = self.output.size()
|
||||
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
|
||||
# remove prepad
|
||||
if self.pre_pad != 0:
|
||||
_, _, h, w = self.output.size()
|
||||
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
|
||||
return self.output
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
|
||||
h_input, w_input = img.shape[0:2]
|
||||
# img: numpy
|
||||
img = img.astype(np.float32)
|
||||
if np.max(img) > 256: # 16-bit image
|
||||
max_range = 65535
|
||||
print('\tInput is a 16-bit image')
|
||||
else:
|
||||
max_range = 255
|
||||
img = img / max_range
|
||||
if len(img.shape) == 2: # gray image
|
||||
img_mode = 'L'
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
||||
elif img.shape[2] == 4: # RGBA image with alpha channel
|
||||
img_mode = 'RGBA'
|
||||
alpha = img[:, :, 3]
|
||||
img = img[:, :, 0:3]
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
if alpha_upsampler == 'realesrgan':
|
||||
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
||||
else:
|
||||
img_mode = 'RGB'
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# ------------------- process image (without the alpha channel) ------------------- #
|
||||
self.pre_process(img)
|
||||
if self.tile_size > 0:
|
||||
self.tile_process()
|
||||
else:
|
||||
self.process()
|
||||
output_img = self.post_process()
|
||||
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
||||
if img_mode == 'L':
|
||||
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# ------------------- process the alpha channel if necessary ------------------- #
|
||||
if img_mode == 'RGBA':
|
||||
if alpha_upsampler == 'realesrgan':
|
||||
self.pre_process(alpha)
|
||||
if self.tile_size > 0:
|
||||
self.tile_process()
|
||||
else:
|
||||
self.process()
|
||||
output_alpha = self.post_process()
|
||||
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
||||
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
||||
else: # use the cv2 resize for alpha channel
|
||||
h, w = alpha.shape[0:2]
|
||||
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# merge the alpha channel
|
||||
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
||||
output_img[:, :, 3] = output_alpha
|
||||
|
||||
# ------------------------------ return ------------------------------ #
|
||||
if max_range == 65535: # 16-bit image
|
||||
output = (output_img * 65535.0).round().astype(np.uint16)
|
||||
else:
|
||||
output = (output_img * 255.0).round().astype(np.uint8)
|
||||
|
||||
if outscale is not None and outscale != float(self.scale):
|
||||
output = cv2.resize(
|
||||
output, (
|
||||
int(w_input * outscale),
|
||||
int(h_input * outscale),
|
||||
), interpolation=cv2.INTER_LANCZOS4)
|
||||
|
||||
return output, img_mode
|
||||
|
||||
|
||||
class PrefetchReader(threading.Thread):
|
||||
"""Prefetch images.
|
||||
|
||||
Args:
|
||||
img_list (list[str]): A image list of image paths to be read.
|
||||
num_prefetch_queue (int): Number of prefetch queue.
|
||||
"""
|
||||
|
||||
def __init__(self, img_list, num_prefetch_queue):
|
||||
super().__init__()
|
||||
self.que = queue.Queue(num_prefetch_queue)
|
||||
self.img_list = img_list
|
||||
|
||||
def run(self):
|
||||
for img_path in self.img_list:
|
||||
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||
self.que.put(img)
|
||||
|
||||
self.que.put(None)
|
||||
|
||||
def __next__(self):
|
||||
next_item = self.que.get()
|
||||
if next_item is None:
|
||||
raise StopIteration
|
||||
return next_item
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
|
||||
class IOConsumer(threading.Thread):
|
||||
|
||||
def __init__(self, opt, que, qid):
|
||||
super().__init__()
|
||||
self._queue = que
|
||||
self.qid = qid
|
||||
self.opt = opt
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
msg = self._queue.get()
|
||||
if isinstance(msg, str) and msg == 'quit':
|
||||
break
|
||||
|
||||
output = msg['output']
|
||||
save_path = msg['save_path']
|
||||
cv2.imwrite(save_path, output)
|
||||
print(f'IO worker {self.qid} is done.')
|
||||
88
basicsr/utils/registry.py
Normal file
88
basicsr/utils/registry.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
|
||||
|
||||
|
||||
class Registry():
|
||||
"""
|
||||
The registry that provides name -> object mapping, to support third-party
|
||||
users' custom modules.
|
||||
|
||||
To create a registry (e.g. a backbone registry):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
BACKBONE_REGISTRY = Registry('BACKBONE')
|
||||
|
||||
To register an object:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
class MyBackbone():
|
||||
...
|
||||
|
||||
Or:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
BACKBONE_REGISTRY.register(MyBackbone)
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""
|
||||
Args:
|
||||
name (str): the name of this registry
|
||||
"""
|
||||
self._name = name
|
||||
self._obj_map = {}
|
||||
|
||||
def _do_register(self, name, obj, suffix=None):
|
||||
if isinstance(suffix, str):
|
||||
name = name + '_' + suffix
|
||||
|
||||
assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
|
||||
f"in '{self._name}' registry!")
|
||||
self._obj_map[name] = obj
|
||||
|
||||
def register(self, obj=None, suffix=None):
|
||||
"""
|
||||
Register the given object under the the name `obj.__name__`.
|
||||
Can be used as either a decorator or not.
|
||||
See docstring of this class for usage.
|
||||
"""
|
||||
if obj is None:
|
||||
# used as a decorator
|
||||
def deco(func_or_class):
|
||||
name = func_or_class.__name__
|
||||
self._do_register(name, func_or_class, suffix)
|
||||
return func_or_class
|
||||
|
||||
return deco
|
||||
|
||||
# used as a function call
|
||||
name = obj.__name__
|
||||
self._do_register(name, obj, suffix)
|
||||
|
||||
def get(self, name, suffix='basicsr'):
|
||||
ret = self._obj_map.get(name)
|
||||
if ret is None:
|
||||
ret = self._obj_map.get(name + '_' + suffix)
|
||||
print(f'Name {name} is not found, use name: {name}_{suffix}!')
|
||||
if ret is None:
|
||||
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
|
||||
return ret
|
||||
|
||||
def __contains__(self, name):
|
||||
return name in self._obj_map
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._obj_map.items())
|
||||
|
||||
def keys(self):
|
||||
return self._obj_map.keys()
|
||||
|
||||
|
||||
DATASET_REGISTRY = Registry('dataset')
|
||||
ARCH_REGISTRY = Registry('arch')
|
||||
MODEL_REGISTRY = Registry('model')
|
||||
LOSS_REGISTRY = Registry('loss')
|
||||
METRIC_REGISTRY = Registry('metric')
|
||||
Reference in New Issue
Block a user