mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
bring in blur, as it will be used somewhere in the cascading DDPM in the decoder eventually, once i figure it out
This commit is contained in:
@@ -11,6 +11,8 @@ from einops.layers.torch import Rearrange
|
||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||
from einops_exts.torch import EinopsToAndFrom
|
||||
|
||||
from kornia.filters import filter2d
|
||||
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
|
||||
# use x-clip
|
||||
@@ -731,6 +733,17 @@ class Unet(nn.Module):
|
||||
|
||||
return self.final_conv(x)
|
||||
|
||||
class Blur(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
filt = torch.Tensor([1, 2, 1])
|
||||
self.register_buffer('filt', filt)
|
||||
|
||||
def forward(self, x):
|
||||
filt = self.filt
|
||||
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
|
||||
return filter2d(x, filt, normalized = True)
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user