Commit Graph

24 Commits

Author SHA1 Message Date
Phil Wang
56883910fb cleanup 2022-06-20 11:14:55 -07:00
Phil Wang
67f0740777 small cleanup 2022-06-20 08:59:51 -07:00
Phil Wang
138079ca83 allow for setting beta schedules of unets differently in the decoder, as what was used in the paper was cosine, cosine, linear 2022-06-20 08:56:37 -07:00
Aidan Dempster
58892135d9 Distributed Training of the Decoder (#121)
* Converted decoder trainer to use accelerate

* Fixed issue where metric evaluation would hang on distributed mode

* Implemented functional saving
Loading still fails due to some issue with the optimizer

* Fixed issue with loading decoders

* Fixed issue with tracker config

* Fixed issue with amp
Updated logging to be more logical

* Saving checkpoint now saves position in training as well
Fixed an issue with running out of gpu space due to loading weights into the gpu twice

* Fixed ema for distributed training

* Fixed isue where get_pkg_version was reintroduced

* Changed decoder trainer to upload config as a file

Fixed issue where loading best would error
2022-06-19 09:25:54 -07:00
zion
fe19b508ca Distributed Training of the Prior (#112)
* distributed prior trainer

better EMA support

update load and save methods of prior

* update prior training script

add test evalution & ema validation

add more tracking metrics

small cleanup
2022-06-19 08:46:14 -07:00
Phil Wang
934c9728dc some cleanup 2022-06-04 16:54:15 -07:00
zion
64c2f9c4eb implement ema warmup from @crowsonkb (#140) 2022-06-04 13:26:34 -07:00
zion
83517849e5 ema module fixes (#139) 2022-06-03 19:43:51 -07:00
Phil Wang
9cc475f6e7 fix update_every within EMA 2022-06-03 10:21:05 -07:00
Phil Wang
1ffeecd0ca lower default ema beta value 2022-05-31 11:55:21 -07:00
Phil Wang
b588286288 fix version 2022-05-30 11:06:34 -07:00
Phil Wang
857b9fbf1e allow for one to stop grouping out weight decayable parameters, to debug optimizer state dict problem 2022-05-24 21:42:32 -07:00
Phil Wang
ae42d03006 allow for saving of additional fields on save method in trainers, and return loaded objects from the load method 2022-05-22 22:14:25 -07:00
Phil Wang
501a8c7c46 small cleanup 2022-05-22 15:39:38 -07:00
Phil Wang
49de72040c fix decoder trainer optimizer loading (since there are multiple for each unet), also save and load step number correctly 2022-05-22 15:21:00 -07:00
Phil Wang
e527002472 take care of saving and loading functions on the diffusion prior and decoder training classes 2022-05-22 15:10:15 -07:00
Phil Wang
bb86ab2404 update sample, and set default gradient clipping value for decoder training 2022-05-16 17:38:30 -07:00
Phil Wang
c7ea8748db default decoder learning rate to what was in the paper 2022-05-16 13:33:54 -07:00
Phil Wang
13382885d9 final update to dalle2 repository for a while - sampling from prior in chunks automatically with max_batch_size keyword given 2022-05-16 12:57:31 -07:00
Phil Wang
164d9be444 use a decorator and take care of sampling in chunks (max_batch_size keyword), in case one is sampling a huge grid of images 2022-05-16 12:34:28 -07:00
Phil Wang
89ff04cfe2 final tweak to EMA class 2022-05-16 11:54:34 -07:00
Phil Wang
f4016f6302 allow for overriding use of EMA during sampling in decoder trainer with use_non_ema keyword, also fix some issues with automatic normalization of images and low res conditioning image if latent diffusion is in play 2022-05-16 11:18:30 -07:00
Phil Wang
dab106d4e5 back to no_grad for now, also keep track and restore unet devices in one_unet_in_gpu contextmanager 2022-05-16 09:36:14 -07:00
Phil Wang
bb151ca6b1 unet_number on decoder trainer only needs to be passed in if there is greater than 1 unet, so that unconditional training of a single ddpm is seamless (experiment in progress locally) 2022-05-16 09:17:17 -07:00