update working unconditional decoder example

This commit is contained in:
Phil Wang
2022-05-16 12:50:07 -07:00
parent 164d9be444
commit c3d4a7ffe4

View File

@@ -867,7 +867,7 @@ ex.
```python ```python
import torch import torch
from dalle2_pytorch import Unet, Decoder from dalle2_pytorch import Unet, Decoder, DecoderTrainer
# unet for the cascading ddpm # unet for the cascading ddpm
@@ -890,20 +890,24 @@ decoder = Decoder(
unconditional = True unconditional = True
).cuda() ).cuda()
# mock images (get a lot of this) # decoder trainer
decoder_trainer = DecoderTrainer(decoder)
# images (get a lot of this)
images = torch.randn(1, 3, 512, 512).cuda() images = torch.randn(1, 3, 512, 512).cuda()
# feed images into decoder # feed images into decoder
for i in (1, 2): for i in (1, 2):
loss = decoder(images, unet_number = i) loss = decoder_trainer(images, unet_number = i)
loss.backward() decoder_trainer.update(unet_number = i)
# do the above for many many many many steps # do the above for many many many many images
# then it will learn to generate images # then it will learn to generate images
images = decoder.sample(batch_size = 2) # (2, 3, 512, 512) images = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512)
``` ```
## Dataloaders ## Dataloaders