mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
fix readme and a small bug in DALLE2 class
This commit is contained in:
@@ -371,6 +371,7 @@ loss.backward()
|
|||||||
unet1 = Unet(
|
unet1 = Unet(
|
||||||
dim = 128,
|
dim = 128,
|
||||||
image_embed_dim = 512,
|
image_embed_dim = 512,
|
||||||
|
text_embed_dim = 512,
|
||||||
cond_dim = 128,
|
cond_dim = 128,
|
||||||
channels = 3,
|
channels = 3,
|
||||||
dim_mults=(1, 2, 4, 8),
|
dim_mults=(1, 2, 4, 8),
|
||||||
|
|||||||
@@ -2938,7 +2938,7 @@ class DALLE2(nn.Module):
|
|||||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
|
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
|
||||||
|
|
||||||
text_cond = text if self.decoder_need_text_cond else None
|
text_cond = text if self.decoder_need_text_cond else None
|
||||||
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
|
images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)
|
||||||
|
|
||||||
if return_pil_images:
|
if return_pil_images:
|
||||||
images = list(map(self.to_pil, images.unbind(dim = 0)))
|
images = list(map(self.to_pil, images.unbind(dim = 0)))
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.2.0'
|
__version__ = '1.2.1'
|
||||||
|
|||||||
Reference in New Issue
Block a user