mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
bring in the exponential moving average wrapper, to get ready for training
This commit is contained in:
@@ -0,0 +1,53 @@
|
|||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
# exponential moving average wrapper
|
||||||
|
|
||||||
|
class EMA(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
beta = 0.99,
|
||||||
|
ema_update_after_step = 1000,
|
||||||
|
ema_update_every = 10,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.beta = beta
|
||||||
|
self.online_model = model
|
||||||
|
self.ema_model = copy.deepcopy(model)
|
||||||
|
|
||||||
|
self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0
|
||||||
|
self.ema_update_every = ema_update_every
|
||||||
|
|
||||||
|
self.register_buffer('initted', torch.Tensor([False]))
|
||||||
|
self.register_buffer('step', torch.tensor([0.]))
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
self.step += 1
|
||||||
|
|
||||||
|
if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.initted:
|
||||||
|
self.ema_model.state_dict(self.online_model.state_dict())
|
||||||
|
self.initted.data.copy_(torch.Tensor([True]))
|
||||||
|
|
||||||
|
self.update_moving_average(self.ema_model, self.online_model)
|
||||||
|
|
||||||
|
def update_moving_average(ma_model, current_model):
|
||||||
|
def calculate_ema(beta, old, new):
|
||||||
|
if not exists(old):
|
||||||
|
return new
|
||||||
|
return old * beta + (1 - beta) * new
|
||||||
|
|
||||||
|
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||||
|
old_weight, up_weight = ma_params.data, current_params.data
|
||||||
|
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
|
||||||
|
|
||||||
|
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
|
||||||
|
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
|
||||||
|
ma_buffer.copy_(new_buffer_value)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.ema_model(*args, **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user