Commit Graph

31 Commits

Author SHA1 Message Date
Phil Wang
e928ae5c34 default the device to the device that the diffusion prior parameters are on, if the trainer was never given the accelerator nor device 2022-07-06 12:47:48 -07:00
Phil Wang
ec68243479 set ability to do warmup steps for each unet during training 2022-07-05 16:24:16 -07:00
Phil Wang
3afdcdfe86 need to keep track of training steps separately for each unet in decoder trainer 2022-07-05 15:17:59 -07:00
Aidan Dempster
27b0f7ca0d Overhauled the tracker system (#172)
* Overhauled the tracker system
Separated the logging and saving capabilities
Changed creation to be consistent and initializing behavior to be defined by a class initializer instead of in the training script
Added class separation between different types of loaders and savers to make the system more verbose

* Changed the saver system to only save the checkpoint once

* Added better error handling for saving checkpoints

* Fixed an error where wandb would error when passed arbitrary kwargs

* Fixed variable naming issues for improved saver
Added more logging during long pauses

* Fixed which methods need to be dummy to immediatly return
Added the ability to set whether you find unused parameters

* Added more logging for when a wandb loader fails
2022-07-01 09:39:40 -07:00
Aidan Dempster
f5760bdb92 Add data flexibility to decoder trainer (#165)
* Added the ability to train decoder with text embeddings

* Added the ability to train using on the fly generated embeddings with clip

* Clip now generates embeddings for whatever is not precomputed
2022-06-25 19:05:20 -07:00
Phil Wang
4b994601ae just make sure decoder learning rate is reasonable and help out budding researchers 2022-06-23 11:29:28 -07:00
Phil Wang
0021535c26 move ema to external repo 2022-06-20 11:48:32 -07:00
Phil Wang
56883910fb cleanup 2022-06-20 11:14:55 -07:00
Phil Wang
67f0740777 small cleanup 2022-06-20 08:59:51 -07:00
Phil Wang
138079ca83 allow for setting beta schedules of unets differently in the decoder, as what was used in the paper was cosine, cosine, linear 2022-06-20 08:56:37 -07:00
Aidan Dempster
58892135d9 Distributed Training of the Decoder (#121)
* Converted decoder trainer to use accelerate

* Fixed issue where metric evaluation would hang on distributed mode

* Implemented functional saving
Loading still fails due to some issue with the optimizer

* Fixed issue with loading decoders

* Fixed issue with tracker config

* Fixed issue with amp
Updated logging to be more logical

* Saving checkpoint now saves position in training as well
Fixed an issue with running out of gpu space due to loading weights into the gpu twice

* Fixed ema for distributed training

* Fixed isue where get_pkg_version was reintroduced

* Changed decoder trainer to upload config as a file

Fixed issue where loading best would error
2022-06-19 09:25:54 -07:00
zion
fe19b508ca Distributed Training of the Prior (#112)
* distributed prior trainer

better EMA support

update load and save methods of prior

* update prior training script

add test evalution & ema validation

add more tracking metrics

small cleanup
2022-06-19 08:46:14 -07:00
Phil Wang
934c9728dc some cleanup 2022-06-04 16:54:15 -07:00
zion
64c2f9c4eb implement ema warmup from @crowsonkb (#140) 2022-06-04 13:26:34 -07:00
zion
83517849e5 ema module fixes (#139) 2022-06-03 19:43:51 -07:00
Phil Wang
9cc475f6e7 fix update_every within EMA 2022-06-03 10:21:05 -07:00
Phil Wang
1ffeecd0ca lower default ema beta value 2022-05-31 11:55:21 -07:00
Phil Wang
b588286288 fix version 2022-05-30 11:06:34 -07:00
Phil Wang
857b9fbf1e allow for one to stop grouping out weight decayable parameters, to debug optimizer state dict problem 2022-05-24 21:42:32 -07:00
Phil Wang
ae42d03006 allow for saving of additional fields on save method in trainers, and return loaded objects from the load method 2022-05-22 22:14:25 -07:00
Phil Wang
501a8c7c46 small cleanup 2022-05-22 15:39:38 -07:00
Phil Wang
49de72040c fix decoder trainer optimizer loading (since there are multiple for each unet), also save and load step number correctly 2022-05-22 15:21:00 -07:00
Phil Wang
e527002472 take care of saving and loading functions on the diffusion prior and decoder training classes 2022-05-22 15:10:15 -07:00
Phil Wang
bb86ab2404 update sample, and set default gradient clipping value for decoder training 2022-05-16 17:38:30 -07:00
Phil Wang
c7ea8748db default decoder learning rate to what was in the paper 2022-05-16 13:33:54 -07:00
Phil Wang
13382885d9 final update to dalle2 repository for a while - sampling from prior in chunks automatically with max_batch_size keyword given 2022-05-16 12:57:31 -07:00
Phil Wang
164d9be444 use a decorator and take care of sampling in chunks (max_batch_size keyword), in case one is sampling a huge grid of images 2022-05-16 12:34:28 -07:00
Phil Wang
89ff04cfe2 final tweak to EMA class 2022-05-16 11:54:34 -07:00
Phil Wang
f4016f6302 allow for overriding use of EMA during sampling in decoder trainer with use_non_ema keyword, also fix some issues with automatic normalization of images and low res conditioning image if latent diffusion is in play 2022-05-16 11:18:30 -07:00
Phil Wang
dab106d4e5 back to no_grad for now, also keep track and restore unet devices in one_unet_in_gpu contextmanager 2022-05-16 09:36:14 -07:00
Phil Wang
bb151ca6b1 unet_number on decoder trainer only needs to be passed in if there is greater than 1 unet, so that unconditional training of a single ddpm is seamless (experiment in progress locally) 2022-05-16 09:17:17 -07:00