From c3d4a7ffe4be4150a2353b718b165e1874c92766 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 16 May 2022 12:50:07 -0700 Subject: [PATCH] update working unconditional decoder example --- README.md | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index c62c818..69c94ff 100644 --- a/README.md +++ b/README.md @@ -867,7 +867,7 @@ ex. ```python import torch -from dalle2_pytorch import Unet, Decoder +from dalle2_pytorch import Unet, Decoder, DecoderTrainer # unet for the cascading ddpm @@ -890,20 +890,24 @@ decoder = Decoder( unconditional = True ).cuda() -# mock images (get a lot of this) +# decoder trainer + +decoder_trainer = DecoderTrainer(decoder) + +# images (get a lot of this) images = torch.randn(1, 3, 512, 512).cuda() # feed images into decoder for i in (1, 2): - loss = decoder(images, unet_number = i) - loss.backward() + loss = decoder_trainer(images, unet_number = i) + decoder_trainer.update(unet_number = i) -# do the above for many many many many steps +# do the above for many many many many images # then it will learn to generate images -images = decoder.sample(batch_size = 2) # (2, 3, 512, 512) +images = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512) ``` ## Dataloaders