mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
fix for small validation bug for sampling steps
This commit is contained in:
@@ -2047,7 +2047,7 @@ class Decoder(nn.Module):
|
|||||||
self.noise_schedulers = nn.ModuleList([])
|
self.noise_schedulers = nn.ModuleList([])
|
||||||
|
|
||||||
for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
|
for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
|
||||||
assert sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
|
assert not exists(sample_timesteps) or sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
|
||||||
|
|
||||||
noise_scheduler = NoiseScheduler(
|
noise_scheduler = NoiseScheduler(
|
||||||
beta_schedule = unet_beta_schedule,
|
beta_schedule = unet_beta_schedule,
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.19.2'
|
__version__ = '0.19.3'
|
||||||
|
|||||||
Reference in New Issue
Block a user