provide option to l2norm the output of the diffusion prior

This commit is contained in:
Phil Wang
2022-05-02 09:41:03 -07:00
parent 7ee0ecc388
commit 0fc6c9cdf3
2 changed files with 8 additions and 2 deletions

View File

@@ -29,6 +29,9 @@ from x_clip import CLIP
def exists(val):
return val is not None
def identity(t, *args, **kwargs):
return t
def default(val, d):
if exists(val):
return val
@@ -635,12 +638,14 @@ class DiffusionPriorNetwork(nn.Module):
self,
dim,
num_timesteps = None,
l2norm_output = False, # whether to restrict image embedding output with l2norm at the end (may make it easier to learn?)
**kwargs
):
super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
self.l2norm_output = l2norm_output
def forward_with_cond_scale(
self,
@@ -719,7 +724,8 @@ class DiffusionPriorNetwork(nn.Module):
pred_image_embed = tokens[..., -1, :]
return pred_image_embed
output_fn = l2norm if self.l2norm_output else identity
return output_fn(pred_image_embed)
class DiffusionPrior(BaseGaussianDiffusion):
def __init__(