from my vision transformer experience, dimension of attention head of 32 is sufficient for image feature maps

This commit is contained in:
Phil Wang
2022-04-20 11:40:24 -07:00
parent b8e8d3c164
commit faebf4c8b8
2 changed files with 15 additions and 9 deletions

View File

@@ -464,11 +464,11 @@ class DiffusionPrior(nn.Module):
net, net,
*, *,
clip, clip,
timesteps=1000, timesteps = 1000,
cond_drop_prob=0.2, cond_drop_prob = 0.2,
loss_type="l1", loss_type = "l1",
predict_x0=True, predict_x0 = True,
beta_schedule="cosine", beta_schedule = "cosine",
): ):
super().__init__() super().__init__()
assert isinstance(clip, CLIP) assert isinstance(clip, CLIP)
@@ -825,6 +825,8 @@ class Unet(nn.Module):
out_dim = None, out_dim = None,
dim_mults=(1, 2, 4, 8), dim_mults=(1, 2, 4, 8),
channels = 3, channels = 3,
attn_dim_head = 32,
attn_heads = 8,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_cond_upsample_mode = 'bilinear', lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1, blur_sigma = 0.1,
@@ -888,6 +890,10 @@ class Unet(nn.Module):
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim)) self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim)) self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
# attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
# layers # layers
self.downs = nn.ModuleList([]) self.downs = nn.ModuleList([])
@@ -901,7 +907,7 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([ self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, norm = ind != 0), ConvNextBlock(dim_in, dim_out, norm = ind != 0),
Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(), Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim), ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity() Downsample(dim_out) if not is_last else nn.Identity()
])) ]))
@@ -909,7 +915,7 @@ class Unet(nn.Module):
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim) self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim) self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
@@ -918,7 +924,7 @@ class Unet(nn.Module):
self.ups.append(nn.ModuleList([ self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim), ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(), Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim), ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
Upsample(dim_in) Upsample(dim_in)
])) ]))

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.30', version = '0.0.31',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',