same for text encodings for decoder ddpm training

This commit is contained in:
Phil Wang
2022-04-18 14:41:02 -07:00
parent 6fee4fce6e
commit 82328f16cd
2 changed files with 3 additions and 3 deletions

View File

@@ -1214,7 +1214,7 @@ class Decoder(nn.Module):
return img
def forward(self, image, text = None, image_embed = None, unet_number = None):
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
assert 1 <= unet_number <= len(self.unets)
@@ -1233,7 +1233,7 @@ class Decoder(nn.Module):
if not exists(image_embed):
image_embed = self.get_image_embed(image)
text_encodings = self.get_text_encodings(text) if exists(text) else None
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
lowres_cond_img = image if index > 0 else None
ddpm_image = resize_image_to(image, target_image_size)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.23',
version = '0.0.24',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',