bring in blur, as it will be used somewhere in the cascading DDPM in the decoder eventually, once i figure it out

This commit is contained in:
Phil Wang
2022-04-14 09:16:09 -07:00
parent e1b0c140f1
commit 97e951221b
2 changed files with 14 additions and 0 deletions

View File

@@ -11,6 +11,8 @@ from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom from einops_exts.torch import EinopsToAndFrom
from kornia.filters import filter2d
from dalle2_pytorch.tokenizer import tokenizer from dalle2_pytorch.tokenizer import tokenizer
# use x-clip # use x-clip
@@ -731,6 +733,17 @@ class Unet(nn.Module):
return self.final_conv(x) return self.final_conv(x)
class Blur(nn.Module):
def __init__(self):
super().__init__()
filt = torch.Tensor([1, 2, 1])
self.register_buffer('filt', filt)
def forward(self, x):
filt = self.filt
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
return filter2d(x, filt, normalized = True)
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__( def __init__(
self, self,

View File

@@ -25,6 +25,7 @@ setup(
'click', 'click',
'einops>=0.4', 'einops>=0.4',
'einops-exts>=0.0.3', 'einops-exts>=0.0.3',
'kornia>=0.5.4',
'pillow', 'pillow',
'torch>=1.10', 'torch>=1.10',
'torchvision', 'torchvision',