mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-18 15:44:23 +01:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d25976f33 | ||
|
|
0b28ee0d01 | ||
|
|
45262a4bb7 | ||
|
|
13a58a78c4 |
12
README.md
12
README.md
@@ -523,6 +523,7 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||||
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
|
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
|
||||||
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
||||||
|
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
|
||||||
- [ ] spend one day cleaning up tech debt in decoder
|
- [ ] spend one day cleaning up tech debt in decoder
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||||
@@ -531,7 +532,6 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||||
- [ ] bring in tools to train vqgan-vae
|
- [ ] bring in tools to train vqgan-vae
|
||||||
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||||
- [ ] experiment with https://arxiv.org/abs/2112.11435 as upsampler, test in https://github.com/lucidrains/lightweight-gan first
|
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
@@ -577,14 +577,4 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@article{Arar2021LearnedQF,
|
|
||||||
title = {Learned Queries for Efficient Local Attention},
|
|
||||||
author = {Moab Arar and Ariel Shamir and Amit H. Bermano},
|
|
||||||
journal = {ArXiv},
|
|
||||||
year = {2021},
|
|
||||||
volume = {abs/2112.11435}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
||||||
|
|||||||
@@ -693,7 +693,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
# decoder
|
# decoder
|
||||||
|
|
||||||
def Upsample(dim):
|
def Upsample(dim):
|
||||||
return QueryAttnUpsample(dim)
|
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
||||||
|
|
||||||
def Downsample(dim):
|
def Downsample(dim):
|
||||||
return nn.Conv2d(dim, dim, 4, 2, 1)
|
return nn.Conv2d(dim, dim, 4, 2, 1)
|
||||||
@@ -1117,7 +1117,7 @@ class Decoder(nn.Module):
|
|||||||
unet,
|
unet,
|
||||||
*,
|
*,
|
||||||
clip,
|
clip,
|
||||||
vae = None,
|
vae = tuple(),
|
||||||
timesteps = 1000,
|
timesteps = 1000,
|
||||||
cond_drop_prob = 0.2,
|
cond_drop_prob = 0.2,
|
||||||
loss_type = 'l1',
|
loss_type = 'l1',
|
||||||
|
|||||||
@@ -0,0 +1,53 @@
|
|||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
# exponential moving average wrapper
|
||||||
|
|
||||||
|
class EMA(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
beta = 0.99,
|
||||||
|
ema_update_after_step = 1000,
|
||||||
|
ema_update_every = 10,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.beta = beta
|
||||||
|
self.online_model = model
|
||||||
|
self.ema_model = copy.deepcopy(model)
|
||||||
|
|
||||||
|
self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0
|
||||||
|
self.ema_update_every = ema_update_every
|
||||||
|
|
||||||
|
self.register_buffer('initted', torch.Tensor([False]))
|
||||||
|
self.register_buffer('step', torch.tensor([0.]))
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
self.step += 1
|
||||||
|
|
||||||
|
if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.initted:
|
||||||
|
self.ema_model.state_dict(self.online_model.state_dict())
|
||||||
|
self.initted.data.copy_(torch.Tensor([True]))
|
||||||
|
|
||||||
|
self.update_moving_average(self.ema_model, self.online_model)
|
||||||
|
|
||||||
|
def update_moving_average(ma_model, current_model):
|
||||||
|
def calculate_ema(beta, old, new):
|
||||||
|
if not exists(old):
|
||||||
|
return new
|
||||||
|
return old * beta + (1 - beta) * new
|
||||||
|
|
||||||
|
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||||
|
old_weight, up_weight = ma_params.data, current_params.data
|
||||||
|
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
|
||||||
|
|
||||||
|
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
|
||||||
|
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
|
||||||
|
ma_buffer.copy_(new_buffer_value)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.ema_model(*args, **kwargs)
|
||||||
|
|||||||
@@ -378,7 +378,7 @@ class VQGanVAE(nn.Module):
|
|||||||
|
|
||||||
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
|
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
|
||||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||||
prepend(self.decoders, nn.Sequential(QueryAttnUpsample(dim_out), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
|
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
|
||||||
|
|
||||||
if layer_use_attn:
|
if layer_use_attn:
|
||||||
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||||
|
|||||||
Reference in New Issue
Block a user