diff --git a/README.md b/README.md
index 2a5ab42..523d46a 100644
--- a/README.md
+++ b/README.md
@@ -911,4 +911,4 @@ Once built, images will be saved to the same directory the command is invoked
}
```
-*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper
+*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper
diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py
index 2ab0e7e..2cb9811 100644
--- a/dalle2_pytorch/dalle2_pytorch.py
+++ b/dalle2_pytorch/dalle2_pytorch.py
@@ -930,6 +930,72 @@ class SinusoidalPosEmb(nn.Module):
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1)
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_out,
+ groups = 8
+ ):
+ super().__init__()
+ self.block = nn.Sequential(
+ nn.Conv2d(dim, dim_out, 3, padding = 1),
+ nn.GroupNorm(groups, dim_out),
+ nn.SiLU()
+ )
+ def forward(self, x):
+ return self.block(x)
+
+class ResnetBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_out,
+ *,
+ cond_dim = None,
+ time_cond_dim = None,
+ groups = 8
+ ):
+ super().__init__()
+
+ self.time_mlp = None
+
+ if exists(time_cond_dim):
+ self.time_mlp = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(time_cond_dim, dim_out)
+ )
+
+ self.cross_attn = None
+
+ if exists(cond_dim):
+ self.cross_attn = EinopsToAndFrom(
+ 'b c h w',
+ 'b (h w) c',
+ CrossAttention(
+ dim = dim_out,
+ context_dim = cond_dim
+ )
+ )
+
+ self.block1 = Block(dim, dim_out, groups = groups)
+ self.block2 = Block(dim_out, dim_out, groups = groups)
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
+
+ def forward(self, x, cond = None, time_emb = None):
+ h = self.block1(x)
+
+ if exists(self.time_mlp) and exists(time_emb):
+ time_emb = self.time_mlp(time_emb)
+ h = rearrange(time_emb, 'b c -> b c 1 1') + h
+
+ if exists(self.cross_attn):
+ assert exists(cond)
+ h = self.cross_attn(h, context = cond) + h
+
+ h = self.block2(h)
+ return h + self.res_conv(x)
+
class ConvNextBlock(nn.Module):
""" https://arxiv.org/abs/2201.03545 """
@@ -940,8 +1006,7 @@ class ConvNextBlock(nn.Module):
*,
cond_dim = None,
time_cond_dim = None,
- mult = 2,
- norm = True
+ mult = 2
):
super().__init__()
need_projection = dim != dim_out
@@ -970,7 +1035,7 @@ class ConvNextBlock(nn.Module):
inner_dim = int(dim_out * mult)
self.net = nn.Sequential(
- ChanLayerNorm(dim) if norm else nn.Identity(),
+ ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 3, padding = 1),
nn.GELU(),
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
@@ -1082,7 +1147,11 @@ class LinearAttention(nn.Module):
self.nonlin = nn.GELU()
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
- self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
+
+ self.to_out = nn.Sequential(
+ nn.Conv2d(inner_dim, dim, 1, bias = False),
+ ChanLayerNorm(dim)
+ )
def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:]
@@ -1125,7 +1194,9 @@ class Unet(nn.Module):
max_text_len = 256,
cond_on_image_embeds = False,
init_dim = None,
- init_conv_kernel_size = 7
+ init_conv_kernel_size = 7,
+ block_type = 'resnet',
+ **kwargs
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
@@ -1200,6 +1271,15 @@ class Unet(nn.Module):
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
+ # whether to use resnet or the (improved?) convnext blocks
+
+ if block_type == 'resnet':
+ block_klass = ResnetBlock
+ elif block_type == 'convnext':
+ block_klass = ConvNextBlock
+ else:
+ raise ValueError(f'unimplemented block type {block_type}')
+
# layers
self.downs = nn.ModuleList([])
@@ -1212,32 +1292,32 @@ class Unet(nn.Module):
layer_cond_dim = cond_dim if not is_first else None
self.downs.append(nn.ModuleList([
- ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0),
+ block_klass(dim_in, dim_out, time_cond_dim = time_cond_dim),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
- ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
+ block_klass(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
- self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
+ self.mid_block1 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
- self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
+ self.mid_block2 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 2)
layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([
- ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
+ block_klass(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
- ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
+ block_klass(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Upsample(dim_in)
]))
out_dim = default(out_dim, channels)
self.final_conv = nn.Sequential(
- ConvNextBlock(dim, dim),
+ block_klass(dim, dim),
nn.Conv2d(dim, out_dim, 1)
)
@@ -1368,10 +1448,10 @@ class Unet(nn.Module):
hiddens = []
- for convnext, sparse_attn, convnext2, downsample in self.downs:
- x = convnext(x, c, t)
+ for block1, sparse_attn, block2, downsample in self.downs:
+ x = block1(x, c, t)
x = sparse_attn(x)
- x = convnext2(x, c, t)
+ x = block2(x, c, t)
hiddens.append(x)
x = downsample(x)
@@ -1382,11 +1462,11 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c, t)
- for convnext, sparse_attn, convnext2, upsample in self.ups:
+ for block1, sparse_attn, block2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
- x = convnext(x, c, t)
+ x = block1(x, c, t)
x = sparse_attn(x)
- x = convnext2(x, c, t)
+ x = block2(x, c, t)
x = upsample(x)
return self.final_conv(x)