mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 19:14:25 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
49de72040c |
@@ -196,7 +196,7 @@ class EMA(nn.Module):
|
|||||||
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
|
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
|
||||||
|
|
||||||
self.register_buffer('initted', torch.Tensor([False]))
|
self.register_buffer('initted', torch.Tensor([False]))
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0]))
|
||||||
|
|
||||||
def restore_ema_model_device(self):
|
def restore_ema_model_device(self):
|
||||||
device = self.initted.device
|
device = self.initted.device
|
||||||
@@ -292,7 +292,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0]))
|
||||||
|
|
||||||
def save(self, path, overwrite = True):
|
def save(self, path, overwrite = True):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
@@ -303,7 +303,8 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
scaler = self.scaler.state_dict(),
|
scaler = self.scaler.state_dict(),
|
||||||
optimizer = self.optimizer.state_dict(),
|
optimizer = self.optimizer.state_dict(),
|
||||||
model = self.diffusion_prior.state_dict(),
|
model = self.diffusion_prior.state_dict(),
|
||||||
version = get_pkg_version()
|
version = get_pkg_version(),
|
||||||
|
step = self.step.item()
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
@@ -321,6 +322,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
|
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
|
||||||
|
|
||||||
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
|
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
|
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||||
|
|
||||||
if only_model:
|
if only_model:
|
||||||
return
|
return
|
||||||
@@ -459,12 +461,18 @@ class DecoderTrainer(nn.Module):
|
|||||||
path.parent.mkdir(parents = True, exist_ok = True)
|
path.parent.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
save_obj = dict(
|
save_obj = dict(
|
||||||
scaler = self.scaler.state_dict(),
|
|
||||||
optimizer = self.optimizer.state_dict(),
|
|
||||||
model = self.decoder.state_dict(),
|
model = self.decoder.state_dict(),
|
||||||
version = get_pkg_version()
|
version = get_pkg_version(),
|
||||||
|
step = self.step.item()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for ind in range(0, self.num_unets):
|
||||||
|
scaler_key = f'scaler{ind}'
|
||||||
|
optimizer_key = f'scaler{ind}'
|
||||||
|
scaler = getattr(self, scaler_key)
|
||||||
|
optimizer = getattr(self, optimizer_key)
|
||||||
|
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||||
|
|
||||||
@@ -480,12 +488,19 @@ class DecoderTrainer(nn.Module):
|
|||||||
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
|
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
|
||||||
|
|
||||||
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
|
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||||
|
|
||||||
if only_model:
|
if only_model:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.scaler.load_state_dict(loaded_obj['scaler'])
|
for ind in range(0, self.num_unets):
|
||||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
scaler_key = f'scaler{ind}'
|
||||||
|
optimizer_key = f'scaler{ind}'
|
||||||
|
scaler = getattr(self, scaler_key)
|
||||||
|
optimizer = getattr(self, optimizer_key)
|
||||||
|
|
||||||
|
scaler.load_state_dict(loaded_obj[scaler_key])
|
||||||
|
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
assert 'ema' in loaded_obj
|
assert 'ema' in loaded_obj
|
||||||
|
|||||||
Reference in New Issue
Block a user