mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
final tweak to EMA class
This commit is contained in:
@@ -195,7 +195,11 @@ class EMA(nn.Module):
|
||||
def update(self):
|
||||
self.step += 1
|
||||
|
||||
if self.step <= self.update_after_step or (self.step % self.update_every) != 0:
|
||||
if (self.step % self.update_every) != 0:
|
||||
return
|
||||
|
||||
if self.step <= self.update_after_step:
|
||||
self.copy_params_from_model_to_ema()
|
||||
return
|
||||
|
||||
if not self.initted:
|
||||
|
||||
Reference in New Issue
Block a user