mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
add yet another transformer stability measure
This commit is contained in:
@@ -760,6 +760,7 @@ class CausalTransformer(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
ff_mult = 4,
|
||||
norm_in = False,
|
||||
norm_out = True,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
@@ -768,6 +769,8 @@ class CausalTransformer(nn.Module):
|
||||
rotary_emb = True
|
||||
):
|
||||
super().__init__()
|
||||
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
|
||||
|
||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||
|
||||
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
|
||||
@@ -785,6 +788,8 @@ class CausalTransformer(nn.Module):
|
||||
def forward(self, x):
|
||||
n, device = x.shape[1], x.device
|
||||
|
||||
x = self.init_norm(x)
|
||||
|
||||
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
|
||||
@@ -137,6 +137,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
|
||||
dim_head: int = 64
|
||||
heads: int = 8
|
||||
ff_mult: int = 4
|
||||
norm_in: bool = False
|
||||
norm_out: bool = True
|
||||
attn_dropout: float = 0.
|
||||
ff_dropout: float = 0.
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.23.1'
|
||||
__version__ = '0.23.2'
|
||||
|
||||
Reference in New Issue
Block a user