Compare commits

...

5 Commits

Author SHA1 Message Date
Phil Wang
350a3d6045 0.6.16 2022-06-06 08:45:46 -07:00
Kashif Rasul
1a81670718 fix quadratic_beta_schedule (#141) 2022-06-06 08:45:14 -07:00
Phil Wang
934c9728dc some cleanup 2022-06-04 16:54:15 -07:00
Phil Wang
ce4b0107c1 0.6.13 2022-06-04 13:26:57 -07:00
zion
64c2f9c4eb implement ema warmup from @crowsonkb (#140) 2022-06-04 13:26:34 -07:00
3 changed files with 55 additions and 8 deletions

View File

@@ -373,7 +373,7 @@ def quadratic_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2
def sigmoid_beta_schedule(timesteps):

View File

@@ -58,8 +58,15 @@ def num_to_groups(num, divisor):
arr.append(remainder)
return arr
def get_pkg_version():
return __version__
def clamp(value, min_value = None, max_value = None):
assert exists(min_value) or exists(max_value)
if exists(min_value):
value = max(value, min_value)
if exists(max_value):
value = min(value, max_value)
return value
# decorators
@@ -175,12 +182,34 @@ def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embe
# exponential moving average wrapper
class EMA(nn.Module):
"""
Implements exponential moving average shadowing for your model.
Utilizes an inverse decay schedule to manage longer term training runs.
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 1.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
def __init__(
self,
model,
beta = 0.9999,
update_after_step = 1000,
update_after_step = 10000,
update_every = 10,
inv_gamma = 1.0,
power = 2/3,
min_value = 0.0,
):
super().__init__()
self.beta = beta
@@ -190,6 +219,10 @@ class EMA(nn.Module):
self.update_every = update_every
self.update_after_step = update_after_step
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0]))
@@ -201,6 +234,18 @@ class EMA(nn.Module):
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
ma_param.data.copy_(current_param.data)
for ma_buffer, current_buffer in zip(list(self.ema_model.buffers()), list(self.online_model.buffers())):
ma_buffer.data.copy_(current_buffer.data)
def get_current_decay(self):
epoch = clamp(self.step.item() - self.update_after_step - 1, min_value = 0)
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
if epoch <= 0:
return 0.
return clamp(value, min_value = self.min_value, max_value = self.beta)
def update(self):
step = self.step.item()
self.step += 1
@@ -220,14 +265,16 @@ class EMA(nn.Module):
@torch.no_grad()
def update_moving_average(self, ma_model, current_model):
current_decay = self.get_current_decay()
for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
difference = ma_params.data - current_params.data
difference.mul_(1.0 - self.beta)
difference.mul_(1.0 - current_decay)
ma_params.sub_(difference)
for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
difference = ma_buffer - current_buffer
difference.mul_(1.0 - self.beta)
difference.mul_(1.0 - current_decay)
ma_buffer.sub_(difference)
def __call__(self, *args, **kwargs):
@@ -488,7 +535,7 @@ class DecoderTrainer(nn.Module):
loaded_obj = torch.load(str(path))
if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])

View File

@@ -1 +1 @@
__version__ = '0.6.12'
__version__ = '0.6.16'