Compare commits

...

3 Commits

3 changed files with 26 additions and 12 deletions

View File

@@ -335,6 +335,10 @@ def approx_standard_normal_cdf(x):
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
assert x.shape == means.shape == log_scales.shape
# attempting to correct nan gradients when learned variance is turned on
# in the setting of deepspeed fp16
eps = 1e-12 if x.dtype == torch.float32 else 1e-5
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1. / 255.)
@@ -349,7 +353,7 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
log_cdf_plus,
torch.where(x > thres,
log_one_minus_cdf_min,
log(cdf_delta)))
log(cdf_delta, eps = eps)))
return log_probs
@@ -1127,11 +1131,12 @@ class SinusoidalPosEmb(nn.Module):
self.dim = dim
def forward(self, x):
dtype, device = x.dtype, x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
emb = rearrange(x.type_as(emb), 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1)
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)
class Block(nn.Module):
def __init__(
@@ -1626,6 +1631,7 @@ class Unet(nn.Module):
# time conditioning
time = time.type_as(x)
time_hiddens = self.to_time_hiddens(time)
time_tokens = self.to_time_tokens(time_hiddens)

View File

@@ -173,14 +173,26 @@ class DiffusionPriorTrainer(nn.Module):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device."
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
# verbosity
self.verbose = verbose
# assign some helpful member vars
self.accelerator = accelerator
self.device = accelerator.device if exists(accelerator) else device
self.text_conditioned = diffusion_prior.condition_on_text_encodings
# setting the device
if not exists(accelerator) and not exists(device):
diffusion_prior_device = next(diffusion_prior.parameters()).device
self.print(f'accelerator not given, and device not specified: defaulting to device of diffusion prior parameters - {diffusion_prior_device}')
self.device = diffusion_prior_device
else:
self.device = accelerator.device if exists(accelerator) else device
# save model
self.diffusion_prior = diffusion_prior
@@ -214,13 +226,9 @@ class DiffusionPriorTrainer(nn.Module):
self.max_grad_norm = max_grad_norm
# verbosity
self.verbose = verbose
# track steps internally
self.register_buffer('step', torch.tensor([0]))
self.register_buffer('step', torch.tensor([0], device = device))
# accelerator wrappers

View File

@@ -1 +1 @@
__version__ = '0.16.8'
__version__ = '0.16.11'