mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
prepare for latent diffusion in the first DDPM of the cascade in the Decoder
This commit is contained in:
@@ -1095,6 +1095,7 @@ class Decoder(nn.Module):
|
||||
|
||||
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
||||
self.image_sizes = image_sizes
|
||||
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
|
||||
|
||||
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
|
||||
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
|
||||
@@ -1272,9 +1273,9 @@ class Decoder(nn.Module):
|
||||
|
||||
img = None
|
||||
|
||||
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
|
||||
for unet, channel, image_size in tqdm(zip(self.unets, self.sample_channels, self.image_sizes)):
|
||||
with self.one_unet_in_gpu(unet = unet):
|
||||
shape = (batch_size, channels, image_size, image_size)
|
||||
shape = (batch_size, channel, image_size, image_size)
|
||||
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
|
||||
|
||||
return img
|
||||
|
||||
Reference in New Issue
Block a user