mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
LazyLinear is not mature, make users pass in text_embed_dim if text conditioning is turned on
This commit is contained in:
@@ -1387,7 +1387,8 @@ class Unet(nn.Module):
|
|||||||
self.text_to_cond = None
|
self.text_to_cond = None
|
||||||
|
|
||||||
if cond_on_text_encodings:
|
if cond_on_text_encodings:
|
||||||
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim)
|
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
|
||||||
|
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
|
||||||
|
|
||||||
# finer control over whether to condition on image embeddings and text encodings
|
# finer control over whether to condition on image embeddings and text encodings
|
||||||
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
||||||
|
|||||||
@@ -335,11 +335,6 @@ class DecoderTrainer(nn.Module):
|
|||||||
self.num_unets = len(self.decoder.unets)
|
self.num_unets = len(self.decoder.unets)
|
||||||
|
|
||||||
self.use_ema = use_ema
|
self.use_ema = use_ema
|
||||||
|
|
||||||
if use_ema:
|
|
||||||
has_lazy_linear = any([type(module) == nn.LazyLinear for module in decoder.modules()])
|
|
||||||
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
|
|
||||||
|
|
||||||
self.ema_unets = nn.ModuleList([])
|
self.ema_unets = nn.ModuleList([])
|
||||||
|
|
||||||
self.amp = amp
|
self.amp = amp
|
||||||
|
|||||||
Reference in New Issue
Block a user