diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index a0ed53d..fad79c2 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1866,14 +1866,17 @@ class Decoder(nn.Module): if not exists(beta_schedule): beta_schedule = ('cosine', *(('cosine',) * max(num_unets - 2, 0)), *(('linear',) * int(num_unets > 1))) + beta_schedule = cast_tuple(beta_schedule, num_unets) + p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets) + self.noise_schedulers = nn.ModuleList([]) - for unet_beta_schedule in beta_schedule: + for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(beta_schedule, p2_loss_weight_gamma): noise_scheduler = NoiseScheduler( beta_schedule = unet_beta_schedule, timesteps = timesteps, loss_type = loss_type, - p2_loss_weight_gamma = p2_loss_weight_gamma, + p2_loss_weight_gamma = unet_p2_loss_weight_gamma, p2_loss_weight_k = p2_loss_weight_k ) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index ae4865c..2b3823f 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.11.1' +__version__ = '0.11.2'