mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
attempting to fix issue with deepspeed fp16 seeing overflowing gradient
This commit is contained in:
@@ -335,6 +335,10 @@ def approx_standard_normal_cdf(x):
|
||||
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
|
||||
assert x.shape == means.shape == log_scales.shape
|
||||
|
||||
# attempting to correct nan gradients when learned variance is turned on
|
||||
# in the setting of deepspeed fp16
|
||||
eps = 1e-12 if x.dtype == torch.float32 else 1e-5
|
||||
|
||||
centered_x = x - means
|
||||
inv_stdv = torch.exp(-log_scales)
|
||||
plus_in = inv_stdv * (centered_x + 1. / 255.)
|
||||
@@ -349,7 +353,7 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
|
||||
log_cdf_plus,
|
||||
torch.where(x > thres,
|
||||
log_one_minus_cdf_min,
|
||||
log(cdf_delta)))
|
||||
log(cdf_delta, eps = eps)))
|
||||
|
||||
return log_probs
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.16.9'
|
||||
__version__ = '0.16.10'
|
||||
|
||||
Reference in New Issue
Block a user