diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 67bf52d..6cf8ac7 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -29,6 +29,9 @@ from x_clip import CLIP def exists(val): return val is not None +def identity(t, *args, **kwargs): + return t + def default(val, d): if exists(val): return val @@ -635,12 +638,14 @@ class DiffusionPriorNetwork(nn.Module): self, dim, num_timesteps = None, + l2norm_output = False, # whether to restrict image embedding output with l2norm at the end (may make it easier to learn?) **kwargs ): super().__init__() self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP self.learned_query = nn.Parameter(torch.randn(dim)) self.causal_transformer = CausalTransformer(dim = dim, **kwargs) + self.l2norm_output = l2norm_output def forward_with_cond_scale( self, @@ -719,7 +724,8 @@ class DiffusionPriorNetwork(nn.Module): pred_image_embed = tokens[..., -1, :] - return pred_image_embed + output_fn = l2norm if self.l2norm_output else identity + return output_fn(pred_image_embed) class DiffusionPrior(BaseGaussianDiffusion): def __init__( diff --git a/setup.py b/setup.py index 4b09fac..d153108 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.90', + version = '0.0.91', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',