mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
fix issue with ema class
This commit is contained in:
@@ -35,7 +35,7 @@ class EMA(nn.Module):
|
|||||||
|
|
||||||
self.update_moving_average(self.ema_model, self.online_model)
|
self.update_moving_average(self.ema_model, self.online_model)
|
||||||
|
|
||||||
def update_moving_average(ma_model, current_model):
|
def update_moving_average(self, ma_model, current_model):
|
||||||
def calculate_ema(beta, old, new):
|
def calculate_ema(beta, old, new):
|
||||||
if not exists(old):
|
if not exists(old):
|
||||||
return new
|
return new
|
||||||
|
|||||||
Reference in New Issue
Block a user