mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 02:04:19 +01:00
update working unconditional decoder example
This commit is contained in:
16
README.md
16
README.md
@@ -867,7 +867,7 @@ ex.
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
from dalle2_pytorch import Unet, Decoder
|
from dalle2_pytorch import Unet, Decoder, DecoderTrainer
|
||||||
|
|
||||||
# unet for the cascading ddpm
|
# unet for the cascading ddpm
|
||||||
|
|
||||||
@@ -890,20 +890,24 @@ decoder = Decoder(
|
|||||||
unconditional = True
|
unconditional = True
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
# mock images (get a lot of this)
|
# decoder trainer
|
||||||
|
|
||||||
|
decoder_trainer = DecoderTrainer(decoder)
|
||||||
|
|
||||||
|
# images (get a lot of this)
|
||||||
|
|
||||||
images = torch.randn(1, 3, 512, 512).cuda()
|
images = torch.randn(1, 3, 512, 512).cuda()
|
||||||
|
|
||||||
# feed images into decoder
|
# feed images into decoder
|
||||||
|
|
||||||
for i in (1, 2):
|
for i in (1, 2):
|
||||||
loss = decoder(images, unet_number = i)
|
loss = decoder_trainer(images, unet_number = i)
|
||||||
loss.backward()
|
decoder_trainer.update(unet_number = i)
|
||||||
|
|
||||||
# do the above for many many many many steps
|
# do the above for many many many many images
|
||||||
# then it will learn to generate images
|
# then it will learn to generate images
|
||||||
|
|
||||||
images = decoder.sample(batch_size = 2) # (2, 3, 512, 512)
|
images = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Dataloaders
|
## Dataloaders
|
||||||
|
|||||||
Reference in New Issue
Block a user