ema module fixes (#139)

This commit is contained in:
zion
2022-06-03 19:43:51 -07:00
committed by GitHub
parent 708809ed6c
commit 83517849e5

View File

@@ -178,7 +178,7 @@ class EMA(nn.Module):
def __init__( def __init__(
self, self,
model, model,
beta = 0.99, beta = 0.9999,
update_after_step = 1000, update_after_step = 1000,
update_every = 10, update_every = 10,
): ):
@@ -198,7 +198,8 @@ class EMA(nn.Module):
self.ema_model.to(device) self.ema_model.to(device)
def copy_params_from_model_to_ema(self): def copy_params_from_model_to_ema(self):
self.ema_model.state_dict(self.online_model.state_dict()) for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
ma_param.data.copy_(current_param.data)
def update(self): def update(self):
step = self.step.item() step = self.step.item()
@@ -217,19 +218,17 @@ class EMA(nn.Module):
self.update_moving_average(self.ema_model, self.online_model) self.update_moving_average(self.ema_model, self.online_model)
@torch.no_grad()
def update_moving_average(self, ma_model, current_model): def update_moving_average(self, ma_model, current_model):
def calculate_ema(beta, old, new): for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
if not exists(old): difference = ma_params.data - current_params.data
return new difference.mul_(1.0 - self.beta)
return old * beta + (1 - beta) * new ma_params.sub_(difference)
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
old_weight, up_weight = ma_params.data, current_params.data difference = ma_buffer - current_buffer
ma_params.data.copy_(calculate_ema(self.beta, old_weight, up_weight)) difference.mul_(1.0 - self.beta)
ma_buffer.sub_(difference)
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
ma_buffer.copy_(new_buffer_value)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs) return self.ema_model(*args, **kwargs)