Compare commits

..

1 Commits

Author SHA1 Message Date
Phil Wang
cf95d37e98 set ability to do warmup steps for each unet during training 2022-07-05 16:20:49 -07:00
3 changed files with 4 additions and 4 deletions

View File

@@ -704,7 +704,7 @@ class Attention(nn.Module):
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = self.dropout(attn)
# aggregate values
@@ -1272,7 +1272,7 @@ class CrossAttention(nn.Module):
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')

View File

@@ -550,7 +550,7 @@ class DecoderTrainer(nn.Module):
if only_model:
return loaded_obj
for ind, last_step in zip(range(0, self.num_unets), self.steps.tolist()):
for ind, last_step in zip(range(0, self.num_unets), self.steps.cpu().unbind()):
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)

View File

@@ -1 +1 @@
__version__ = '0.16.7'
__version__ = '0.16.5'