|
|
|
|
@@ -335,6 +335,10 @@ def approx_standard_normal_cdf(x):
|
|
|
|
|
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
|
|
|
|
|
assert x.shape == means.shape == log_scales.shape
|
|
|
|
|
|
|
|
|
|
# attempting to correct nan gradients when learned variance is turned on
|
|
|
|
|
# in the setting of deepspeed fp16
|
|
|
|
|
eps = 1e-12 if x.dtype == torch.float32 else 1e-5
|
|
|
|
|
|
|
|
|
|
centered_x = x - means
|
|
|
|
|
inv_stdv = torch.exp(-log_scales)
|
|
|
|
|
plus_in = inv_stdv * (centered_x + 1. / 255.)
|
|
|
|
|
@@ -349,7 +353,7 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
|
|
|
|
|
log_cdf_plus,
|
|
|
|
|
torch.where(x > thres,
|
|
|
|
|
log_one_minus_cdf_min,
|
|
|
|
|
log(cdf_delta)))
|
|
|
|
|
log(cdf_delta, eps = eps)))
|
|
|
|
|
|
|
|
|
|
return log_probs
|
|
|
|
|
|
|
|
|
|
@@ -704,7 +708,7 @@ class Attention(nn.Module):
|
|
|
|
|
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
|
|
|
|
sim = sim * self.pb_relax_alpha
|
|
|
|
|
|
|
|
|
|
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
|
|
|
|
attn = sim.softmax(dim = -1)
|
|
|
|
|
attn = self.dropout(attn)
|
|
|
|
|
|
|
|
|
|
# aggregate values
|
|
|
|
|
@@ -1127,11 +1131,12 @@ class SinusoidalPosEmb(nn.Module):
|
|
|
|
|
self.dim = dim
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
dtype, device = x.dtype, x.device
|
|
|
|
|
half_dim = self.dim // 2
|
|
|
|
|
emb = math.log(10000) / (half_dim - 1)
|
|
|
|
|
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
|
|
|
|
|
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
|
|
|
|
|
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
|
|
|
|
return torch.cat((emb.sin(), emb.cos()), dim = -1)
|
|
|
|
|
return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)
|
|
|
|
|
|
|
|
|
|
class Block(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
@@ -1272,7 +1277,7 @@ class CrossAttention(nn.Module):
|
|
|
|
|
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
|
|
|
|
sim = sim * self.pb_relax_alpha
|
|
|
|
|
|
|
|
|
|
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
|
|
|
|
attn = sim.softmax(dim = -1)
|
|
|
|
|
|
|
|
|
|
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
|
|
|
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
|
|
|
|
@@ -1626,6 +1631,7 @@ class Unet(nn.Module):
|
|
|
|
|
|
|
|
|
|
# time conditioning
|
|
|
|
|
|
|
|
|
|
time = time.type_as(x)
|
|
|
|
|
time_hiddens = self.to_time_hiddens(time)
|
|
|
|
|
|
|
|
|
|
time_tokens = self.to_time_tokens(time_hiddens)
|
|
|
|
|
|