mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 10:14:19 +01:00
complete the main contribution of the paper, the diffusion prior network, minus the diffusion training setup
This commit is contained in:
@@ -149,6 +149,56 @@ class Transformer(nn.Module):
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class PriorNetwork(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_timesteps = 1000,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||
self.causal_transformer = Transformer(**kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embed,
|
||||
*,
|
||||
diffusion_timesteps,
|
||||
text_encodings,
|
||||
text_embed,
|
||||
mask = None,
|
||||
):
|
||||
batch = image_embed.shape[0]
|
||||
|
||||
# in section 2.2, last paragraph
|
||||
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
||||
|
||||
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (0, 4), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
time_embed = self.time_embeddings(diffusion_timesteps)
|
||||
|
||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||
|
||||
tokens = torch.cat((
|
||||
text_encodings,
|
||||
text_embed,
|
||||
time_embed,
|
||||
learned_queries
|
||||
), dim = -2)
|
||||
|
||||
tokens = self.causal_transformer(tokens, mask = mask)
|
||||
|
||||
# get learned query, which should predict the image embedding (per DDPM timestep)
|
||||
|
||||
pred_image_embed = tokens[..., -1, :]
|
||||
|
||||
return pred_image_embed
|
||||
|
||||
class DiffusionPrior(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user