mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
add learned padding tokens, same strategy as dalle1, for diffusion prior, and get rid of masking in causal transformer
This commit is contained in:
@@ -782,17 +782,13 @@ class CausalTransformer(nn.Module):
|
|||||||
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||||
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
|
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
|
||||||
|
|
||||||
def forward(
|
def forward(self, x):
|
||||||
self,
|
|
||||||
x,
|
|
||||||
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
|
||||||
):
|
|
||||||
n, device = x.shape[1], x.device
|
n, device = x.shape[1], x.device
|
||||||
|
|
||||||
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
||||||
|
|
||||||
for attn, ff in self.layers:
|
for attn, ff in self.layers:
|
||||||
x = attn(x, mask = mask, attn_bias = attn_bias) + x
|
x = attn(x, attn_bias = attn_bias) + x
|
||||||
x = ff(x) + x
|
x = ff(x) + x
|
||||||
|
|
||||||
out = self.norm(x)
|
out = self.norm(x)
|
||||||
@@ -806,7 +802,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
num_time_embeds = 1,
|
num_time_embeds = 1,
|
||||||
num_image_embeds = 1,
|
num_image_embeds = 1,
|
||||||
num_text_embeds = 1,
|
num_text_embeds = 1,
|
||||||
attend_all_text_encodings = True,
|
max_text_len = 256,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -832,7 +828,10 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||||
|
|
||||||
self.attend_all_text_encodings = attend_all_text_encodings
|
# dalle1 learned padding strategy
|
||||||
|
|
||||||
|
self.max_text_len = max_text_len
|
||||||
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
|
||||||
|
|
||||||
def forward_with_cond_scale(
|
def forward_with_cond_scale(
|
||||||
self,
|
self,
|
||||||
@@ -872,11 +871,28 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
|
|
||||||
if not exists(text_encodings):
|
if not exists(text_encodings):
|
||||||
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
|
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
|
||||||
|
|
||||||
|
mask = torch.any(text_encodings != 0., dim = -1)
|
||||||
|
|
||||||
if self.attend_all_text_encodings:
|
# replace any padding in the text encodings with learned padding tokens unique across position
|
||||||
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
|
|
||||||
else:
|
text_encodings = text_encodings[:, :self.max_text_len]
|
||||||
mask = torch.any(text_encodings != 0., dim = -1)
|
mask = mask[:, :self.max_text_len]
|
||||||
|
|
||||||
|
text_len = text_encodings.shape[-2]
|
||||||
|
remainder = self.max_text_len - text_len
|
||||||
|
|
||||||
|
if remainder > 0:
|
||||||
|
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
|
||||||
|
mask = F.pad(mask, (0, remainder), value = False)
|
||||||
|
|
||||||
|
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
|
||||||
|
|
||||||
|
text_encodings = torch.where(
|
||||||
|
rearrange(mask, 'b n -> b n 1'),
|
||||||
|
text_encodings,
|
||||||
|
null_text_embeds
|
||||||
|
)
|
||||||
|
|
||||||
# classifier free guidance
|
# classifier free guidance
|
||||||
|
|
||||||
@@ -910,7 +926,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
|
|
||||||
# attend
|
# attend
|
||||||
|
|
||||||
tokens = self.causal_transformer(tokens, mask = mask)
|
tokens = self.causal_transformer(tokens)
|
||||||
|
|
||||||
# get learned query, which should predict the image embedding (per DDPM timestep)
|
# get learned query, which should predict the image embedding (per DDPM timestep)
|
||||||
|
|
||||||
|
|||||||
@@ -129,11 +129,11 @@ class AdapterConfig(BaseModel):
|
|||||||
class DiffusionPriorNetworkConfig(BaseModel):
|
class DiffusionPriorNetworkConfig(BaseModel):
|
||||||
dim: int
|
dim: int
|
||||||
depth: int
|
depth: int
|
||||||
|
max_text_len: int = None
|
||||||
num_timesteps: int = None
|
num_timesteps: int = None
|
||||||
num_time_embeds: int = 1
|
num_time_embeds: int = 1
|
||||||
num_image_embeds: int = 1
|
num_image_embeds: int = 1
|
||||||
num_text_embeds: int = 1
|
num_text_embeds: int = 1
|
||||||
attend_all_text_encodings: bool = True
|
|
||||||
dim_head: int = 64
|
dim_head: int = 64
|
||||||
heads: int = 8
|
heads: int = 8
|
||||||
ff_mult: int = 4
|
ff_mult: int = 4
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.22.3'
|
__version__ = '0.23.1'
|
||||||
|
|||||||
Reference in New Issue
Block a user