mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 02:04:19 +01:00
same for text encodings for decoder ddpm training
This commit is contained in:
@@ -1214,7 +1214,7 @@ class Decoder(nn.Module):
|
||||
|
||||
return img
|
||||
|
||||
def forward(self, image, text = None, image_embed = None, unet_number = None):
|
||||
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
|
||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||
unet_number = default(unet_number, 1)
|
||||
assert 1 <= unet_number <= len(self.unets)
|
||||
@@ -1233,7 +1233,7 @@ class Decoder(nn.Module):
|
||||
if not exists(image_embed):
|
||||
image_embed = self.get_image_embed(image)
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
||||
|
||||
lowres_cond_img = image if index > 0 else None
|
||||
ddpm_image = resize_image_to(image, target_image_size)
|
||||
|
||||
Reference in New Issue
Block a user