attempting to fix issue with deepspeed fp16 seeing overflowing gradient

This commit is contained in:
Phil Wang
2022-07-06 08:27:34 -07:00
parent f33453df9f
commit 1bd8a7835a
2 changed files with 6 additions and 2 deletions

View File

@@ -335,6 +335,10 @@ def approx_standard_normal_cdf(x):
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
assert x.shape == means.shape == log_scales.shape
# attempting to correct nan gradients when learned variance is turned on
# in the setting of deepspeed fp16
eps = 1e-12 if x.dtype == torch.float32 else 1e-5
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1. / 255.)
@@ -349,7 +353,7 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
log_cdf_plus,
torch.where(x > thres,
log_one_minus_cdf_min,
log(cdf_delta)))
log(cdf_delta, eps = eps)))
return log_probs