mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
add yet another transformer stability measure
This commit is contained in:
@@ -760,6 +760,7 @@ class CausalTransformer(nn.Module):
|
|||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
heads = 8,
|
heads = 8,
|
||||||
ff_mult = 4,
|
ff_mult = 4,
|
||||||
|
norm_in = False,
|
||||||
norm_out = True,
|
norm_out = True,
|
||||||
attn_dropout = 0.,
|
attn_dropout = 0.,
|
||||||
ff_dropout = 0.,
|
ff_dropout = 0.,
|
||||||
@@ -768,6 +769,8 @@ class CausalTransformer(nn.Module):
|
|||||||
rotary_emb = True
|
rotary_emb = True
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
|
||||||
|
|
||||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||||
|
|
||||||
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
|
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
|
||||||
@@ -785,6 +788,8 @@ class CausalTransformer(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
n, device = x.shape[1], x.device
|
n, device = x.shape[1], x.device
|
||||||
|
|
||||||
|
x = self.init_norm(x)
|
||||||
|
|
||||||
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
||||||
|
|
||||||
for attn, ff in self.layers:
|
for attn, ff in self.layers:
|
||||||
|
|||||||
@@ -137,6 +137,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
|
|||||||
dim_head: int = 64
|
dim_head: int = 64
|
||||||
heads: int = 8
|
heads: int = 8
|
||||||
ff_mult: int = 4
|
ff_mult: int = 4
|
||||||
|
norm_in: bool = False
|
||||||
norm_out: bool = True
|
norm_out: bool = True
|
||||||
attn_dropout: float = 0.
|
attn_dropout: float = 0.
|
||||||
ff_dropout: float = 0.
|
ff_dropout: float = 0.
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.23.1'
|
__version__ = '0.23.2'
|
||||||
|
|||||||
Reference in New Issue
Block a user