mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
bring in blur, as it will be used somewhere in the cascading DDPM in the decoder eventually, once i figure it out
This commit is contained in:
@@ -11,6 +11,8 @@ from einops.layers.torch import Rearrange
|
|||||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||||
from einops_exts.torch import EinopsToAndFrom
|
from einops_exts.torch import EinopsToAndFrom
|
||||||
|
|
||||||
|
from kornia.filters import filter2d
|
||||||
|
|
||||||
from dalle2_pytorch.tokenizer import tokenizer
|
from dalle2_pytorch.tokenizer import tokenizer
|
||||||
|
|
||||||
# use x-clip
|
# use x-clip
|
||||||
@@ -731,6 +733,17 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
return self.final_conv(x)
|
return self.final_conv(x)
|
||||||
|
|
||||||
|
class Blur(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
filt = torch.Tensor([1, 2, 1])
|
||||||
|
self.register_buffer('filt', filt)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
filt = self.filt
|
||||||
|
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
|
||||||
|
return filter2d(x, filt, normalized = True)
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user