mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 02:04:19 +01:00
complete the main contribution of the paper, the diffusion prior network, minus the diffusion training setup
This commit is contained in:
@@ -149,6 +149,56 @@ class Transformer(nn.Module):
|
|||||||
|
|
||||||
return self.norm(x)
|
return self.norm(x)
|
||||||
|
|
||||||
|
class PriorNetwork(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_timesteps = 1000,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||||
|
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||||
|
self.causal_transformer = Transformer(**kwargs)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
image_embed,
|
||||||
|
*,
|
||||||
|
diffusion_timesteps,
|
||||||
|
text_encodings,
|
||||||
|
text_embed,
|
||||||
|
mask = None,
|
||||||
|
):
|
||||||
|
batch = image_embed.shape[0]
|
||||||
|
|
||||||
|
# in section 2.2, last paragraph
|
||||||
|
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
||||||
|
|
||||||
|
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
mask = F.pad(mask, (0, 4), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||||
|
|
||||||
|
time_embed = self.time_embeddings(diffusion_timesteps)
|
||||||
|
|
||||||
|
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||||
|
|
||||||
|
tokens = torch.cat((
|
||||||
|
text_encodings,
|
||||||
|
text_embed,
|
||||||
|
time_embed,
|
||||||
|
learned_queries
|
||||||
|
), dim = -2)
|
||||||
|
|
||||||
|
tokens = self.causal_transformer(tokens, mask = mask)
|
||||||
|
|
||||||
|
# get learned query, which should predict the image embedding (per DDPM timestep)
|
||||||
|
|
||||||
|
pred_image_embed = tokens[..., -1, :]
|
||||||
|
|
||||||
|
return pred_image_embed
|
||||||
|
|
||||||
class DiffusionPrior(nn.Module):
|
class DiffusionPrior(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user