mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
fix condition_on_text_encodings in dalle2 orchestrator class, fix readme
This commit is contained in:
@@ -581,7 +581,8 @@ unet1 = Unet(
|
|||||||
image_embed_dim = 512,
|
image_embed_dim = 512,
|
||||||
cond_dim = 128,
|
cond_dim = 128,
|
||||||
channels = 3,
|
channels = 3,
|
||||||
dim_mults=(1, 2, 4, 8)
|
dim_mults=(1, 2, 4, 8),
|
||||||
|
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
unet2 = Unet(
|
unet2 = Unet(
|
||||||
@@ -598,12 +599,11 @@ decoder = Decoder(
|
|||||||
clip = clip,
|
clip = clip,
|
||||||
timesteps = 100,
|
timesteps = 100,
|
||||||
image_cond_drop_prob = 0.1,
|
image_cond_drop_prob = 0.1,
|
||||||
text_cond_drop_prob = 0.5,
|
text_cond_drop_prob = 0.5
|
||||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
for unet_number in (1, 2):
|
for unet_number in (1, 2):
|
||||||
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# do above for many steps
|
# do above for many steps
|
||||||
|
|||||||
@@ -1930,10 +1930,6 @@ class Decoder(nn.Module):
|
|||||||
self.unets.append(one_unet)
|
self.unets.append(one_unet)
|
||||||
self.vaes.append(one_vae.copy_for_eval())
|
self.vaes.append(one_vae.copy_for_eval())
|
||||||
|
|
||||||
# determine from unets whether conditioning on text encoding is needed
|
|
||||||
|
|
||||||
self.condition_on_text_encodings = any([unet.cond_on_text_encodings for unet in self.unets])
|
|
||||||
|
|
||||||
# create noise schedulers per unet
|
# create noise schedulers per unet
|
||||||
|
|
||||||
if not exists(beta_schedule):
|
if not exists(beta_schedule):
|
||||||
@@ -2012,6 +2008,10 @@ class Decoder(nn.Module):
|
|||||||
def device(self):
|
def device(self):
|
||||||
return self._dummy.device
|
return self._dummy.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def condition_on_text_encodings(self):
|
||||||
|
return any([unet.cond_on_text_encodings for unet in self.unets])
|
||||||
|
|
||||||
def get_unet(self, unet_number):
|
def get_unet(self, unet_number):
|
||||||
assert 0 < unet_number <= len(self.unets)
|
assert 0 < unet_number <= len(self.unets)
|
||||||
index = unet_number - 1
|
index = unet_number - 1
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.16.14'
|
__version__ = '0.16.15'
|
||||||
|
|||||||
Reference in New Issue
Block a user