From f545ce18f48ca98acdc62247f6c3c7e0b4cdaaf5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 20 Jun 2022 09:43:31 -0700 Subject: [PATCH] be able to turn off p2 loss reweighting for upsamplers --- dalle2_pytorch/dalle2_pytorch.py | 7 +++++-- dalle2_pytorch/version.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index a0ed53d..fad79c2 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1866,14 +1866,17 @@ class Decoder(nn.Module): if not exists(beta_schedule): beta_schedule = ('cosine', *(('cosine',) * max(num_unets - 2, 0)), *(('linear',) * int(num_unets > 1))) + beta_schedule = cast_tuple(beta_schedule, num_unets) + p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets) + self.noise_schedulers = nn.ModuleList([]) - for unet_beta_schedule in beta_schedule: + for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(beta_schedule, p2_loss_weight_gamma): noise_scheduler = NoiseScheduler( beta_schedule = unet_beta_schedule, timesteps = timesteps, loss_type = loss_type, - p2_loss_weight_gamma = p2_loss_weight_gamma, + p2_loss_weight_gamma = unet_p2_loss_weight_gamma, p2_loss_weight_k = p2_loss_weight_k ) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index ae4865c..2b3823f 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.11.1' +__version__ = '0.11.2'