also allow for image embedding to be passed into the diffusion model, in the case one wants to generate image embedding once and then train multiple unets in one iteration

This commit is contained in:
Phil Wang
2022-04-18 14:00:38 -07:00
parent a54e309269
commit 6fee4fce6e
2 changed files with 5 additions and 3 deletions

View File

@@ -1214,7 +1214,7 @@ class Decoder(nn.Module):
return img
def forward(self, image, text = None, unet_number = None):
def forward(self, image, text = None, image_embed = None, unet_number = None):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
assert 1 <= unet_number <= len(self.unets)
@@ -1230,7 +1230,9 @@ class Decoder(nn.Module):
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image)
if not exists(image_embed):
image_embed = self.get_image_embed(image)
text_encodings = self.get_text_encodings(text) if exists(text) else None
lowres_cond_img = image if index > 0 else None