Compare commits

..

5 Commits

Author SHA1 Message Date
Phil Wang
1bd8a7835a attempting to fix issue with deepspeed fp16 seeing overflowing gradient 2022-07-06 08:27:34 -07:00
Phil Wang
f33453df9f debugging with Aidan 2022-07-05 18:22:43 -07:00
Phil Wang
1e4bb2bafb cast long as float before deriving sinusoidal pos emb 2022-07-05 18:01:22 -07:00
Phil Wang
ee75515c7d remove forcing of softmax in f32, in case it is interfering with deepspeed 2022-07-05 16:53:58 -07:00
Phil Wang
ec68243479 set ability to do warmup steps for each unet during training 2022-07-05 16:24:16 -07:00
3 changed files with 13 additions and 7 deletions

View File

@@ -335,6 +335,10 @@ def approx_standard_normal_cdf(x):
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
assert x.shape == means.shape == log_scales.shape
# attempting to correct nan gradients when learned variance is turned on
# in the setting of deepspeed fp16
eps = 1e-12 if x.dtype == torch.float32 else 1e-5
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1. / 255.)
@@ -349,7 +353,7 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
log_cdf_plus,
torch.where(x > thres,
log_one_minus_cdf_min,
log(cdf_delta)))
log(cdf_delta, eps = eps)))
return log_probs
@@ -704,7 +708,7 @@ class Attention(nn.Module):
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# aggregate values
@@ -1127,11 +1131,12 @@ class SinusoidalPosEmb(nn.Module):
self.dim = dim
def forward(self, x):
dtype, device = x.dtype, x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1)
return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)
class Block(nn.Module):
def __init__(
@@ -1272,7 +1277,7 @@ class CrossAttention(nn.Module):
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
@@ -1626,6 +1631,7 @@ class Unet(nn.Module):
# time conditioning
time = time.type_as(x)
time_hiddens = self.to_time_hiddens(time)
time_tokens = self.to_time_tokens(time_hiddens)

View File

@@ -550,7 +550,7 @@ class DecoderTrainer(nn.Module):
if only_model:
return loaded_obj
for ind, last_step in zip(range(0, self.num_unets), self.steps.cpu().unbind()):
for ind, last_step in zip(range(0, self.num_unets), self.steps.tolist()):
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)

View File

@@ -1 +1 @@
__version__ = '0.16.5'
__version__ = '0.16.10'