diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 1e37539..6fb02e8 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -760,6 +760,7 @@ class CausalTransformer(nn.Module): dim_head = 64, heads = 8, ff_mult = 4, + norm_in = False, norm_out = True, attn_dropout = 0., ff_dropout = 0., @@ -768,6 +769,8 @@ class CausalTransformer(nn.Module): rotary_emb = True ): super().__init__() + self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM + self.rel_pos_bias = RelPosBias(heads = heads) rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None @@ -785,6 +788,8 @@ class CausalTransformer(nn.Module): def forward(self, x): n, device = x.shape[1], x.device + x = self.init_norm(x) + attn_bias = self.rel_pos_bias(n, n + 1, device = device) for attn, ff in self.layers: diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index d2f502e..fc24282 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -137,6 +137,7 @@ class DiffusionPriorNetworkConfig(BaseModel): dim_head: int = 64 heads: int = 8 ff_mult: int = 4 + norm_in: bool = False norm_out: bool = True attn_dropout: float = 0. ff_dropout: float = 0. diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index db714a8..5a983c9 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.23.1' +__version__ = '0.23.2'