diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index e69de29..c8c2647 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -0,0 +1,53 @@ +import copy +import torch +from torch import nn + +# exponential moving average wrapper + +class EMA(nn.Module): + def __init__( + self, + model, + beta = 0.99, + ema_update_after_step = 1000, + ema_update_every = 10, + ): + super().__init__() + self.beta = beta + self.online_model = model + self.ema_model = copy.deepcopy(model) + + self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0 + self.ema_update_every = ema_update_every + + self.register_buffer('initted', torch.Tensor([False])) + self.register_buffer('step', torch.tensor([0.])) + + def update(self): + self.step += 1 + + if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0: + return + + if not self.initted: + self.ema_model.state_dict(self.online_model.state_dict()) + self.initted.data.copy_(torch.Tensor([True])) + + self.update_moving_average(self.ema_model, self.online_model) + + def update_moving_average(ma_model, current_model): + def calculate_ema(beta, old, new): + if not exists(old): + return new + return old * beta + (1 - beta) * new + + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = calculate_ema(self.beta, old_weight, up_weight) + + for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()): + new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer) + ma_buffer.copy_(new_buffer_value) + + def __call__(self, *args, **kwargs): + return self.ema_model(*args, **kwargs)