mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
1 Commits
v0.25.2
...
fix_resizi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3a1dea7d97 |
2
.github/FUNDING.yml
vendored
2
.github/FUNDING.yml
vendored
@@ -1 +1 @@
|
||||
github: [nousr, Veldrovive, lucidrains]
|
||||
github: [lucidrains]
|
||||
|
||||
15
README.md
15
README.md
@@ -45,7 +45,6 @@ This library would not have gotten to this working state without the help of
|
||||
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
|
||||
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
|
||||
- <a href="https://github.com/marunine">Marunine</a> for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes
|
||||
- <a href="https://github.com/malumadev">MalumaDev</a> for proposing the use of pixel shuffle upsampler for fixing checkboard artifacts
|
||||
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
||||
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
|
||||
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
|
||||
@@ -356,8 +355,7 @@ prior_network = DiffusionPriorNetwork(
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
timesteps = 1000,
|
||||
sample_timesteps = 64,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
@@ -421,7 +419,7 @@ For the layperson, no worries, training will all be automated into a CLI tool, a
|
||||
|
||||
## Training on Preprocessed CLIP Embeddings
|
||||
|
||||
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings`
|
||||
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
|
||||
|
||||
Working example below
|
||||
|
||||
@@ -585,7 +583,6 @@ unet1 = Unet(
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
text_embed_dim = 512,
|
||||
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
|
||||
).cuda()
|
||||
|
||||
@@ -601,8 +598,7 @@ decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 1000,
|
||||
sample_timesteps = (250, 27),
|
||||
timesteps = 100,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5
|
||||
).cuda()
|
||||
@@ -1048,10 +1044,11 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)
|
||||
- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
|
||||
- [x] allow for unet to be able to condition non-cross attention style as well
|
||||
- [x] speed up inference, read up on papers (ddim)
|
||||
- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
||||
|
||||
## Citations
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -129,7 +129,6 @@ class AdapterConfig(BaseModel):
|
||||
class DiffusionPriorNetworkConfig(BaseModel):
|
||||
dim: int
|
||||
depth: int
|
||||
max_text_len: int = None
|
||||
num_timesteps: int = None
|
||||
num_time_embeds: int = 1
|
||||
num_image_embeds: int = 1
|
||||
@@ -137,7 +136,6 @@ class DiffusionPriorNetworkConfig(BaseModel):
|
||||
dim_head: int = 64
|
||||
heads: int = 8
|
||||
ff_mult: int = 4
|
||||
norm_in: bool = False
|
||||
norm_out: bool = True
|
||||
attn_dropout: float = 0.
|
||||
ff_dropout: float = 0.
|
||||
@@ -156,7 +154,6 @@ class DiffusionPriorConfig(BaseModel):
|
||||
image_size: int
|
||||
image_channels: int = 3
|
||||
timesteps: int = 1000
|
||||
sample_timesteps: Optional[int] = None
|
||||
cond_drop_prob: float = 0.
|
||||
loss_type: str = 'l2'
|
||||
predict_x_start: bool = True
|
||||
@@ -225,7 +222,6 @@ class UnetConfig(BaseModel):
|
||||
self_attn: ListOrTuple(int)
|
||||
attn_dim_head: int = 32
|
||||
attn_heads: int = 16
|
||||
init_cross_embed: bool = True
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
@@ -237,7 +233,6 @@ class DecoderConfig(BaseModel):
|
||||
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
||||
channels: int = 3
|
||||
timesteps: int = 1000
|
||||
sample_timesteps: Optional[SingularOrIterable(int)] = None
|
||||
loss_type: str = 'l2'
|
||||
beta_schedule: ListOrTuple(str) = 'cosine'
|
||||
learned_variance: bool = True
|
||||
|
||||
@@ -536,19 +536,11 @@ class DecoderTrainer(nn.Module):
|
||||
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
|
||||
clip = decoder.clip
|
||||
clip.to(precision_type)
|
||||
|
||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
# prepare dataloaders
|
||||
|
||||
train_loader = val_loader = None
|
||||
if exists(dataloaders):
|
||||
train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"])
|
||||
decoder, train_loader, val_loader, *optimizers = list(self.accelerator.prepare(decoder, dataloaders["train"], dataloaders["val"], *optimizers))
|
||||
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.decoder = decoder
|
||||
|
||||
# store optimizers
|
||||
|
||||
@@ -673,14 +665,8 @@ class DecoderTrainer(nn.Module):
|
||||
def sample(self, *args, **kwargs):
|
||||
distributed = self.accelerator.num_processes > 1
|
||||
base_decoder = self.accelerator.unwrap_model(self.decoder)
|
||||
|
||||
was_training = base_decoder.training
|
||||
base_decoder.eval()
|
||||
|
||||
if kwargs.pop('use_non_ema', False) or not self.use_ema:
|
||||
out = base_decoder.sample(*args, **kwargs, distributed = distributed)
|
||||
base_decoder.train(was_training)
|
||||
return out
|
||||
return base_decoder.sample(*args, **kwargs, distributed = distributed)
|
||||
|
||||
trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
|
||||
base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||
@@ -693,7 +679,6 @@ class DecoderTrainer(nn.Module):
|
||||
for ema in self.ema_unets:
|
||||
ema.restore_ema_model_device()
|
||||
|
||||
base_decoder.train(was_training)
|
||||
return output
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.25.2'
|
||||
__version__ = '0.18.0'
|
||||
|
||||
@@ -323,7 +323,7 @@ def train(
|
||||
last_snapshot = sample
|
||||
|
||||
if next_task == 'train':
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
|
||||
for i, (img, emb, txt) in enumerate(trainer.train_loader):
|
||||
# We want to count the total number of samples across all processes
|
||||
sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
||||
@@ -358,7 +358,6 @@ def train(
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||
forward_params['text'] = tokenized_texts
|
||||
loss = trainer.forward(img, **forward_params, unet_number=unet)
|
||||
trainer.update(unet_number=unet)
|
||||
@@ -417,7 +416,7 @@ def train(
|
||||
timer = Timer()
|
||||
accelerator.wait_for_everyone()
|
||||
i = 0
|
||||
for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader
|
||||
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
|
||||
val_sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(val_sample_length_tensor)
|
||||
total_samples = all_samples.sum().item()
|
||||
@@ -558,7 +557,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
|
||||
# Create the decoder model and print basic info
|
||||
decoder = config.decoder.create()
|
||||
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
|
||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||
|
||||
# Create and initialize the tracker if we are the master
|
||||
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
|
||||
@@ -587,10 +586,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
accelerator.print(print_ribbon("Loaded Config", repeat=40))
|
||||
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
|
||||
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
|
||||
accelerator.print(f"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training")
|
||||
for i, unet in enumerate(decoder.unets):
|
||||
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
|
||||
|
||||
accelerator.print(f"Number of parameters: {num_parameters}")
|
||||
train(dataloaders, decoder, accelerator,
|
||||
tracker=tracker,
|
||||
inference_device=accelerator.device,
|
||||
|
||||
@@ -126,9 +126,9 @@ def report_cosine_sims(
|
||||
|
||||
# we are text conditioned, we produce an embedding from the tokenized text
|
||||
if text_conditioned:
|
||||
text_embedding, text_encodings = trainer.embed_text(text_data)
|
||||
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
|
||||
text_cond = dict(
|
||||
text_embed=text_embedding, text_encodings=text_encodings
|
||||
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
|
||||
)
|
||||
else:
|
||||
text_embedding = text_data
|
||||
@@ -146,12 +146,15 @@ def report_cosine_sims(
|
||||
|
||||
if text_conditioned:
|
||||
text_encodings_shuffled = text_encodings[rolled_idx]
|
||||
text_mask_shuffled = text_mask[rolled_idx]
|
||||
else:
|
||||
text_encodings_shuffled = None
|
||||
text_mask_shuffled = None
|
||||
|
||||
text_cond_shuffled = dict(
|
||||
text_embed=text_embed_shuffled,
|
||||
text_encodings=text_encodings_shuffled
|
||||
text_encodings=text_encodings_shuffled,
|
||||
mask=text_mask_shuffled,
|
||||
)
|
||||
|
||||
# prepare the text embedding
|
||||
|
||||
Reference in New Issue
Block a user