LazyLinear is not mature, make users pass in text_embed_dim if text conditioning is turned on

This commit is contained in:
Phil Wang
2022-05-15 18:56:52 -07:00
parent 4a4c7ac9e6
commit 36c5079bd7
3 changed files with 3 additions and 7 deletions

View File

@@ -335,11 +335,6 @@ class DecoderTrainer(nn.Module):
self.num_unets = len(self.decoder.unets)
self.use_ema = use_ema
if use_ema:
has_lazy_linear = any([type(module) == nn.LazyLinear for module in decoder.modules()])
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
self.ema_unets = nn.ModuleList([])
self.amp = amp