mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
also allow for image embedding to be passed into the diffusion model, in the case one wants to generate image embedding once and then train multiple unets in one iteration
This commit is contained in:
@@ -1214,7 +1214,7 @@ class Decoder(nn.Module):
|
||||
|
||||
return img
|
||||
|
||||
def forward(self, image, text = None, unet_number = None):
|
||||
def forward(self, image, text = None, image_embed = None, unet_number = None):
|
||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||
unet_number = default(unet_number, 1)
|
||||
assert 1 <= unet_number <= len(self.unets)
|
||||
@@ -1230,7 +1230,9 @@ class Decoder(nn.Module):
|
||||
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
|
||||
image_embed = self.get_image_embed(image)
|
||||
if not exists(image_embed):
|
||||
image_embed = self.get_image_embed(image)
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
|
||||
lowres_cond_img = image if index > 0 else None
|
||||
|
||||
Reference in New Issue
Block a user