mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Add data flexibility to decoder trainer (#165)
* Added the ability to train decoder with text embeddings * Added the ability to train using on the fly generated embeddings with clip * Clip now generates embeddings for whatever is not precomputed
This commit is contained in:
160
train_decoder.py
160
train_decoder.py
@@ -6,6 +6,7 @@ from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
|
||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
||||
from clip import tokenize
|
||||
|
||||
import torchvision
|
||||
import torch
|
||||
@@ -33,7 +34,8 @@ def exists(val):
|
||||
def create_dataloaders(
|
||||
available_shards,
|
||||
webdataset_base_url,
|
||||
embeddings_url,
|
||||
img_embeddings_url=None,
|
||||
text_embeddings_url=None,
|
||||
shard_width=6,
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
@@ -63,14 +65,15 @@ def create_dataloaders(
|
||||
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
|
||||
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
|
||||
|
||||
create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader(
|
||||
create_dataloader = lambda tar_urls, shuffle=False, resample=False, for_sampling=False: create_image_embedding_dataloader(
|
||||
tar_url=tar_urls,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size if not for_sampling else n_sample_images,
|
||||
embeddings_url=embeddings_url,
|
||||
img_embeddings_url=img_embeddings_url,
|
||||
text_embeddings_url=text_embeddings_url,
|
||||
index_width=index_width,
|
||||
shuffle_num = None,
|
||||
extra_keys= ["txt"] if with_text else [],
|
||||
extra_keys= ["txt"],
|
||||
shuffle_shards = shuffle,
|
||||
resample_shards = resample,
|
||||
img_preproc=img_preproc,
|
||||
@@ -79,8 +82,8 @@ def create_dataloaders(
|
||||
|
||||
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
|
||||
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
|
||||
val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True)
|
||||
test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True)
|
||||
val_dataloader = create_dataloader(val_urls, shuffle=False)
|
||||
test_dataloader = create_dataloader(test_urls, shuffle=False)
|
||||
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
|
||||
return {
|
||||
"train": train_dataloader,
|
||||
@@ -104,42 +107,65 @@ def get_example_data(dataloader, device, n=5):
|
||||
Samples the dataloader and returns a zipped list of examples
|
||||
"""
|
||||
images = []
|
||||
embeddings = []
|
||||
img_embeddings = []
|
||||
text_embeddings = []
|
||||
captions = []
|
||||
dataset_keys = get_dataset_keys(dataloader)
|
||||
has_caption = "txt" in dataset_keys
|
||||
for data in dataloader:
|
||||
if has_caption:
|
||||
img, emb, txt = data
|
||||
for img, emb, txt in dataloader:
|
||||
img_emb, text_emb = emb.get('img'), emb.get('text')
|
||||
if img_emb is not None:
|
||||
img_emb = img_emb.to(device=device, dtype=torch.float)
|
||||
img_embeddings.extend(list(img_emb))
|
||||
else:
|
||||
img, emb = data
|
||||
txt = [""] * emb.shape[0]
|
||||
# Then we add None img.shape[0] times
|
||||
img_embeddings.extend([None]*img.shape[0])
|
||||
if text_emb is not None:
|
||||
text_emb = text_emb.to(device=device, dtype=torch.float)
|
||||
text_embeddings.extend(list(text_emb))
|
||||
else:
|
||||
# Then we add None img.shape[0] times
|
||||
text_embeddings.extend([None]*img.shape[0])
|
||||
img = img.to(device=device, dtype=torch.float)
|
||||
emb = emb.to(device=device, dtype=torch.float)
|
||||
images.extend(list(img))
|
||||
embeddings.extend(list(emb))
|
||||
captions.extend(list(txt))
|
||||
if len(images) >= n:
|
||||
break
|
||||
return list(zip(images[:n], embeddings[:n], captions[:n]))
|
||||
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
|
||||
|
||||
def generate_samples(trainer, example_data, text_prepend=""):
|
||||
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""):
|
||||
"""
|
||||
Takes example data and generates images from the embeddings
|
||||
Returns three lists: real images, generated images, and captions
|
||||
"""
|
||||
real_images, embeddings, txts = zip(*example_data)
|
||||
embeddings_tensor = torch.stack(embeddings)
|
||||
samples = trainer.sample(embeddings_tensor)
|
||||
real_images, img_embeddings, text_embeddings, txts = zip(*example_data)
|
||||
sample_params = {}
|
||||
if img_embeddings[0] is None:
|
||||
# Generate image embeddings from clip
|
||||
imgs_tensor = torch.stack(real_images)
|
||||
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
|
||||
sample_params["image_embed"] = img_embeddings
|
||||
else:
|
||||
# Then we are using precomputed image embeddings
|
||||
img_embeddings = torch.stack(img_embeddings)
|
||||
sample_params["image_embed"] = img_embeddings
|
||||
if condition_on_text_encodings:
|
||||
if text_embeddings[0] is None:
|
||||
# Generate text embeddings from text
|
||||
tokenized_texts = tokenize(txts, truncate=True)
|
||||
sample_params["text"] = tokenized_texts
|
||||
else:
|
||||
# Then we are using precomputed text embeddings
|
||||
text_embeddings = torch.stack(text_embeddings)
|
||||
sample_params["text_encodings"] = text_embeddings
|
||||
samples = trainer.sample(**sample_params)
|
||||
generated_images = list(samples)
|
||||
captions = [text_prepend + txt for txt in txts]
|
||||
return real_images, generated_images, captions
|
||||
|
||||
def generate_grid_samples(trainer, examples, text_prepend=""):
|
||||
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
|
||||
"""
|
||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||
"""
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
|
||||
|
||||
real_image_size = real_images[0].shape[-1]
|
||||
generated_image_size = generated_images[0].shape[-1]
|
||||
@@ -151,7 +177,7 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
|
||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||
return grid_images, captions
|
||||
|
||||
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
"""
|
||||
Computes evaluation metrics for the decoder
|
||||
"""
|
||||
@@ -161,7 +187,7 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID
|
||||
if len(examples) == 0:
|
||||
print("No data to evaluate. Check that your dataloader has shards.")
|
||||
return metrics
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples)
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings)
|
||||
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||
@@ -250,6 +276,7 @@ def train(
|
||||
save_latest=True,
|
||||
save_best=True,
|
||||
unet_training_mask=None,
|
||||
condition_on_text_encodings=False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
@@ -307,14 +334,22 @@ def train(
|
||||
last_snapshot = sample
|
||||
|
||||
if next_task == 'train':
|
||||
for i, (img, emb) in enumerate(dataloaders["train"]):
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
|
||||
# We want to count the total number of samples across all processes
|
||||
sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
||||
total_samples = all_samples.sum().item()
|
||||
sample += total_samples
|
||||
samples_seen += total_samples
|
||||
img, emb = send_to_device((img, emb))
|
||||
img_emb = emb.get('img')
|
||||
has_img_embedding = img_emb is not None
|
||||
if has_img_embedding:
|
||||
img_emb, = send_to_device((img_emb,))
|
||||
text_emb = emb.get('text')
|
||||
has_text_embedding = text_emb is not None
|
||||
if has_text_embedding:
|
||||
text_emb, = send_to_device((text_emb,))
|
||||
img, = send_to_device((img,))
|
||||
|
||||
trainer.train()
|
||||
for unet in range(1, trainer.num_unets+1):
|
||||
@@ -322,7 +357,20 @@ def train(
|
||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||
continue
|
||||
|
||||
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
|
||||
forward_params = {}
|
||||
if has_img_embedding:
|
||||
forward_params['image_embed'] = img_emb
|
||||
else:
|
||||
# Forward pass automatically generates embedding
|
||||
pass
|
||||
if condition_on_text_encodings:
|
||||
if has_text_embedding:
|
||||
forward_params['text_encodings'] = text_emb
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
forward_params['text'] = tokenized_texts
|
||||
loss = trainer.forward(img, **forward_params, unet_number=unet)
|
||||
trainer.update(unet_number=unet)
|
||||
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
||||
|
||||
@@ -366,7 +414,7 @@ def train(
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
||||
if exists(n_sample_images) and n_sample_images > 0:
|
||||
trainer.eval()
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
if epoch_samples is not None and sample >= epoch_samples:
|
||||
@@ -389,14 +437,35 @@ def train(
|
||||
all_samples = accelerator.gather(val_sample_length_tensor)
|
||||
total_samples = all_samples.sum().item()
|
||||
val_sample += total_samples
|
||||
img, emb = send_to_device((img, emb))
|
||||
img_emb = emb.get('img')
|
||||
has_img_embedding = img_emb is not None
|
||||
if has_img_embedding:
|
||||
img_emb, = send_to_device((img_emb,))
|
||||
text_emb = emb.get('text')
|
||||
has_text_embedding = text_emb is not None
|
||||
if has_text_embedding:
|
||||
text_emb, = send_to_device((text_emb,))
|
||||
img, = send_to_device((img,))
|
||||
|
||||
for unet in range(1, len(decoder.unets)+1):
|
||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||
# No need to evaluate an unchanging unet
|
||||
continue
|
||||
|
||||
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
|
||||
|
||||
forward_params = {}
|
||||
if has_img_embedding:
|
||||
forward_params['image_embed'] = img_emb.float()
|
||||
else:
|
||||
# Forward pass automatically generates embedding
|
||||
pass
|
||||
if condition_on_text_encodings:
|
||||
if has_text_embedding:
|
||||
forward_params['text_encodings'] = text_emb.float()
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
forward_params['text'] = tokenized_texts
|
||||
loss = trainer.forward(img.float(), **forward_params, unet_number=unet)
|
||||
average_val_loss_tensor[0, unet-1] += loss
|
||||
|
||||
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
||||
@@ -423,7 +492,7 @@ def train(
|
||||
if next_task == 'eval':
|
||||
if exists(evaluate_config):
|
||||
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
|
||||
if is_master:
|
||||
tracker.log(evaluation, step=step(), verbose=True)
|
||||
next_task = 'sample'
|
||||
@@ -434,8 +503,8 @@ def train(
|
||||
# Generate examples and save the model if we are the master
|
||||
# Generate sample images
|
||||
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
@@ -525,14 +594,35 @@ def initialize_training(config, config_path):
|
||||
# Create and initialize the tracker if we are the master
|
||||
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
|
||||
|
||||
has_img_embeddings = config.data.img_embeddings_url is not None
|
||||
has_text_embeddings = config.data.text_embeddings_url is not None
|
||||
conditioning_on_text = config.decoder.condition_on_text_encodings
|
||||
has_clip_model = config.decoder.clip is not None
|
||||
data_source_string = ""
|
||||
if has_img_embeddings:
|
||||
data_source_string += "precomputed image embeddings"
|
||||
elif has_clip_model:
|
||||
data_source_string += "clip image embeddings generation"
|
||||
else:
|
||||
raise ValueError("No image embeddings source specified")
|
||||
if conditioning_on_text:
|
||||
if has_text_embeddings:
|
||||
data_source_string += " and precomputed text embeddings"
|
||||
elif has_clip_model:
|
||||
data_source_string += " and clip text encoding generation"
|
||||
else:
|
||||
raise ValueError("No text embeddings source specified")
|
||||
|
||||
accelerator.print(print_ribbon("Loaded Config", repeat=40))
|
||||
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
|
||||
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
|
||||
accelerator.print(f"Number of parameters: {num_parameters}")
|
||||
train(dataloaders, decoder, accelerator,
|
||||
tracker=tracker,
|
||||
inference_device=accelerator.device,
|
||||
load_config=config.load,
|
||||
evaluate_config=config.evaluate,
|
||||
condition_on_text_encodings=config.decoder.condition_on_text_encodings,
|
||||
**config.train.dict(),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user