mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
fix condition_on_text_encodings in dalle2 orchestrator class, fix readme
This commit is contained in:
@@ -581,7 +581,8 @@ unet1 = Unet(
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
@@ -598,12 +599,11 @@ decoder = Decoder(
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||
text_cond_drop_prob = 0.5
|
||||
).cuda()
|
||||
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps
|
||||
|
||||
@@ -1930,10 +1930,6 @@ class Decoder(nn.Module):
|
||||
self.unets.append(one_unet)
|
||||
self.vaes.append(one_vae.copy_for_eval())
|
||||
|
||||
# determine from unets whether conditioning on text encoding is needed
|
||||
|
||||
self.condition_on_text_encodings = any([unet.cond_on_text_encodings for unet in self.unets])
|
||||
|
||||
# create noise schedulers per unet
|
||||
|
||||
if not exists(beta_schedule):
|
||||
@@ -2012,6 +2008,10 @@ class Decoder(nn.Module):
|
||||
def device(self):
|
||||
return self._dummy.device
|
||||
|
||||
@property
|
||||
def condition_on_text_encodings(self):
|
||||
return any([unet.cond_on_text_encodings for unet in self.unets])
|
||||
|
||||
def get_unet(self, unet_number):
|
||||
assert 0 < unet_number <= len(self.unets)
|
||||
index = unet_number - 1
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.16.14'
|
||||
__version__ = '0.16.15'
|
||||
|
||||
Reference in New Issue
Block a user