mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 16:54:46 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
271a376eaf | ||
|
|
e527002472 | ||
|
|
c12e067178 |
@@ -1077,6 +1077,7 @@ This library would not have gotten to this working state without the help of
|
|||||||
- [x] cross embed layers for downsampling, as an option
|
- [x] cross embed layers for downsampling, as an option
|
||||||
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||||
- [x] use pydantic for config drive training
|
- [x] use pydantic for config drive training
|
||||||
|
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
@@ -1091,7 +1092,6 @@ This library would not have gotten to this working state without the help of
|
|||||||
- [ ] decoder needs one day worth of refactor for tech debt
|
- [ ] decoder needs one day worth of refactor for tech debt
|
||||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||||
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
|
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
|
||||||
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
|
||||||
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
from torchvision import transforms as T
|
from torchvision import transforms as T
|
||||||
from pydantic import BaseModel, validator, root_validator
|
from pydantic import BaseModel, validator, root_validator
|
||||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
||||||
@@ -111,6 +112,12 @@ class TrainDecoderConfig(BaseModel):
|
|||||||
tracker: TrackerConfig
|
tracker: TrackerConfig
|
||||||
load: DecoderLoadConfig
|
load: DecoderLoadConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_path(cls, json_path):
|
||||||
|
with open(json_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return cls(**config)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def img_preproc(self):
|
def img_preproc(self):
|
||||||
def _get_transformation(transformation_name, **kwargs):
|
def _get_transformation(transformation_name, **kwargs):
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
import copy
|
import copy
|
||||||
|
from pathlib import Path
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
@@ -55,6 +56,10 @@ def num_to_groups(num, divisor):
|
|||||||
arr.append(remainder)
|
arr.append(remainder)
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
|
def get_pkg_version():
|
||||||
|
from pkg_resources import get_distribution
|
||||||
|
return get_distribution('dalle2_pytorch').version
|
||||||
|
|
||||||
# decorators
|
# decorators
|
||||||
|
|
||||||
def cast_torch_tensor(fn):
|
def cast_torch_tensor(fn):
|
||||||
@@ -289,6 +294,44 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0.]))
|
||||||
|
|
||||||
|
def save(self, path, overwrite = True):
|
||||||
|
path = Path(path)
|
||||||
|
assert not (path.exists() and not overwrite)
|
||||||
|
path.parent.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
|
save_obj = dict(
|
||||||
|
scaler = self.scaler.state_dict(),
|
||||||
|
optimizer = self.optimizer.state_dict(),
|
||||||
|
model = self.diffusion_prior.state_dict(),
|
||||||
|
version = get_pkg_version()
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
save_obj = {**save_obj, 'ema': self.ema_diffusion_prior.state_dict()}
|
||||||
|
|
||||||
|
torch.save(save_obj, str(path))
|
||||||
|
|
||||||
|
def load(self, path, only_model = False, strict = True):
|
||||||
|
path = Path(path)
|
||||||
|
assert path.exists()
|
||||||
|
|
||||||
|
loaded_obj = torch.load(str(path))
|
||||||
|
|
||||||
|
if get_pkg_version() != loaded_obj['version']:
|
||||||
|
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
|
||||||
|
|
||||||
|
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
|
|
||||||
|
if only_model:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.scaler.load_state_dict(loaded_obj['scaler'])
|
||||||
|
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
assert 'ema' in loaded_obj
|
||||||
|
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
if exists(self.max_grad_norm):
|
if exists(self.max_grad_norm):
|
||||||
self.scaler.unscale_(self.optimizer)
|
self.scaler.unscale_(self.optimizer)
|
||||||
@@ -410,6 +453,44 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0.]))
|
||||||
|
|
||||||
|
def save(self, path, overwrite = True):
|
||||||
|
path = Path(path)
|
||||||
|
assert not (path.exists() and not overwrite)
|
||||||
|
path.parent.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
|
save_obj = dict(
|
||||||
|
scaler = self.scaler.state_dict(),
|
||||||
|
optimizer = self.optimizer.state_dict(),
|
||||||
|
model = self.decoder.state_dict(),
|
||||||
|
version = get_pkg_version()
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||||
|
|
||||||
|
torch.save(save_obj, str(path))
|
||||||
|
|
||||||
|
def load(self, path, only_model = False, strict = True):
|
||||||
|
path = Path(path)
|
||||||
|
assert path.exists()
|
||||||
|
|
||||||
|
loaded_obj = torch.load(str(path))
|
||||||
|
|
||||||
|
if get_pkg_version() != loaded_obj['version']:
|
||||||
|
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
|
||||||
|
|
||||||
|
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
|
|
||||||
|
if only_model:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.scaler.load_state_dict(loaded_obj['scaler'])
|
||||||
|
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
assert 'ema' in loaded_obj
|
||||||
|
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unets(self):
|
def unets(self):
|
||||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.4.1',
|
version = '0.4.3',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
|||||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||||
from dalle2_pytorch.utils import Timer
|
from dalle2_pytorch.utils import Timer
|
||||||
|
|
||||||
import json
|
|
||||||
import torchvision
|
import torchvision
|
||||||
import torch
|
import torch
|
||||||
from torchmetrics.image.fid import FrechetInceptionDistance
|
from torchmetrics.image.fid import FrechetInceptionDistance
|
||||||
@@ -449,9 +448,7 @@ def initialize_training(config):
|
|||||||
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
||||||
def main(config_file):
|
def main(config_file):
|
||||||
print("Recalling config from {}".format(config_file))
|
print("Recalling config from {}".format(config_file))
|
||||||
with open(config_file) as f:
|
config = TrainDecoderConfig.from_json_path(config_file)
|
||||||
config = json.load(f)
|
|
||||||
config = TrainDecoderConfig(**config)
|
|
||||||
initialize_training(config)
|
initialize_training(config)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user