Compare commits

...

1 Commits

3 changed files with 3 additions and 7 deletions

View File

@@ -1387,7 +1387,8 @@ class Unet(nn.Module):
self.text_to_cond = None
if cond_on_text_encodings:
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim)
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting

View File

@@ -335,11 +335,6 @@ class DecoderTrainer(nn.Module):
self.num_unets = len(self.decoder.unets)
self.use_ema = use_ema
if use_ema:
has_lazy_linear = any([type(module) == nn.LazyLinear for module in decoder.modules()])
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
self.ema_unets = nn.ModuleList([])
self.amp = amp

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.36',
version = '0.2.37',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',