add yet another transformer stability measure

This commit is contained in:
Phil Wang
2022-07-12 17:49:16 -07:00
parent 3ee3c56d2a
commit 349aaca56f
3 changed files with 7 additions and 1 deletions

View File

@@ -760,6 +760,7 @@ class CausalTransformer(nn.Module):
dim_head = 64, dim_head = 64,
heads = 8, heads = 8,
ff_mult = 4, ff_mult = 4,
norm_in = False,
norm_out = True, norm_out = True,
attn_dropout = 0., attn_dropout = 0.,
ff_dropout = 0., ff_dropout = 0.,
@@ -768,6 +769,8 @@ class CausalTransformer(nn.Module):
rotary_emb = True rotary_emb = True
): ):
super().__init__() super().__init__()
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
self.rel_pos_bias = RelPosBias(heads = heads) self.rel_pos_bias = RelPosBias(heads = heads)
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
@@ -785,6 +788,8 @@ class CausalTransformer(nn.Module):
def forward(self, x): def forward(self, x):
n, device = x.shape[1], x.device n, device = x.shape[1], x.device
x = self.init_norm(x)
attn_bias = self.rel_pos_bias(n, n + 1, device = device) attn_bias = self.rel_pos_bias(n, n + 1, device = device)
for attn, ff in self.layers: for attn, ff in self.layers:

View File

@@ -137,6 +137,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
dim_head: int = 64 dim_head: int = 64
heads: int = 8 heads: int = 8
ff_mult: int = 4 ff_mult: int = 4
norm_in: bool = False
norm_out: bool = True norm_out: bool = True
attn_dropout: float = 0. attn_dropout: float = 0.
ff_dropout: float = 0. ff_dropout: float = 0.

View File

@@ -1 +1 @@
__version__ = '0.23.1' __version__ = '0.23.2'