diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 700b18a..4840548 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -11,6 +11,8 @@ from einops.layers.torch import Rearrange from einops_exts import rearrange_many, repeat_many, check_shape from einops_exts.torch import EinopsToAndFrom +from kornia.filters import filter2d + from dalle2_pytorch.tokenizer import tokenizer # use x-clip @@ -731,6 +733,17 @@ class Unet(nn.Module): return self.final_conv(x) +class Blur(nn.Module): + def __init__(self): + super().__init__() + filt = torch.Tensor([1, 2, 1]) + self.register_buffer('filt', filt) + + def forward(self, x): + filt = self.filt + filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1') + return filter2d(x, filt, normalized = True) + class Decoder(nn.Module): def __init__( self, diff --git a/setup.py b/setup.py index 9d71c34..5949f48 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ setup( 'click', 'einops>=0.4', 'einops-exts>=0.0.3', + 'kornia>=0.5.4', 'pillow', 'torch>=1.10', 'torchvision',