mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
ema module fixes (#139)
This commit is contained in:
@@ -178,7 +178,7 @@ class EMA(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
beta = 0.99,
|
beta = 0.9999,
|
||||||
update_after_step = 1000,
|
update_after_step = 1000,
|
||||||
update_every = 10,
|
update_every = 10,
|
||||||
):
|
):
|
||||||
@@ -198,7 +198,8 @@ class EMA(nn.Module):
|
|||||||
self.ema_model.to(device)
|
self.ema_model.to(device)
|
||||||
|
|
||||||
def copy_params_from_model_to_ema(self):
|
def copy_params_from_model_to_ema(self):
|
||||||
self.ema_model.state_dict(self.online_model.state_dict())
|
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
|
||||||
|
ma_param.data.copy_(current_param.data)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
step = self.step.item()
|
step = self.step.item()
|
||||||
@@ -217,19 +218,17 @@ class EMA(nn.Module):
|
|||||||
|
|
||||||
self.update_moving_average(self.ema_model, self.online_model)
|
self.update_moving_average(self.ema_model, self.online_model)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def update_moving_average(self, ma_model, current_model):
|
def update_moving_average(self, ma_model, current_model):
|
||||||
def calculate_ema(beta, old, new):
|
for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
|
||||||
if not exists(old):
|
difference = ma_params.data - current_params.data
|
||||||
return new
|
difference.mul_(1.0 - self.beta)
|
||||||
return old * beta + (1 - beta) * new
|
ma_params.sub_(difference)
|
||||||
|
|
||||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
|
||||||
old_weight, up_weight = ma_params.data, current_params.data
|
difference = ma_buffer - current_buffer
|
||||||
ma_params.data.copy_(calculate_ema(self.beta, old_weight, up_weight))
|
difference.mul_(1.0 - self.beta)
|
||||||
|
ma_buffer.sub_(difference)
|
||||||
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
|
|
||||||
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
|
|
||||||
ma_buffer.copy_(new_buffer_value)
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.ema_model(*args, **kwargs)
|
return self.ema_model(*args, **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user