fix update_every within EMA

This commit is contained in:
Phil Wang
2022-06-03 10:21:05 -07:00
parent ffd342e9d0
commit 9cc475f6e7
2 changed files with 7 additions and 6 deletions

View File

@@ -188,7 +188,7 @@ class EMA(nn.Module):
self.ema_model = copy.deepcopy(model) self.ema_model = copy.deepcopy(model)
self.update_every = update_every self.update_every = update_every
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0 self.update_after_step = update_after_step
self.register_buffer('initted', torch.Tensor([False])) self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0])) self.register_buffer('step', torch.tensor([0]))
@@ -201,16 +201,17 @@ class EMA(nn.Module):
self.ema_model.state_dict(self.online_model.state_dict()) self.ema_model.state_dict(self.online_model.state_dict())
def update(self): def update(self):
step = self.step.item()
self.step += 1 self.step += 1
if (self.step % self.update_every) != 0: if (step % self.update_every) != 0:
return return
if self.step <= self.update_after_step: if step <= self.update_after_step:
self.copy_params_from_model_to_ema() self.copy_params_from_model_to_ema()
return return
if not self.initted: if not self.initted.item():
self.copy_params_from_model_to_ema() self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.Tensor([True])) self.initted.data.copy_(torch.Tensor([True]))
@@ -224,7 +225,7 @@ class EMA(nn.Module):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = calculate_ema(self.beta, old_weight, up_weight) ma_params.data.copy_(calculate_ema(self.beta, old_weight, up_weight))
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()): for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer) new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)

View File

@@ -1 +1 @@
__version__ = '0.6.9' __version__ = '0.6.10'