complete the main contribution of the paper, the diffusion prior network, minus the diffusion training setup

This commit is contained in:
Phil Wang
2022-04-12 11:43:59 -07:00
parent 83aabd42ca
commit fd38eb83c4
2 changed files with 51 additions and 0 deletions

View File

@@ -149,6 +149,56 @@ class Transformer(nn.Module):
return self.norm(x)
class PriorNetwork(nn.Module):
def __init__(
self,
dim,
num_timesteps = 1000,
**kwargs
):
super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = Transformer(**kwargs)
def forward(
self,
image_embed,
*,
diffusion_timesteps,
text_encodings,
text_embed,
mask = None,
):
batch = image_embed.shape[0]
# in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
if exists(mask):
mask = F.pad(mask, (0, 4), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.time_embeddings(diffusion_timesteps)
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
tokens = torch.cat((
text_encodings,
text_embed,
time_embed,
learned_queries
), dim = -2)
tokens = self.causal_transformer(tokens, mask = mask)
# get learned query, which should predict the image embedding (per DDPM timestep)
pred_image_embed = tokens[..., -1, :]
return pred_image_embed
class DiffusionPrior(nn.Module):
def __init__(
self,

View File

@@ -28,6 +28,7 @@ setup(
'pillow',
'torch>=1.10',
'torchvision',
'tqdm',
'x-clip>=0.4.1',
'youtokentome'
],