mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
allows one to shortcut sampling at a specific unet number, if one were to be training in stages
This commit is contained in:
@@ -783,7 +783,7 @@ for unet_number in (1, 2):
|
|||||||
# you can sample from the exponentially moving averaged unets as so
|
# you can sample from the exponentially moving averaged unets as so
|
||||||
|
|
||||||
mock_image_embed = torch.randn(4, 512).cuda()
|
mock_image_embed = torch.randn(4, 512).cuda()
|
||||||
images = decoder.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
|
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||||
```
|
```
|
||||||
|
|
||||||
## CLI (wip)
|
## CLI (wip)
|
||||||
|
|||||||
@@ -1540,7 +1540,13 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@eval_decorator
|
@eval_decorator
|
||||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
def sample(
|
||||||
|
self,
|
||||||
|
image_embed,
|
||||||
|
text = None,
|
||||||
|
cond_scale = 1.,
|
||||||
|
stop_at_unet_number = None
|
||||||
|
):
|
||||||
batch_size = image_embed.shape[0]
|
batch_size = image_embed.shape[0]
|
||||||
|
|
||||||
text_encodings = text_mask = None
|
text_encodings = text_mask = None
|
||||||
@@ -1552,7 +1558,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
img = None
|
img = None
|
||||||
|
|
||||||
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||||
|
|
||||||
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
|
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
|
||||||
|
|
||||||
@@ -1584,6 +1590,9 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
img = vae.decode(img)
|
img = vae.decode(img)
|
||||||
|
|
||||||
|
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
|
||||||
|
break
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Reference in New Issue
Block a user