mirror of
https://github.com/aljazceru/InvSR.git
synced 2025-12-17 06:14:22 +01:00
first commit
This commit is contained in:
83
basicsr/utils/img_process_util.py
Normal file
83
basicsr/utils/img_process_util.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def filter2D(img, kernel):
|
||||
"""PyTorch version of cv2.filter2D
|
||||
|
||||
Args:
|
||||
img (Tensor): (b, c, h, w)
|
||||
kernel (Tensor): (b, k, k)
|
||||
"""
|
||||
k = kernel.size(-1)
|
||||
b, c, h, w = img.size()
|
||||
if k % 2 == 1:
|
||||
img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
|
||||
else:
|
||||
raise ValueError('Wrong kernel size')
|
||||
|
||||
ph, pw = img.size()[-2:]
|
||||
|
||||
if kernel.size(0) == 1:
|
||||
# apply the same kernel to all batch images
|
||||
img = img.view(b * c, 1, ph, pw)
|
||||
kernel = kernel.view(1, 1, k, k)
|
||||
return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
|
||||
else:
|
||||
img = img.view(1, b * c, ph, pw)
|
||||
kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
|
||||
return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
|
||||
|
||||
|
||||
def usm_sharp(img, weight=0.5, radius=50, threshold=10):
|
||||
"""USM sharpening.
|
||||
|
||||
Input image: I; Blurry image: B.
|
||||
1. sharp = I + weight * (I - B)
|
||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
||||
3. Blur mask:
|
||||
4. Out = Mask * sharp + (1 - Mask) * I
|
||||
|
||||
|
||||
Args:
|
||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
||||
weight (float): Sharp weight. Default: 1.
|
||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
||||
threshold (int):
|
||||
"""
|
||||
if radius % 2 == 0:
|
||||
radius += 1
|
||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
||||
residual = img - blur
|
||||
mask = np.abs(residual) * 255 > threshold
|
||||
mask = mask.astype('float32')
|
||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
||||
|
||||
sharp = img + weight * residual
|
||||
sharp = np.clip(sharp, 0, 1)
|
||||
return soft_mask * sharp + (1 - soft_mask) * img
|
||||
|
||||
|
||||
class USMSharp(torch.nn.Module):
|
||||
|
||||
def __init__(self, radius=50, sigma=0):
|
||||
super(USMSharp, self).__init__()
|
||||
if radius % 2 == 0:
|
||||
radius += 1
|
||||
self.radius = radius
|
||||
kernel = cv2.getGaussianKernel(radius, sigma)
|
||||
kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
|
||||
self.register_buffer('kernel', kernel)
|
||||
|
||||
def forward(self, img, weight=0.5, threshold=10):
|
||||
blur = filter2D(img, self.kernel)
|
||||
residual = img - blur
|
||||
|
||||
mask = torch.abs(residual) * 255 > threshold
|
||||
mask = mask.float()
|
||||
soft_mask = filter2D(mask, self.kernel)
|
||||
sharp = img + weight * residual
|
||||
sharp = torch.clip(sharp, 0, 1)
|
||||
return soft_mask * sharp + (1 - soft_mask) * img
|
||||
Reference in New Issue
Block a user