diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 3c25715..82c59ab 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -116,12 +116,14 @@ class Attention(nn.Module): dim, dim_head = 64, heads = 8, - dropout = 0. + dropout = 0., + causal = False ): super().__init__() self.scale = dim_head ** -0.5 inner_dim = dim_head * heads + self.causal = causal self.norm = RMSNorm(dim) self.dropout = nn.Dropout(dropout) @@ -154,8 +156,9 @@ class Attention(nn.Module): mask = rearrange(mask, 'b j -> b 1 1 j') sim = sim.masked_fill(~mask, max_neg_value) - causal_mask = torch.ones((n, n), dtype = torch.bool, device = device).triu(1) - sim = sim.masked_fill(causal_mask, max_neg_value) + if self.causal: + causal_mask = torch.ones((n, n), dtype = torch.bool, device = device).triu(1) + sim = sim.masked_fill(causal_mask, max_neg_value) sim = sim - sim.amax(dim = -1, keepdim = True) attn = sim.softmax(dim = -1) @@ -165,7 +168,7 @@ class Attention(nn.Module): out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) -class Transformer(nn.Module): +class CausalTransformer(nn.Module): def __init__( self, *, @@ -184,7 +187,7 @@ class Transformer(nn.Module): self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout), + Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout), FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) ])) @@ -211,7 +214,7 @@ class DiffusionPriorNetwork(nn.Module): super().__init__() self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP self.learned_query = nn.Parameter(torch.randn(dim)) - self.causal_transformer = Transformer(**kwargs) + self.causal_transformer = CausalTransformer(**kwargs) def forward_with_cond_scale( self,