add diffusion prior trainer, which automatically takes care of the exponential moving average (training and sampling), as well as mixed precision, gradient clipping

This commit is contained in:
Phil Wang
2022-05-06 08:11:09 -07:00
parent 878b555ef7
commit 98df1ba51e
5 changed files with 154 additions and 4 deletions

View File

@@ -1,6 +1,6 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.train import DecoderTrainer
from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP