z
|
cd5f2c1de4
|
simulate unrelated captions as a training metric (#66)
* add unrelated embedding metric
* change to torch.roll
Co-authored-by: nousr <z@localhost.com>
Co-authored-by: nousr <>
|
2022-05-07 05:34:59 -07:00 |
|
Phil Wang
|
63029f7388
|
remove l2norm output from train_diffusion_prior.py
|
2022-05-05 19:07:58 -07:00 |
|
Kumar R
|
2d9963d30e
|
Reporting metrics - Cosine similarity. (#55)
* Update train_diffusion_prior.py
* Delete train_diffusion_prior.py
* Cosine similarity logging.
* Update train_diffusion_prior.py
* Report Cosine metrics every N steps.
|
2022-05-04 08:04:36 -07:00 |
|
Kumar R
|
72c16b496e
|
Update train_diffusion_prior.py (#53)
|
2022-05-02 22:44:57 -07:00 |
|
z
|
81d83dd7f2
|
defaults align with paper (#52)
Co-authored-by: nousr <>
|
2022-05-02 13:52:11 -07:00 |
|
Phil Wang
|
aa8d135245
|
allow laion to experiment with normformer in diffusion prior
|
2022-05-02 11:35:00 -07:00 |
|
Romain Beaumont
|
2d25c89f35
|
Fix passing of l2norm_output to DiffusionPriorNetwork (#51)
|
2022-05-02 10:48:16 -07:00 |
|
Phil Wang
|
3fe96c208a
|
add ability to train diffusion prior with l2norm on output image embed
|
2022-05-02 09:53:20 -07:00 |
|
Phil Wang
|
7ee0ecc388
|
mixed precision for training diffusion prior + save optimizer and scaler states
|
2022-05-02 09:31:04 -07:00 |
|
Phil Wang
|
f7df3caaf3
|
address not calculating average eval / test loss when training diffusion prior https://github.com/lucidrains/DALLE2-pytorch/issues/49
|
2022-05-02 08:51:41 -07:00 |
|
Phil Wang
|
d991b8c39c
|
just clip the diffusion prior network parameters
|
2022-05-01 12:01:08 -07:00 |
|
Phil Wang
|
35cd63982d
|
add gradient clipping, make sure weight decay is configurable, make sure learning rate is actually passed into get_optimizer, make sure model is set to training mode at beginning of each epoch
|
2022-05-01 11:55:38 -07:00 |
|
Kumar R
|
53ce6dfdf6
|
All changes implemented, current run happening. Link to wandb run in comments. (#43)
* Train DiffusionPrior with pre-computed embeddings
This is in response to https://github.com/lucidrains/DALLE2-pytorch/issues/29 - more metrics will get added.
|
2022-05-01 11:46:59 -07:00 |
|