diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 84c0957..f9fc1d1 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -58,8 +58,15 @@ def num_to_groups(num, divisor): arr.append(remainder) return arr -def get_pkg_version(): - return __version__ +def clamp(value, min_value = None, max_value = None): + assert exists(min_value) or exists(max_value) + if exists(min_value): + value = max(value, min_value) + + if exists(max_value): + value = min(value, max_value) + + return value # decorators @@ -227,10 +234,17 @@ class EMA(nn.Module): for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())): ma_param.data.copy_(current_param.data) + for ma_buffer, current_buffer in zip(list(self.ema_model.buffers()), list(self.online_model.buffers())): + ma_buffer.data.copy_(current_buffer.data) + def get_current_decay(self): - epoch = max(0, self.step.item() - self.update_after_step - 1) + epoch = clamp(self.step.item() - self.update_after_step - 1, min_value = 0) value = 1 - (1 + epoch / self.inv_gamma) ** - self.power - return 0. if epoch < 0 else min(self.beta, max(self.min_value, value)) + + if epoch <= 0: + return 0. + + return clamp(value, min_value = self.min_value, max_value = self.beta) def update(self): step = self.step.item() @@ -521,7 +535,7 @@ class DecoderTrainer(nn.Module): loaded_obj = torch.load(str(path)) if version.parse(__version__) != loaded_obj['version']: - print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}') + print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}') self.decoder.load_state_dict(loaded_obj['model'], strict = strict) self.step.copy_(torch.ones_like(self.step) * loaded_obj['step']) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 70039a4..1b3c28f 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.6.13' +__version__ = '0.6.15'