mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 19:24:21 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
461347c171 | ||
|
|
46cef31c86 |
@@ -350,7 +350,8 @@ class CausalTransformer(nn.Module):
|
|||||||
ff_mult = 4,
|
ff_mult = 4,
|
||||||
norm_out = False,
|
norm_out = False,
|
||||||
attn_dropout = 0.,
|
attn_dropout = 0.,
|
||||||
ff_dropout = 0.
|
ff_dropout = 0.,
|
||||||
|
final_proj = True
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||||
@@ -363,6 +364,7 @@ class CausalTransformer(nn.Module):
|
|||||||
]))
|
]))
|
||||||
|
|
||||||
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||||
|
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -377,7 +379,8 @@ class CausalTransformer(nn.Module):
|
|||||||
x = attn(x, mask = mask, attn_bias = attn_bias) + x
|
x = attn(x, mask = mask, attn_bias = attn_bias) + x
|
||||||
x = ff(x) + x
|
x = ff(x) + x
|
||||||
|
|
||||||
return self.norm(x)
|
out = self.norm(x)
|
||||||
|
return self.project_out(out)
|
||||||
|
|
||||||
class DiffusionPriorNetwork(nn.Module):
|
class DiffusionPriorNetwork(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -435,12 +435,12 @@ class VQGanVAE(nn.Module):
|
|||||||
return fmap
|
return fmap
|
||||||
|
|
||||||
def decode(self, fmap):
|
def decode(self, fmap):
|
||||||
fmap = self.vq(fmap)
|
fmap, indices, commit_loss = self.vq(fmap)
|
||||||
|
|
||||||
for dec in self.decoders:
|
for dec in self.decoders:
|
||||||
fmap = dec(fmap)
|
fmap = dec(fmap)
|
||||||
|
|
||||||
return fmap
|
return fmap, indices, commit_loss
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -453,9 +453,9 @@ class VQGanVAE(nn.Module):
|
|||||||
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
|
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
|
||||||
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
|
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
|
||||||
|
|
||||||
fmap, indices, commit_loss = self.encode(img)
|
fmap = self.encode(img)
|
||||||
|
|
||||||
fmap = self.decode(fmap)
|
fmap, indices, commit_loss = self.decode(fmap)
|
||||||
|
|
||||||
if not return_loss and not return_discr_loss:
|
if not return_loss and not return_discr_loss:
|
||||||
return fmap
|
return fmap
|
||||||
|
|||||||
Reference in New Issue
Block a user