diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 3e111a4..1a9b3e2 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -2498,7 +2498,10 @@ class Decoder(nn.Module): img = None is_cuda = next(self.parameters()).is_cuda - for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps)): + num_unets = len(self.unets) + cond_scale = cast_tuple(cond_scale, num_unets) + + for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps, cond_scale)): context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context() @@ -2520,7 +2523,7 @@ class Decoder(nn.Module): shape, image_embed = image_embed, text_encodings = text_encodings, - cond_scale = cond_scale, + cond_scale = unet_cond_scale, predict_x_start = predict_x_start, learned_variance = learned_variance, clip_denoised = not is_latent_diffusion, diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 3a1985b..d5d3ec6 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.23.8' +__version__ = '0.23.9'