From 36c5079bd7c0a3ad4131e7866785220c968a512a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 15 May 2022 18:56:52 -0700 Subject: [PATCH] LazyLinear is not mature, make users pass in text_embed_dim if text conditioning is turned on --- dalle2_pytorch/dalle2_pytorch.py | 3 ++- dalle2_pytorch/train.py | 5 ----- setup.py | 2 +- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index aa58424..a8fc6b8 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1387,7 +1387,8 @@ class Unet(nn.Module): self.text_to_cond = None if cond_on_text_encodings: - self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim) + assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True' + self.text_to_cond = nn.Linear(text_embed_dim, cond_dim) # finer control over whether to condition on image embeddings and text encodings # so one can have the latter unets in the cascading DDPMs only focus on super-resoluting diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index 0e70312..f743e6e 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -335,11 +335,6 @@ class DecoderTrainer(nn.Module): self.num_unets = len(self.decoder.unets) self.use_ema = use_ema - - if use_ema: - has_lazy_linear = any([type(module) == nn.LazyLinear for module in decoder.modules()]) - assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average' - self.ema_unets = nn.ModuleList([]) self.amp = amp diff --git a/setup.py b/setup.py index 2a9c836..9650985 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.36', + version = '0.2.37', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',