From 4ab527e779ba0e9266bd14942f2641fd0e8b1ab1 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 27 Apr 2022 20:11:43 -0700 Subject: [PATCH] some extra asserts for text encoding of diffusion prior and decoder --- dalle2_pytorch/dalle2_pytorch.py | 6 ++++-- setup.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 2a0477c..bbb9a53 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -783,6 +783,7 @@ class DiffusionPrior(BaseGaussianDiffusion): text_cond = dict(text_embed = text_embed) if self.condition_on_text_encodings: + assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified' text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask} # timestep conditioning from ddpm @@ -792,8 +793,7 @@ class DiffusionPrior(BaseGaussianDiffusion): # calculate forward loss - loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs) - return loss + return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs) # decoder @@ -1418,6 +1418,7 @@ class Decoder(BaseGaussianDiffusion): _, text_encodings = self.clip.embed_text(text) assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' + assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' img = None @@ -1485,6 +1486,7 @@ class Decoder(BaseGaussianDiffusion): _, text_encodings = self.clip.embed_text(text) assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' + assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None image = resize_image_to(image, target_image_size) diff --git a/setup.py b/setup.py index 03dc703..636e449 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.59', + version = '0.0.60', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',