bug fixes for text conditioning update (#175)

This commit is contained in:
zion
2022-06-26 18:12:32 -05:00
committed by GitHub
parent 032e83b0e0
commit 868c001199
3 changed files with 16 additions and 15 deletions

View File

@@ -1781,13 +1781,6 @@ class Decoder(nn.Module):
):
super().__init__()
self.unconditional = unconditional
# text conditioning
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
self.condition_on_text_encodings = condition_on_text_encodings
# clip
self.clip = None
@@ -1819,12 +1812,18 @@ class Decoder(nn.Module):
self.channels = channels
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
# verify conditioning method
unets = cast_tuple(unet)
num_unets = len(unets)
self.unconditional = unconditional
self.condition_on_text_encodings = unets[0].cond_on_text_encodings
assert not (self.condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper