fix readme and a small bug in DALLE2 class

This commit is contained in:
Phil Wang
2022-07-28 08:33:51 -07:00
parent 07abfcf45b
commit 36fb46a95e
3 changed files with 3 additions and 2 deletions

View File

@@ -371,6 +371,7 @@ loss.backward()
unet1 = Unet( unet1 = Unet(
dim = 128, dim = 128,
image_embed_dim = 512, image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults=(1, 2, 4, 8), dim_mults=(1, 2, 4, 8),

View File

@@ -2938,7 +2938,7 @@ class DALLE2(nn.Module):
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale) image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
text_cond = text if self.decoder_need_text_cond else None text_cond = text if self.decoder_need_text_cond else None
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale) images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)
if return_pil_images: if return_pil_images:
images = list(map(self.to_pil, images.unbind(dim = 0))) images = list(map(self.to_pil, images.unbind(dim = 0)))

View File

@@ -1 +1 @@
__version__ = '1.2.0' __version__ = '1.2.1'