Compare commits

...

2 Commits

Author SHA1 Message Date
Phil Wang
461347c171 fix vqgan-vae for latent diffusion 2022-04-22 11:38:57 -07:00
Phil Wang
46cef31c86 optional projection out for prior network causal transformer 2022-04-22 11:16:30 -07:00
3 changed files with 10 additions and 7 deletions

View File

@@ -350,7 +350,8 @@ class CausalTransformer(nn.Module):
ff_mult = 4,
norm_out = False,
attn_dropout = 0.,
ff_dropout = 0.
ff_dropout = 0.,
final_proj = True
):
super().__init__()
self.rel_pos_bias = RelPosBias(heads = heads)
@@ -363,6 +364,7 @@ class CausalTransformer(nn.Module):
]))
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
def forward(
self,
@@ -377,7 +379,8 @@ class CausalTransformer(nn.Module):
x = attn(x, mask = mask, attn_bias = attn_bias) + x
x = ff(x) + x
return self.norm(x)
out = self.norm(x)
return self.project_out(out)
class DiffusionPriorNetwork(nn.Module):
def __init__(

View File

@@ -435,12 +435,12 @@ class VQGanVAE(nn.Module):
return fmap
def decode(self, fmap):
fmap = self.vq(fmap)
fmap, indices, commit_loss = self.vq(fmap)
for dec in self.decoders:
fmap = dec(fmap)
return fmap
return fmap, indices, commit_loss
def forward(
self,
@@ -453,9 +453,9 @@ class VQGanVAE(nn.Module):
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
fmap, indices, commit_loss = self.encode(img)
fmap = self.encode(img)
fmap = self.decode(fmap)
fmap, indices, commit_loss = self.decode(fmap)
if not return_loss and not return_discr_loss:
return fmap

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.34',
version = '0.0.36',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',