From 6fee4fce6ead667cc9d3a3847e7ad0880d8c49c8 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 18 Apr 2022 14:00:38 -0700 Subject: [PATCH] also allow for image embedding to be passed into the diffusion model, in the case one wants to generate image embedding once and then train multiple unets in one iteration --- dalle2_pytorch/dalle2_pytorch.py | 6 ++++-- setup.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 24b4e64..dddc206 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1214,7 +1214,7 @@ class Decoder(nn.Module): return img - def forward(self, image, text = None, unet_number = None): + def forward(self, image, text = None, image_embed = None, unet_number = None): assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' unet_number = default(unet_number, 1) assert 1 <= unet_number <= len(self.unets) @@ -1230,7 +1230,9 @@ class Decoder(nn.Module): times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) - image_embed = self.get_image_embed(image) + if not exists(image_embed): + image_embed = self.get_image_embed(image) + text_encodings = self.get_text_encodings(text) if exists(text) else None lowres_cond_img = image if index > 0 else None diff --git a/setup.py b/setup.py index c1bcccb..00bb232 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.22', + version = '0.0.23', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',