remove convnext blocks, they are illsuited for generative work, validated by early experimental results at https://github.com/lucidrains/video-diffusion-pytorch

This commit is contained in:
Phil Wang
2022-05-05 07:07:21 -07:00
parent aec5575d09
commit 896f19786d
4 changed files with 3 additions and 187 deletions

View File

@@ -999,68 +999,6 @@ class ResnetBlock(nn.Module):
h = self.block2(h)
return h + self.res_conv(x)
class ConvNextBlock(nn.Module):
""" https://arxiv.org/abs/2201.03545 """
def __init__(
self,
dim,
dim_out,
*,
cond_dim = None,
time_cond_dim = None,
mult = 2
):
super().__init__()
need_projection = dim != dim_out
self.cross_attn = None
if exists(cond_dim):
self.cross_attn = EinopsToAndFrom(
'b c h w',
'b (h w) c',
CrossAttention(
dim = dim,
context_dim = cond_dim
)
)
self.time_mlp = None
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.GELU(),
nn.Linear(time_cond_dim, dim)
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
inner_dim = int(dim_out * mult)
self.net = nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 3, padding = 1),
nn.GELU(),
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
def forward(self, x, cond = None, time = None):
h = self.ds_conv(x)
if exists(time) and exists(self.time_mlp):
t = self.time_mlp(time)
h = rearrange(t, 'b c -> b c 1 1') + h
if exists(self.cross_attn):
assert exists(cond)
h = self.cross_attn(h, context = cond) + h
h = self.net(h)
return h + self.res_conv(x)
class CrossAttention(nn.Module):
def __init__(
self,
@@ -1200,7 +1138,6 @@ class Unet(nn.Module):
init_conv_kernel_size = 7,
block_type = 'resnet',
block_resnet_groups = 8,
block_convnext_mult = 2,
**kwargs
):
super().__init__()
@@ -1276,14 +1213,9 @@ class Unet(nn.Module):
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
# whether to use resnet or the (improved?) convnext blocks
# resnet block klass
if block_type == 'resnet':
block_klass = partial(ResnetBlock, groups = block_resnet_groups)
elif block_type == 'convnext':
block_klass = partial(ConvNextBlock, mult = block_convnext_mult)
else:
raise ValueError(f'unimplemented block type {block_type}')
block_klass = partial(ResnetBlock, groups = block_resnet_groups)
# layers