mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
fix remaining issues with deriving cond_on_text_encodings from child unet settings
This commit is contained in:
@@ -1817,9 +1817,7 @@ class Decoder(nn.Module):
|
||||
unets = cast_tuple(unet)
|
||||
num_unets = len(unets)
|
||||
|
||||
self.unconditional = unconditional
|
||||
self.condition_on_text_encodings = unets[0].cond_on_text_encodings
|
||||
assert not (self.condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
self.unconditional = unconditional
|
||||
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
@@ -1859,6 +1857,10 @@ class Decoder(nn.Module):
|
||||
self.unets.append(one_unet)
|
||||
self.vaes.append(one_vae.copy_for_eval())
|
||||
|
||||
# determine from unets whether conditioning on text encoding is needed
|
||||
|
||||
self.condition_on_text_encodings = any([unet.cond_on_text_encodings for unet in self.unets])
|
||||
|
||||
# create noise schedulers per unet
|
||||
|
||||
if not exists(beta_schedule):
|
||||
|
||||
Reference in New Issue
Block a user