fix remaining issues with deriving cond_on_text_encodings from child unet settings

This commit is contained in:
Phil Wang
2022-06-26 21:07:42 -07:00
parent 868c001199
commit b90364695d
5 changed files with 22 additions and 13 deletions

View File

@@ -1817,9 +1817,7 @@ class Decoder(nn.Module):
unets = cast_tuple(unet)
num_unets = len(unets)
self.unconditional = unconditional
self.condition_on_text_encodings = unets[0].cond_on_text_encodings
assert not (self.condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
self.unconditional = unconditional
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
@@ -1859,6 +1857,10 @@ class Decoder(nn.Module):
self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval())
# determine from unets whether conditioning on text encoding is needed
self.condition_on_text_encodings = any([unet.cond_on_text_encodings for unet in self.unets])
# create noise schedulers per unet
if not exists(beta_schedule):