allow for using classifier free guidance for some unets but not others, by passing in a tuple of cond_scale during sampling for decoder, just in case it is causing issues for upsamplers

This commit is contained in:
Phil Wang
2022-07-13 13:12:30 -07:00
parent f988207718
commit f141144a6d
2 changed files with 6 additions and 3 deletions

View File

@@ -2498,7 +2498,10 @@ class Decoder(nn.Module):
img = None img = None
is_cuda = next(self.parameters()).is_cuda is_cuda = next(self.parameters()).is_cuda
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps)): num_unets = len(self.unets)
cond_scale = cast_tuple(cond_scale, num_unets)
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps, cond_scale)):
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context() context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
@@ -2520,7 +2523,7 @@ class Decoder(nn.Module):
shape, shape,
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
cond_scale = cond_scale, cond_scale = unet_cond_scale,
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
learned_variance = learned_variance, learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion, clip_denoised = not is_latent_diffusion,

View File

@@ -1 +1 @@
__version__ = '0.23.8' __version__ = '0.23.9'