mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
fix issue with ema class
This commit is contained in:
@@ -35,7 +35,7 @@ class EMA(nn.Module):
|
||||
|
||||
self.update_moving_average(self.ema_model, self.online_model)
|
||||
|
||||
def update_moving_average(ma_model, current_model):
|
||||
def update_moving_average(self, ma_model, current_model):
|
||||
def calculate_ema(beta, old, new):
|
||||
if not exists(old):
|
||||
return new
|
||||
|
||||
Reference in New Issue
Block a user