mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 01:34:19 +01:00
more shots in the dark regarding fp16 with learned variance for deepspeed issue
This commit is contained in:
@@ -337,7 +337,7 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
|
|||||||
|
|
||||||
# attempting to correct nan gradients when learned variance is turned on
|
# attempting to correct nan gradients when learned variance is turned on
|
||||||
# in the setting of deepspeed fp16
|
# in the setting of deepspeed fp16
|
||||||
eps = 1e-12 if x.dtype == torch.float32 else 1e-5
|
eps = 1e-12 if x.dtype == torch.float32 else 1e-3
|
||||||
|
|
||||||
centered_x = x - means
|
centered_x = x - means
|
||||||
inv_stdv = torch.exp(-log_scales)
|
inv_stdv = torch.exp(-log_scales)
|
||||||
@@ -345,8 +345,8 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
|
|||||||
cdf_plus = approx_standard_normal_cdf(plus_in)
|
cdf_plus = approx_standard_normal_cdf(plus_in)
|
||||||
min_in = inv_stdv * (centered_x - 1. / 255.)
|
min_in = inv_stdv * (centered_x - 1. / 255.)
|
||||||
cdf_min = approx_standard_normal_cdf(min_in)
|
cdf_min = approx_standard_normal_cdf(min_in)
|
||||||
log_cdf_plus = log(cdf_plus)
|
log_cdf_plus = log(cdf_plus, eps = eps)
|
||||||
log_one_minus_cdf_min = log(1. - cdf_min)
|
log_one_minus_cdf_min = log(1. - cdf_min, eps = eps)
|
||||||
cdf_delta = cdf_plus - cdf_min
|
cdf_delta = cdf_plus - cdf_min
|
||||||
|
|
||||||
log_probs = torch.where(x < -thres,
|
log_probs = torch.where(x < -thres,
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.16.13'
|
__version__ = '0.16.14'
|
||||||
|
|||||||
Reference in New Issue
Block a user