mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-19 03:34:39 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fc954ee788 | ||
|
|
c1db2753f5 |
@@ -830,6 +830,7 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||||
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||||
|
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
|
|||||||
@@ -285,6 +285,10 @@ class ResnetEncDec(nn.Module):
|
|||||||
def get_encoded_fmap_size(self, image_size):
|
def get_encoded_fmap_size(self, image_size):
|
||||||
return image_size // (2 ** self.layers)
|
return image_size // (2 ** self.layers)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_dec_layer(self):
|
||||||
|
return self.decoders[-1].weight
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
for enc in self.encoders:
|
for enc in self.encoders:
|
||||||
x = enc(x)
|
x = enc(x)
|
||||||
@@ -419,6 +423,10 @@ class ConvNextEncDec(nn.Module):
|
|||||||
def get_encoded_fmap_size(self, image_size):
|
def get_encoded_fmap_size(self, image_size):
|
||||||
return image_size // (2 ** self.layers)
|
return image_size // (2 ** self.layers)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_dec_layer(self):
|
||||||
|
return self.decoders[-1].weight
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
for enc in self.encoders:
|
for enc in self.encoders:
|
||||||
x = enc(x)
|
x = enc(x)
|
||||||
@@ -606,6 +614,10 @@ class ViTEncDec(nn.Module):
|
|||||||
def get_encoded_fmap_size(self, image_size):
|
def get_encoded_fmap_size(self, image_size):
|
||||||
return image_size // self.patch_size
|
return image_size // self.patch_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_dec_layer(self):
|
||||||
|
return self.decoder[-3][-1].weight
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
return self.encoder(x)
|
return self.encoder(x)
|
||||||
|
|
||||||
@@ -843,7 +855,7 @@ class VQGanVAE(nn.Module):
|
|||||||
|
|
||||||
# calculate adaptive weight
|
# calculate adaptive weight
|
||||||
|
|
||||||
last_dec_layer = self.decoders[-1].weight
|
last_dec_layer = self.enc_dec.last_dec_layer
|
||||||
|
|
||||||
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
|
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
|
||||||
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
|
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
|
||||||
|
|||||||
Reference in New Issue
Block a user