mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-13 12:04:24 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
349aaca56f | ||
|
|
3ee3c56d2a |
@@ -760,6 +760,7 @@ class CausalTransformer(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
ff_mult = 4,
|
||||
norm_in = False,
|
||||
norm_out = True,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
@@ -768,6 +769,8 @@ class CausalTransformer(nn.Module):
|
||||
rotary_emb = True
|
||||
):
|
||||
super().__init__()
|
||||
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
|
||||
|
||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||
|
||||
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
|
||||
@@ -785,6 +788,8 @@ class CausalTransformer(nn.Module):
|
||||
def forward(self, x):
|
||||
n, device = x.shape[1], x.device
|
||||
|
||||
x = self.init_norm(x)
|
||||
|
||||
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
@@ -884,7 +889,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
if remainder > 0:
|
||||
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
|
||||
mask = F.pad(mask, (0, remainder), value = 0.)
|
||||
mask = F.pad(mask, (0, remainder), value = False)
|
||||
|
||||
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
|
||||
|
||||
|
||||
@@ -137,6 +137,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
|
||||
dim_head: int = 64
|
||||
heads: int = 8
|
||||
ff_mult: int = 4
|
||||
norm_in: bool = False
|
||||
norm_out: bool = True
|
||||
attn_dropout: float = 0.
|
||||
ff_dropout: float = 0.
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.23.0'
|
||||
__version__ = '0.23.2'
|
||||
|
||||
Reference in New Issue
Block a user