mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-13 12:04:24 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6a59c7093d | ||
|
|
a6cdbe0b9c | ||
|
|
e928ae5c34 |
@@ -337,7 +337,7 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
|
||||
|
||||
# attempting to correct nan gradients when learned variance is turned on
|
||||
# in the setting of deepspeed fp16
|
||||
eps = 1e-12 if x.dtype == torch.float32 else 1e-5
|
||||
eps = 1e-12 if x.dtype == torch.float32 else 1e-3
|
||||
|
||||
centered_x = x - means
|
||||
inv_stdv = torch.exp(-log_scales)
|
||||
@@ -345,8 +345,8 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
|
||||
cdf_plus = approx_standard_normal_cdf(plus_in)
|
||||
min_in = inv_stdv * (centered_x - 1. / 255.)
|
||||
cdf_min = approx_standard_normal_cdf(min_in)
|
||||
log_cdf_plus = log(cdf_plus)
|
||||
log_one_minus_cdf_min = log(1. - cdf_min)
|
||||
log_cdf_plus = log(cdf_plus, eps = eps)
|
||||
log_one_minus_cdf_min = log(1. - cdf_min, eps = eps)
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
|
||||
log_probs = torch.where(x < -thres,
|
||||
|
||||
@@ -173,14 +173,26 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
super().__init__()
|
||||
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
|
||||
assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device."
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
|
||||
# verbosity
|
||||
|
||||
self.verbose = verbose
|
||||
|
||||
# assign some helpful member vars
|
||||
|
||||
self.accelerator = accelerator
|
||||
self.device = accelerator.device if exists(accelerator) else device
|
||||
self.text_conditioned = diffusion_prior.condition_on_text_encodings
|
||||
|
||||
# setting the device
|
||||
|
||||
if not exists(accelerator) and not exists(device):
|
||||
diffusion_prior_device = next(diffusion_prior.parameters()).device
|
||||
self.print(f'accelerator not given, and device not specified: defaulting to device of diffusion prior parameters - {diffusion_prior_device}')
|
||||
self.device = diffusion_prior_device
|
||||
else:
|
||||
self.device = accelerator.device if exists(accelerator) else device
|
||||
|
||||
# save model
|
||||
|
||||
self.diffusion_prior = diffusion_prior
|
||||
@@ -214,13 +226,9 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
# verbosity
|
||||
|
||||
self.verbose = verbose
|
||||
|
||||
# track steps internally
|
||||
|
||||
self.register_buffer('step', torch.tensor([0]))
|
||||
self.register_buffer('step', torch.tensor([0], device = self.device))
|
||||
|
||||
# accelerator wrappers
|
||||
|
||||
@@ -465,7 +473,7 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
|
||||
|
||||
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||
|
||||
optimizers = []
|
||||
schedulers = []
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.16.10'
|
||||
__version__ = '0.16.14'
|
||||
|
||||
Reference in New Issue
Block a user