mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 10:14:19 +01:00
from my vision transformer experience, dimension of attention head of 32 is sufficient for image feature maps
This commit is contained in:
@@ -464,11 +464,11 @@ class DiffusionPrior(nn.Module):
|
|||||||
net,
|
net,
|
||||||
*,
|
*,
|
||||||
clip,
|
clip,
|
||||||
timesteps=1000,
|
timesteps = 1000,
|
||||||
cond_drop_prob=0.2,
|
cond_drop_prob = 0.2,
|
||||||
loss_type="l1",
|
loss_type = "l1",
|
||||||
predict_x0=True,
|
predict_x0 = True,
|
||||||
beta_schedule="cosine",
|
beta_schedule = "cosine",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(clip, CLIP)
|
assert isinstance(clip, CLIP)
|
||||||
@@ -825,6 +825,8 @@ class Unet(nn.Module):
|
|||||||
out_dim = None,
|
out_dim = None,
|
||||||
dim_mults=(1, 2, 4, 8),
|
dim_mults=(1, 2, 4, 8),
|
||||||
channels = 3,
|
channels = 3,
|
||||||
|
attn_dim_head = 32,
|
||||||
|
attn_heads = 8,
|
||||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||||
lowres_cond_upsample_mode = 'bilinear',
|
lowres_cond_upsample_mode = 'bilinear',
|
||||||
blur_sigma = 0.1,
|
blur_sigma = 0.1,
|
||||||
@@ -888,6 +890,10 @@ class Unet(nn.Module):
|
|||||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||||
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
|
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
|
||||||
|
|
||||||
|
# attention related params
|
||||||
|
|
||||||
|
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
||||||
|
|
||||||
# layers
|
# layers
|
||||||
|
|
||||||
self.downs = nn.ModuleList([])
|
self.downs = nn.ModuleList([])
|
||||||
@@ -901,7 +907,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
self.downs.append(nn.ModuleList([
|
self.downs.append(nn.ModuleList([
|
||||||
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
||||||
Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
|
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
|
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
|
||||||
Downsample(dim_out) if not is_last else nn.Identity()
|
Downsample(dim_out) if not is_last else nn.Identity()
|
||||||
]))
|
]))
|
||||||
@@ -909,7 +915,7 @@ class Unet(nn.Module):
|
|||||||
mid_dim = dims[-1]
|
mid_dim = dims[-1]
|
||||||
|
|
||||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None
|
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||||
|
|
||||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||||
@@ -918,7 +924,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
self.ups.append(nn.ModuleList([
|
self.ups.append(nn.ModuleList([
|
||||||
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
|
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
|
||||||
Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
|
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
|
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
|
||||||
Upsample(dim_in)
|
Upsample(dim_in)
|
||||||
]))
|
]))
|
||||||
|
|||||||
Reference in New Issue
Block a user