From 6a59c7093df2451dd28edfab732b5ae86cdc82bb Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 6 Jul 2022 19:05:50 -0700 Subject: [PATCH] more shots in the dark regarding fp16 with learned variance for deepspeed issue --- dalle2_pytorch/dalle2_pytorch.py | 6 +++--- dalle2_pytorch/version.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 4301a54..878ab8a 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -337,7 +337,7 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999): # attempting to correct nan gradients when learned variance is turned on # in the setting of deepspeed fp16 - eps = 1e-12 if x.dtype == torch.float32 else 1e-5 + eps = 1e-12 if x.dtype == torch.float32 else 1e-3 centered_x = x - means inv_stdv = torch.exp(-log_scales) @@ -345,8 +345,8 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999): cdf_plus = approx_standard_normal_cdf(plus_in) min_in = inv_stdv * (centered_x - 1. / 255.) cdf_min = approx_standard_normal_cdf(min_in) - log_cdf_plus = log(cdf_plus) - log_one_minus_cdf_min = log(1. - cdf_min) + log_cdf_plus = log(cdf_plus, eps = eps) + log_one_minus_cdf_min = log(1. - cdf_min, eps = eps) cdf_delta = cdf_plus - cdf_min log_probs = torch.where(x < -thres, diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index bdc0bed..8e73b46 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.16.13' +__version__ = '0.16.14'