mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-18 17:34:18 +01:00
more shots in the dark regarding fp16 with learned variance for deepspeed issue
This commit is contained in:
@@ -337,7 +337,7 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
|
||||
|
||||
# attempting to correct nan gradients when learned variance is turned on
|
||||
# in the setting of deepspeed fp16
|
||||
eps = 1e-12 if x.dtype == torch.float32 else 1e-5
|
||||
eps = 1e-12 if x.dtype == torch.float32 else 1e-3
|
||||
|
||||
centered_x = x - means
|
||||
inv_stdv = torch.exp(-log_scales)
|
||||
@@ -345,8 +345,8 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
|
||||
cdf_plus = approx_standard_normal_cdf(plus_in)
|
||||
min_in = inv_stdv * (centered_x - 1. / 255.)
|
||||
cdf_min = approx_standard_normal_cdf(min_in)
|
||||
log_cdf_plus = log(cdf_plus)
|
||||
log_one_minus_cdf_min = log(1. - cdf_min)
|
||||
log_cdf_plus = log(cdf_plus, eps = eps)
|
||||
log_one_minus_cdf_min = log(1. - cdf_min, eps = eps)
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
|
||||
log_probs = torch.where(x < -thres,
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.16.13'
|
||||
__version__ = '0.16.14'
|
||||
|
||||
Reference in New Issue
Block a user