mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
final tweak to EMA class
This commit is contained in:
@@ -195,7 +195,11 @@ class EMA(nn.Module):
|
|||||||
def update(self):
|
def update(self):
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
if self.step <= self.update_after_step or (self.step % self.update_every) != 0:
|
if (self.step % self.update_every) != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.step <= self.update_after_step:
|
||||||
|
self.copy_params_from_model_to_ema()
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.initted:
|
if not self.initted:
|
||||||
|
|||||||
Reference in New Issue
Block a user