Compare commits

..

2 Commits

3 changed files with 8 additions and 2 deletions

View File

@@ -760,6 +760,7 @@ class CausalTransformer(nn.Module):
dim_head = 64,
heads = 8,
ff_mult = 4,
norm_in = False,
norm_out = True,
attn_dropout = 0.,
ff_dropout = 0.,
@@ -768,6 +769,8 @@ class CausalTransformer(nn.Module):
rotary_emb = True
):
super().__init__()
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
self.rel_pos_bias = RelPosBias(heads = heads)
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
@@ -785,6 +788,8 @@ class CausalTransformer(nn.Module):
def forward(self, x):
n, device = x.shape[1], x.device
x = self.init_norm(x)
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
for attn, ff in self.layers:
@@ -884,7 +889,7 @@ class DiffusionPriorNetwork(nn.Module):
if remainder > 0:
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
mask = F.pad(mask, (0, remainder), value = 0.)
mask = F.pad(mask, (0, remainder), value = False)
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)

View File

@@ -137,6 +137,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
dim_head: int = 64
heads: int = 8
ff_mult: int = 4
norm_in: bool = False
norm_out: bool = True
attn_dropout: float = 0.
ff_dropout: float = 0.

View File

@@ -1 +1 @@
__version__ = '0.23.0'
__version__ = '0.23.2'