Compare commits

...

3 Commits
0.4.1 ... 0.4.3

Author SHA1 Message Date
Phil Wang
271a376eaf 0.4.3 2022-05-22 15:10:28 -07:00
Phil Wang
e527002472 take care of saving and loading functions on the diffusion prior and decoder training classes 2022-05-22 15:10:15 -07:00
Phil Wang
c12e067178 let the pydantic config base model take care of loading configuration from json path 2022-05-22 14:47:23 -07:00
5 changed files with 91 additions and 6 deletions

View File

@@ -1077,6 +1077,7 @@ This library would not have gotten to this working state without the help of
- [x] cross embed layers for downsampling, as an option - [x] cross embed layers for downsampling, as an option
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a> - [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [x] use pydantic for config drive training - [x] use pydantic for config drive training
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
@@ -1091,7 +1092,6 @@ This library would not have gotten to this working state without the help of
- [ ] decoder needs one day worth of refactor for tech debt - [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well - [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly - [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89 - [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
## Citations ## Citations

View File

@@ -1,3 +1,4 @@
import json
from torchvision import transforms as T from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
@@ -111,6 +112,12 @@ class TrainDecoderConfig(BaseModel):
tracker: TrackerConfig tracker: TrackerConfig
load: DecoderLoadConfig load: DecoderLoadConfig
@classmethod
def from_json_path(cls, json_path):
with open(json_path) as f:
config = json.load(f)
return cls(**config)
@property @property
def img_preproc(self): def img_preproc(self):
def _get_transformation(transformation_name, **kwargs): def _get_transformation(transformation_name, **kwargs):

View File

@@ -1,5 +1,6 @@
import time import time
import copy import copy
from pathlib import Path
from math import ceil from math import ceil
from functools import partial, wraps from functools import partial, wraps
from collections.abc import Iterable from collections.abc import Iterable
@@ -55,6 +56,10 @@ def num_to_groups(num, divisor):
arr.append(remainder) arr.append(remainder)
return arr return arr
def get_pkg_version():
from pkg_resources import get_distribution
return get_distribution('dalle2_pytorch').version
# decorators # decorators
def cast_torch_tensor(fn): def cast_torch_tensor(fn):
@@ -289,6 +294,44 @@ class DiffusionPriorTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0.])) self.register_buffer('step', torch.tensor([0.]))
def save(self, path, overwrite = True):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(),
version = get_pkg_version()
)
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_diffusion_prior.state_dict()}
torch.save(save_obj, str(path))
def load(self, path, only_model = False, strict = True):
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path))
if get_pkg_version() != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
if only_model:
return
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer'])
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
def update(self): def update(self):
if exists(self.max_grad_norm): if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer) self.scaler.unscale_(self.optimizer)
@@ -410,6 +453,44 @@ class DecoderTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0.])) self.register_buffer('step', torch.tensor([0.]))
def save(self, path, overwrite = True):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.decoder.state_dict(),
version = get_pkg_version()
)
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
torch.save(save_obj, str(path))
def load(self, path, only_model = False, strict = True):
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path))
if get_pkg_version() != loaded_obj['version']:
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
if only_model:
return
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer'])
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
@property @property
def unets(self): def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) return nn.ModuleList([ema.ema_model for ema in self.ema_unets])

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.4.1', version = '0.4.3',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',

View File

@@ -5,7 +5,6 @@ from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer from dalle2_pytorch.utils import Timer
import json
import torchvision import torchvision
import torch import torch
from torchmetrics.image.fid import FrechetInceptionDistance from torchmetrics.image.fid import FrechetInceptionDistance
@@ -449,9 +448,7 @@ def initialize_training(config):
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file") @click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
def main(config_file): def main(config_file):
print("Recalling config from {}".format(config_file)) print("Recalling config from {}".format(config_file))
with open(config_file) as f: config = TrainDecoderConfig.from_json_path(config_file)
config = json.load(f)
config = TrainDecoderConfig(**config)
initialize_training(config) initialize_training(config)