mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-13 20:54:23 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cf95d37e98 |
@@ -704,7 +704,7 @@ class Attention(nn.Module):
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
sim = sim * self.pb_relax_alpha
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
@@ -1272,7 +1272,7 @@ class CrossAttention(nn.Module):
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
sim = sim * self.pb_relax_alpha
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
@@ -550,7 +550,7 @@ class DecoderTrainer(nn.Module):
|
||||
if only_model:
|
||||
return loaded_obj
|
||||
|
||||
for ind, last_step in zip(range(0, self.num_unets), self.steps.tolist()):
|
||||
for ind, last_step in zip(range(0, self.num_unets), self.steps.cpu().unbind()):
|
||||
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.16.7'
|
||||
__version__ = '0.16.5'
|
||||
|
||||
Reference in New Issue
Block a user