mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
bug fixes for text conditioning update (#175)
This commit is contained in:
@@ -1781,13 +1781,6 @@ class Decoder(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.unconditional = unconditional
|
||||
|
||||
# text conditioning
|
||||
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# clip
|
||||
|
||||
self.clip = None
|
||||
@@ -1819,12 +1812,18 @@ class Decoder(nn.Module):
|
||||
|
||||
self.channels = channels
|
||||
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
# verify conditioning method
|
||||
|
||||
unets = cast_tuple(unet)
|
||||
num_unets = len(unets)
|
||||
|
||||
self.unconditional = unconditional
|
||||
self.condition_on_text_encodings = unets[0].cond_on_text_encodings
|
||||
assert not (self.condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
|
||||
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
|
||||
|
||||
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
|
||||
|
||||
Reference in New Issue
Block a user