mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
allow for an option to constrain the variance interpolation fraction coming out from the unet for learned variance, if it is turned on
This commit is contained in:
@@ -1745,6 +1745,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
clip_x_start = True,
|
clip_x_start = True,
|
||||||
clip_adapter_overrides = dict(),
|
clip_adapter_overrides = dict(),
|
||||||
learned_variance = True,
|
learned_variance = True,
|
||||||
|
learned_variance_constrain_frac = False,
|
||||||
vb_loss_weight = 0.001,
|
vb_loss_weight = 0.001,
|
||||||
unconditional = False,
|
unconditional = False,
|
||||||
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||||
@@ -1805,6 +1806,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
|
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
|
||||||
self.learned_variance = learned_variance
|
self.learned_variance = learned_variance
|
||||||
|
self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1
|
||||||
self.vb_loss_weight = vb_loss_weight
|
self.vb_loss_weight = vb_loss_weight
|
||||||
|
|
||||||
# construct unets and vaes
|
# construct unets and vaes
|
||||||
@@ -1945,6 +1947,9 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
max_log = extract(torch.log(self.betas), t, x.shape)
|
max_log = extract(torch.log(self.betas), t, x.shape)
|
||||||
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
|
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
|
||||||
|
|
||||||
|
if self.learned_variance_constrain_frac:
|
||||||
|
var_interp_frac = var_interp_frac.sigmoid()
|
||||||
|
|
||||||
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
||||||
posterior_variance = posterior_log_variance.exp()
|
posterior_variance = posterior_log_variance.exp()
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.6.8'
|
__version__ = '0.6.9'
|
||||||
|
|||||||
Reference in New Issue
Block a user