diff --git a/README.md b/README.md index c62c818..69c94ff 100644 --- a/README.md +++ b/README.md @@ -867,7 +867,7 @@ ex. ```python import torch -from dalle2_pytorch import Unet, Decoder +from dalle2_pytorch import Unet, Decoder, DecoderTrainer # unet for the cascading ddpm @@ -890,20 +890,24 @@ decoder = Decoder( unconditional = True ).cuda() -# mock images (get a lot of this) +# decoder trainer + +decoder_trainer = DecoderTrainer(decoder) + +# images (get a lot of this) images = torch.randn(1, 3, 512, 512).cuda() # feed images into decoder for i in (1, 2): - loss = decoder(images, unet_number = i) - loss.backward() + loss = decoder_trainer(images, unet_number = i) + decoder_trainer.update(unet_number = i) -# do the above for many many many many steps +# do the above for many many many many images # then it will learn to generate images -images = decoder.sample(batch_size = 2) # (2, 3, 512, 512) +images = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512) ``` ## Dataloaders