does not make much sense, as researchers may want to try predicting noise with diffusionprior instead of predicting x0

This commit is contained in:
Phil Wang
2022-05-05 07:37:00 -07:00
parent 1d5dc08810
commit 8518684ae9
2 changed files with 2 additions and 5 deletions

View File

@@ -652,14 +652,12 @@ class DiffusionPriorNetwork(nn.Module):
self, self,
dim, dim,
num_timesteps = None, num_timesteps = None,
l2norm_output = False, # whether to restrict image embedding output with l2norm at the end (may make it easier to learn?)
**kwargs **kwargs
): ):
super().__init__() super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.learned_query = nn.Parameter(torch.randn(dim)) self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs) self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
self.l2norm_output = l2norm_output
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,
@@ -738,8 +736,7 @@ class DiffusionPriorNetwork(nn.Module):
pred_image_embed = tokens[..., -1, :] pred_image_embed = tokens[..., -1, :]
output_fn = l2norm if self.l2norm_output else identity return pred_image_embed
return output_fn(pred_image_embed)
class DiffusionPrior(BaseGaussianDiffusion): class DiffusionPrior(BaseGaussianDiffusion):
def __init__( def __init__(

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.104', version = '0.0.105',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',