mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-13 12:04:24 +01:00
Compare commits
29 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
723bf0abba | ||
|
|
d88c7ba56c | ||
|
|
3676a8ce78 | ||
|
|
da8e99ada0 | ||
|
|
6afb886cf4 | ||
|
|
c7fe4f2f44 | ||
|
|
a2ee3fa3cc | ||
|
|
a58a370d75 | ||
|
|
1662bbf226 | ||
|
|
5be1f57448 | ||
|
|
c52ce58e10 | ||
|
|
a34f60962a | ||
|
|
0b40cbaa54 | ||
|
|
f141144a6d | ||
|
|
f988207718 | ||
|
|
b2073219f0 | ||
|
|
cc0f7a935c | ||
|
|
95a512cb65 | ||
|
|
972ee973bc | ||
|
|
79e2a3bc77 | ||
|
|
544cdd0b29 | ||
|
|
349aaca56f | ||
|
|
3ee3c56d2a | ||
|
|
cd26c6b17d | ||
|
|
775abc4df6 | ||
|
|
11b1d533a0 | ||
|
|
e76e89f9eb | ||
|
|
bb3ff0ac67 | ||
|
|
1ec4dbe64f |
2
.github/FUNDING.yml
vendored
2
.github/FUNDING.yml
vendored
@@ -1 +1 @@
|
||||
github: [lucidrains]
|
||||
github: [nousr, Veldrovive, lucidrains]
|
||||
|
||||
@@ -421,7 +421,7 @@ For the layperson, no worries, training will all be automated into a CLI tool, a
|
||||
|
||||
## Training on Preprocessed CLIP Embeddings
|
||||
|
||||
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
|
||||
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings`
|
||||
|
||||
Working example below
|
||||
|
||||
@@ -1048,11 +1048,10 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)
|
||||
- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
|
||||
- [x] allow for unet to be able to condition non-cross attention style as well
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [x] speed up inference, read up on papers (ddim)
|
||||
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
||||
- [ ] try out the nested unet from https://arxiv.org/abs/2005.09007 after hearing several positive testimonies from researchers, for segmentation anyhow
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
||||
|
||||
## Citations
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -129,6 +129,7 @@ class AdapterConfig(BaseModel):
|
||||
class DiffusionPriorNetworkConfig(BaseModel):
|
||||
dim: int
|
||||
depth: int
|
||||
max_text_len: int = None
|
||||
num_timesteps: int = None
|
||||
num_time_embeds: int = 1
|
||||
num_image_embeds: int = 1
|
||||
@@ -136,6 +137,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
|
||||
dim_head: int = 64
|
||||
heads: int = 8
|
||||
ff_mult: int = 4
|
||||
norm_in: bool = False
|
||||
norm_out: bool = True
|
||||
attn_dropout: float = 0.
|
||||
ff_dropout: float = 0.
|
||||
@@ -223,6 +225,7 @@ class UnetConfig(BaseModel):
|
||||
self_attn: ListOrTuple(int)
|
||||
attn_dim_head: int = 32
|
||||
attn_heads: int = 16
|
||||
init_cross_embed: bool = True
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
@@ -673,8 +673,14 @@ class DecoderTrainer(nn.Module):
|
||||
def sample(self, *args, **kwargs):
|
||||
distributed = self.accelerator.num_processes > 1
|
||||
base_decoder = self.accelerator.unwrap_model(self.decoder)
|
||||
|
||||
was_training = base_decoder.training
|
||||
base_decoder.eval()
|
||||
|
||||
if kwargs.pop('use_non_ema', False) or not self.use_ema:
|
||||
return base_decoder.sample(*args, **kwargs, distributed = distributed)
|
||||
out = base_decoder.sample(*args, **kwargs, distributed = distributed)
|
||||
base_decoder.train(was_training)
|
||||
return out
|
||||
|
||||
trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
|
||||
base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||
@@ -687,6 +693,7 @@ class DecoderTrainer(nn.Module):
|
||||
for ema in self.ema_unets:
|
||||
ema.restore_ema_model_device()
|
||||
|
||||
base_decoder.train(was_training)
|
||||
return output
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.21.1'
|
||||
__version__ = '0.26.0'
|
||||
|
||||
@@ -323,7 +323,7 @@ def train(
|
||||
last_snapshot = sample
|
||||
|
||||
if next_task == 'train':
|
||||
for i, (img, emb, txt) in enumerate(trainer.train_loader):
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
|
||||
# We want to count the total number of samples across all processes
|
||||
sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
||||
@@ -358,6 +358,7 @@ def train(
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||
forward_params['text'] = tokenized_texts
|
||||
loss = trainer.forward(img, **forward_params, unet_number=unet)
|
||||
trainer.update(unet_number=unet)
|
||||
@@ -416,7 +417,7 @@ def train(
|
||||
timer = Timer()
|
||||
accelerator.wait_for_everyone()
|
||||
i = 0
|
||||
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
|
||||
for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader
|
||||
val_sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(val_sample_length_tensor)
|
||||
total_samples = all_samples.sum().item()
|
||||
|
||||
@@ -126,9 +126,9 @@ def report_cosine_sims(
|
||||
|
||||
# we are text conditioned, we produce an embedding from the tokenized text
|
||||
if text_conditioned:
|
||||
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
|
||||
text_embedding, text_encodings = trainer.embed_text(text_data)
|
||||
text_cond = dict(
|
||||
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
|
||||
text_embed=text_embedding, text_encodings=text_encodings
|
||||
)
|
||||
else:
|
||||
text_embedding = text_data
|
||||
@@ -146,15 +146,12 @@ def report_cosine_sims(
|
||||
|
||||
if text_conditioned:
|
||||
text_encodings_shuffled = text_encodings[rolled_idx]
|
||||
text_mask_shuffled = text_mask[rolled_idx]
|
||||
else:
|
||||
text_encodings_shuffled = None
|
||||
text_mask_shuffled = None
|
||||
|
||||
text_cond_shuffled = dict(
|
||||
text_embed=text_embed_shuffled,
|
||||
text_encodings=text_encodings_shuffled,
|
||||
mask=text_mask_shuffled,
|
||||
text_encodings=text_encodings_shuffled
|
||||
)
|
||||
|
||||
# prepare the text embedding
|
||||
|
||||
Reference in New Issue
Block a user