mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
remove convnext blocks, they are illsuited for generative work, validated by early experimental results at https://github.com/lucidrains/video-diffusion-pytorch
This commit is contained in:
@@ -999,68 +999,6 @@ class ResnetBlock(nn.Module):
|
||||
h = self.block2(h)
|
||||
return h + self.res_conv(x)
|
||||
|
||||
class ConvNextBlock(nn.Module):
|
||||
""" https://arxiv.org/abs/2201.03545 """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out,
|
||||
*,
|
||||
cond_dim = None,
|
||||
time_cond_dim = None,
|
||||
mult = 2
|
||||
):
|
||||
super().__init__()
|
||||
need_projection = dim != dim_out
|
||||
|
||||
self.cross_attn = None
|
||||
|
||||
if exists(cond_dim):
|
||||
self.cross_attn = EinopsToAndFrom(
|
||||
'b c h w',
|
||||
'b (h w) c',
|
||||
CrossAttention(
|
||||
dim = dim,
|
||||
context_dim = cond_dim
|
||||
)
|
||||
)
|
||||
|
||||
self.time_mlp = None
|
||||
|
||||
if exists(time_cond_dim):
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.GELU(),
|
||||
nn.Linear(time_cond_dim, dim)
|
||||
)
|
||||
|
||||
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
||||
|
||||
inner_dim = int(dim_out * mult)
|
||||
self.net = nn.Sequential(
|
||||
ChanLayerNorm(dim),
|
||||
nn.Conv2d(dim, inner_dim, 3, padding = 1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
|
||||
)
|
||||
|
||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
|
||||
|
||||
def forward(self, x, cond = None, time = None):
|
||||
h = self.ds_conv(x)
|
||||
|
||||
if exists(time) and exists(self.time_mlp):
|
||||
t = self.time_mlp(time)
|
||||
h = rearrange(t, 'b c -> b c 1 1') + h
|
||||
|
||||
if exists(self.cross_attn):
|
||||
assert exists(cond)
|
||||
h = self.cross_attn(h, context = cond) + h
|
||||
|
||||
h = self.net(h)
|
||||
|
||||
return h + self.res_conv(x)
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1200,7 +1138,6 @@ class Unet(nn.Module):
|
||||
init_conv_kernel_size = 7,
|
||||
block_type = 'resnet',
|
||||
block_resnet_groups = 8,
|
||||
block_convnext_mult = 2,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1276,14 +1213,9 @@ class Unet(nn.Module):
|
||||
|
||||
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
||||
|
||||
# whether to use resnet or the (improved?) convnext blocks
|
||||
# resnet block klass
|
||||
|
||||
if block_type == 'resnet':
|
||||
block_klass = partial(ResnetBlock, groups = block_resnet_groups)
|
||||
elif block_type == 'convnext':
|
||||
block_klass = partial(ConvNextBlock, mult = block_convnext_mult)
|
||||
else:
|
||||
raise ValueError(f'unimplemented block type {block_type}')
|
||||
block_klass = partial(ResnetBlock, groups = block_resnet_groups)
|
||||
|
||||
# layers
|
||||
|
||||
|
||||
Reference in New Issue
Block a user