From ffd342e9d06acf3d28165609fe927e2f6be6498a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 3 Jun 2022 09:34:57 -0700 Subject: [PATCH] allow for an option to constrain the variance interpolation fraction coming out from the unet for learned variance, if it is turned on --- dalle2_pytorch/dalle2_pytorch.py | 5 +++++ dalle2_pytorch/version.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 952c9f2..ca3eba7 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1745,6 +1745,7 @@ class Decoder(BaseGaussianDiffusion): clip_x_start = True, clip_adapter_overrides = dict(), learned_variance = True, + learned_variance_constrain_frac = False, vb_loss_weight = 0.001, unconditional = False, auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader @@ -1805,6 +1806,7 @@ class Decoder(BaseGaussianDiffusion): learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False) self.learned_variance = learned_variance + self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1 self.vb_loss_weight = vb_loss_weight # construct unets and vaes @@ -1945,6 +1947,9 @@ class Decoder(BaseGaussianDiffusion): max_log = extract(torch.log(self.betas), t, x.shape) var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized) + if self.learned_variance_constrain_frac: + var_interp_frac = var_interp_frac.sigmoid() + posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log posterior_variance = posterior_log_variance.exp() diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 085bc85..6300a70 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.6.8' +__version__ = '0.6.9'