allows one to shortcut sampling at a specific unet number, if one were to be training in stages

This commit is contained in:
Phil Wang
2022-04-30 16:05:13 -07:00
parent ebe01749ed
commit d1a697ac23
3 changed files with 13 additions and 4 deletions

View File

@@ -1540,7 +1540,13 @@ class Decoder(BaseGaussianDiffusion):
@torch.no_grad()
@eval_decorator
def sample(self, image_embed, text = None, cond_scale = 1.):
def sample(
self,
image_embed,
text = None,
cond_scale = 1.,
stop_at_unet_number = None
):
batch_size = image_embed.shape[0]
text_encodings = text_mask = None
@@ -1552,7 +1558,7 @@ class Decoder(BaseGaussianDiffusion):
img = None
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
@@ -1584,6 +1590,9 @@ class Decoder(BaseGaussianDiffusion):
img = vae.decode(img)
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
break
return img
def forward(