fix condition_on_text_encodings in dalle2 orchestrator class, fix readme

This commit is contained in:
Phil Wang
2022-07-07 07:42:13 -07:00
parent b3e646fd3b
commit 88f516b5db
3 changed files with 8 additions and 8 deletions

View File

@@ -581,7 +581,8 @@ unet1 = Unet(
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
dim_mults=(1, 2, 4, 8),
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
).cuda()
unet2 = Unet(
@@ -598,8 +599,7 @@ decoder = Decoder(
clip = clip,
timesteps = 100,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
text_cond_drop_prob = 0.5
).cuda()
for unet_number in (1, 2):

View File

@@ -1930,10 +1930,6 @@ class Decoder(nn.Module):
self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval())
# determine from unets whether conditioning on text encoding is needed
self.condition_on_text_encodings = any([unet.cond_on_text_encodings for unet in self.unets])
# create noise schedulers per unet
if not exists(beta_schedule):
@@ -2012,6 +2008,10 @@ class Decoder(nn.Module):
def device(self):
return self._dummy.device
@property
def condition_on_text_encodings(self):
return any([unet.cond_on_text_encodings for unet in self.unets])
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1

View File

@@ -1 +1 @@
__version__ = '0.16.14'
__version__ = '0.16.15'