diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 122da24..528d122 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -188,7 +188,7 @@ class EMA(nn.Module): self.ema_model = copy.deepcopy(model) self.update_every = update_every - self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0 + self.update_after_step = update_after_step self.register_buffer('initted', torch.Tensor([False])) self.register_buffer('step', torch.tensor([0])) @@ -201,16 +201,17 @@ class EMA(nn.Module): self.ema_model.state_dict(self.online_model.state_dict()) def update(self): + step = self.step.item() self.step += 1 - if (self.step % self.update_every) != 0: + if (step % self.update_every) != 0: return - if self.step <= self.update_after_step: + if step <= self.update_after_step: self.copy_params_from_model_to_ema() return - if not self.initted: + if not self.initted.item(): self.copy_params_from_model_to_ema() self.initted.data.copy_(torch.Tensor([True])) @@ -224,7 +225,7 @@ class EMA(nn.Module): for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): old_weight, up_weight = ma_params.data, current_params.data - ma_params.data = calculate_ema(self.beta, old_weight, up_weight) + ma_params.data.copy_(calculate_ema(self.beta, old_weight, up_weight)) for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()): new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 6300a70..6c1f53c 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.6.9' +__version__ = '0.6.10'