prepare for latent diffusion in the first DDPM of the cascade in the Decoder

This commit is contained in:
Phil Wang
2022-04-21 17:54:31 -07:00
parent 0b4ec34efb
commit ad17c69ab6
5 changed files with 538 additions and 16 deletions

View File

@@ -1095,6 +1095,7 @@ class Decoder(nn.Module):
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
self.image_sizes = image_sizes
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
@@ -1272,9 +1273,9 @@ class Decoder(nn.Module):
img = None
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
for unet, channel, image_size in tqdm(zip(self.unets, self.sample_channels, self.image_sizes)):
with self.one_unet_in_gpu(unet = unet):
shape = (batch_size, channels, image_size, image_size)
shape = (batch_size, channel, image_size, image_size)
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
return img