mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 23:44:22 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e024971dc3 |
@@ -495,10 +495,12 @@ class ViTEncDec(nn.Module):
|
|||||||
layers = layers
|
layers = layers
|
||||||
),
|
),
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
nn.Linear(dim, dim * 4, bias = False),
|
nn.Linear(dim, dim * 2, bias = False),
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
nn.Linear(dim * 4, input_dim, bias = False),
|
nn.Linear(dim * 2, dim, bias = False),
|
||||||
),
|
),
|
||||||
|
nn.LayerNorm(dim),
|
||||||
|
nn.Linear(dim, input_dim),
|
||||||
RearrangeImage(),
|
RearrangeImage(),
|
||||||
Rearrange('b h w (p1 p2 c) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size)
|
Rearrange('b h w (p1 p2 c) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size)
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user