diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 6842c7d..0b0db6c 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -374,12 +374,13 @@ class DiffusionPrior(nn.Module): image_encoding = self.clip.visual_transformer(image) image_cls = image_encoding[:, 0] image_embed = self.clip.to_visual_latent(image_cls) - return image_embed + return l2norm(image_embed) def get_text_cond(self, text): text_encodings = self.clip.text_transformer(text) text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:] text_embed = self.clip.to_text_latent(text_cls) + text_embed = l2norm(text_embed) return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0) def q_mean_variance(self, x_start, t): @@ -750,7 +751,7 @@ class Decoder(nn.Module): image_encoding = self.clip.visual_transformer(image) image_cls = image_encoding[:, 0] image_embed = self.clip.to_visual_latent(image_cls) - return image_embed + return l2norm(image_embed) def q_mean_variance(self, x_start, t): mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start diff --git a/setup.py b/setup.py index 1f99d3e..d7a1c96 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.4', + version = '0.0.5', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',