From 1bd8a7835ab66f266daf11cff2aa56ed6a46a001 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 6 Jul 2022 08:27:34 -0700 Subject: [PATCH] attempting to fix issue with deepspeed fp16 seeing overflowing gradient --- dalle2_pytorch/dalle2_pytorch.py | 6 +++++- dalle2_pytorch/version.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 694483d..4301a54 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -335,6 +335,10 @@ def approx_standard_normal_cdf(x): def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999): assert x.shape == means.shape == log_scales.shape + # attempting to correct nan gradients when learned variance is turned on + # in the setting of deepspeed fp16 + eps = 1e-12 if x.dtype == torch.float32 else 1e-5 + centered_x = x - means inv_stdv = torch.exp(-log_scales) plus_in = inv_stdv * (centered_x + 1. / 255.) @@ -349,7 +353,7 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999): log_cdf_plus, torch.where(x > thres, log_one_minus_cdf_min, - log(cdf_delta))) + log(cdf_delta, eps = eps))) return log_probs diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index bc68d1e..029a258 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.16.9' +__version__ = '0.16.10'