allows one to shortcut sampling at a specific unet number, if one were to be training in stages

This commit is contained in:
Phil Wang
2022-04-30 16:05:13 -07:00
parent ebe01749ed
commit d1a697ac23
3 changed files with 13 additions and 4 deletions

View File

@@ -783,7 +783,7 @@ for unet_number in (1, 2):
# you can sample from the exponentially moving averaged unets as so
mock_image_embed = torch.randn(4, 512).cuda()
images = decoder.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
```
## CLI (wip)