mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 19:44:26 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fc954ee788 | ||
|
|
c1db2753f5 |
@@ -830,6 +830,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||
|
||||
## Citations
|
||||
|
||||
|
||||
@@ -285,6 +285,10 @@ class ResnetEncDec(nn.Module):
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // (2 ** self.layers)
|
||||
|
||||
@property
|
||||
def last_dec_layer(self):
|
||||
return self.decoders[-1].weight
|
||||
|
||||
def encode(self, x):
|
||||
for enc in self.encoders:
|
||||
x = enc(x)
|
||||
@@ -419,6 +423,10 @@ class ConvNextEncDec(nn.Module):
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // (2 ** self.layers)
|
||||
|
||||
@property
|
||||
def last_dec_layer(self):
|
||||
return self.decoders[-1].weight
|
||||
|
||||
def encode(self, x):
|
||||
for enc in self.encoders:
|
||||
x = enc(x)
|
||||
@@ -606,6 +614,10 @@ class ViTEncDec(nn.Module):
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // self.patch_size
|
||||
|
||||
@property
|
||||
def last_dec_layer(self):
|
||||
return self.decoder[-3][-1].weight
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
@@ -843,7 +855,7 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
# calculate adaptive weight
|
||||
|
||||
last_dec_layer = self.decoders[-1].weight
|
||||
last_dec_layer = self.enc_dec.last_dec_layer
|
||||
|
||||
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
|
||||
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
|
||||
|
||||
Reference in New Issue
Block a user