mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
some cleanup
This commit is contained in:
@@ -58,8 +58,15 @@ def num_to_groups(num, divisor):
|
|||||||
arr.append(remainder)
|
arr.append(remainder)
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
def get_pkg_version():
|
def clamp(value, min_value = None, max_value = None):
|
||||||
return __version__
|
assert exists(min_value) or exists(max_value)
|
||||||
|
if exists(min_value):
|
||||||
|
value = max(value, min_value)
|
||||||
|
|
||||||
|
if exists(max_value):
|
||||||
|
value = min(value, max_value)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
# decorators
|
# decorators
|
||||||
|
|
||||||
@@ -227,10 +234,17 @@ class EMA(nn.Module):
|
|||||||
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
|
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
|
||||||
ma_param.data.copy_(current_param.data)
|
ma_param.data.copy_(current_param.data)
|
||||||
|
|
||||||
|
for ma_buffer, current_buffer in zip(list(self.ema_model.buffers()), list(self.online_model.buffers())):
|
||||||
|
ma_buffer.data.copy_(current_buffer.data)
|
||||||
|
|
||||||
def get_current_decay(self):
|
def get_current_decay(self):
|
||||||
epoch = max(0, self.step.item() - self.update_after_step - 1)
|
epoch = clamp(self.step.item() - self.update_after_step - 1, min_value = 0)
|
||||||
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
|
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
|
||||||
return 0. if epoch < 0 else min(self.beta, max(self.min_value, value))
|
|
||||||
|
if epoch <= 0:
|
||||||
|
return 0.
|
||||||
|
|
||||||
|
return clamp(value, min_value = self.min_value, max_value = self.beta)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
step = self.step.item()
|
step = self.step.item()
|
||||||
@@ -521,7 +535,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
loaded_obj = torch.load(str(path))
|
loaded_obj = torch.load(str(path))
|
||||||
|
|
||||||
if version.parse(__version__) != loaded_obj['version']:
|
if version.parse(__version__) != loaded_obj['version']:
|
||||||
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
|
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||||
|
|
||||||
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.6.13'
|
__version__ = '0.6.15'
|
||||||
|
|||||||
Reference in New Issue
Block a user