mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
implement ema warmup from @crowsonkb (#140)
This commit is contained in:
@@ -175,12 +175,34 @@ def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embe
|
|||||||
# exponential moving average wrapper
|
# exponential moving average wrapper
|
||||||
|
|
||||||
class EMA(nn.Module):
|
class EMA(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements exponential moving average shadowing for your model.
|
||||||
|
|
||||||
|
Utilizes an inverse decay schedule to manage longer term training runs.
|
||||||
|
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
|
||||||
|
|
||||||
|
@crowsonkb's notes on EMA Warmup:
|
||||||
|
|
||||||
|
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
|
||||||
|
good values for models you plan to train for a million or more steps (reaches decay
|
||||||
|
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
|
||||||
|
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
||||||
|
215.4k steps).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||||
|
power (float): Exponential factor of EMA warmup. Default: 1.
|
||||||
|
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
beta = 0.9999,
|
beta = 0.9999,
|
||||||
update_after_step = 1000,
|
update_after_step = 10000,
|
||||||
update_every = 10,
|
update_every = 10,
|
||||||
|
inv_gamma = 1.0,
|
||||||
|
power = 2/3,
|
||||||
|
min_value = 0.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
@@ -190,6 +212,10 @@ class EMA(nn.Module):
|
|||||||
self.update_every = update_every
|
self.update_every = update_every
|
||||||
self.update_after_step = update_after_step
|
self.update_after_step = update_after_step
|
||||||
|
|
||||||
|
self.inv_gamma = inv_gamma
|
||||||
|
self.power = power
|
||||||
|
self.min_value = min_value
|
||||||
|
|
||||||
self.register_buffer('initted', torch.Tensor([False]))
|
self.register_buffer('initted', torch.Tensor([False]))
|
||||||
self.register_buffer('step', torch.tensor([0]))
|
self.register_buffer('step', torch.tensor([0]))
|
||||||
|
|
||||||
@@ -201,6 +227,11 @@ class EMA(nn.Module):
|
|||||||
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
|
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
|
||||||
ma_param.data.copy_(current_param.data)
|
ma_param.data.copy_(current_param.data)
|
||||||
|
|
||||||
|
def get_current_decay(self):
|
||||||
|
epoch = max(0, self.step.item() - self.update_after_step - 1)
|
||||||
|
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
|
||||||
|
return 0. if epoch < 0 else min(self.beta, max(self.min_value, value))
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
step = self.step.item()
|
step = self.step.item()
|
||||||
self.step += 1
|
self.step += 1
|
||||||
@@ -220,14 +251,16 @@ class EMA(nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def update_moving_average(self, ma_model, current_model):
|
def update_moving_average(self, ma_model, current_model):
|
||||||
|
current_decay = self.get_current_decay()
|
||||||
|
|
||||||
for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
|
for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
|
||||||
difference = ma_params.data - current_params.data
|
difference = ma_params.data - current_params.data
|
||||||
difference.mul_(1.0 - self.beta)
|
difference.mul_(1.0 - current_decay)
|
||||||
ma_params.sub_(difference)
|
ma_params.sub_(difference)
|
||||||
|
|
||||||
for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
|
for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
|
||||||
difference = ma_buffer - current_buffer
|
difference = ma_buffer - current_buffer
|
||||||
difference.mul_(1.0 - self.beta)
|
difference.mul_(1.0 - current_decay)
|
||||||
ma_buffer.sub_(difference)
|
ma_buffer.sub_(difference)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
|
|||||||
Reference in New Issue
Block a user