diff --git a/README.md b/README.md index bff3f21..350839a 100644 --- a/README.md +++ b/README.md @@ -1077,6 +1077,7 @@ This library would not have gotten to this working state without the help of - [x] cross embed layers for downsampling, as an option - [x] use an experimental tracker agnostic setup, as done here - [x] use pydantic for config drive training +- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number) - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] train on a toy task, offer in colab @@ -1091,7 +1092,6 @@ This library would not have gotten to this working state without the help of - [ ] decoder needs one day worth of refactor for tech debt - [ ] allow for unet to be able to condition non-cross attention style as well - [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly -- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number) - [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89 ## Citations diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 743ec96..2019d64 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -1,5 +1,6 @@ import time import copy +from pathlib import Path from math import ceil from functools import partial, wraps from collections.abc import Iterable @@ -55,6 +56,10 @@ def num_to_groups(num, divisor): arr.append(remainder) return arr +def get_pkg_version(): + from pkg_resources import get_distribution + return get_distribution('dalle2_pytorch').version + # decorators def cast_torch_tensor(fn): @@ -289,6 +294,44 @@ class DiffusionPriorTrainer(nn.Module): self.register_buffer('step', torch.tensor([0.])) + def save(self, path, overwrite = True): + path = Path(path) + assert not (path.exists() and not overwrite) + path.parent.mkdir(parents = True, exist_ok = True) + + save_obj = dict( + scaler = self.scaler.state_dict(), + optimizer = self.optimizer.state_dict(), + model = self.diffusion_prior.state_dict(), + version = get_pkg_version() + ) + + if self.use_ema: + save_obj = {**save_obj, 'ema': self.ema_diffusion_prior.state_dict()} + + torch.save(save_obj, str(path)) + + def load(self, path, only_model = False, strict = True): + path = Path(path) + assert path.exists() + + loaded_obj = torch.load(str(path)) + + if get_pkg_version() != loaded_obj['version']: + print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}') + + self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict) + + if only_model: + return + + self.scaler.load_state_dict(loaded_obj['scaler']) + self.optimizer.load_state_dict(loaded_obj['optimizer']) + + if self.use_ema: + assert 'ema' in loaded_obj + self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict) + def update(self): if exists(self.max_grad_norm): self.scaler.unscale_(self.optimizer) @@ -410,6 +453,44 @@ class DecoderTrainer(nn.Module): self.register_buffer('step', torch.tensor([0.])) + def save(self, path, overwrite = True): + path = Path(path) + assert not (path.exists() and not overwrite) + path.parent.mkdir(parents = True, exist_ok = True) + + save_obj = dict( + scaler = self.scaler.state_dict(), + optimizer = self.optimizer.state_dict(), + model = self.decoder.state_dict(), + version = get_pkg_version() + ) + + if self.use_ema: + save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} + + torch.save(save_obj, str(path)) + + def load(self, path, only_model = False, strict = True): + path = Path(path) + assert path.exists() + + loaded_obj = torch.load(str(path)) + + if get_pkg_version() != loaded_obj['version']: + print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}') + + self.decoder.load_state_dict(loaded_obj['model'], strict = strict) + + if only_model: + return + + self.scaler.load_state_dict(loaded_obj['scaler']) + self.optimizer.load_state_dict(loaded_obj['optimizer']) + + if self.use_ema: + assert 'ema' in loaded_obj + self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) + @property def unets(self): return nn.ModuleList([ema.ema_model for ema in self.ema_unets])