fix for small validation bug for sampling steps

This commit is contained in:
Phil Wang
2022-07-09 17:31:54 -07:00
parent 47ae17b36e
commit 4878762627
2 changed files with 2 additions and 2 deletions

View File

@@ -2047,7 +2047,7 @@ class Decoder(nn.Module):
self.noise_schedulers = nn.ModuleList([])
for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
assert sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
assert not exists(sample_timesteps) or sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
noise_scheduler = NoiseScheduler(
beta_schedule = unet_beta_schedule,