Add data flexibility to decoder trainer (#165)

* Added the ability to train decoder with text embeddings

* Added the ability to train using on the fly generated embeddings with clip

* Clip now generates embeddings for whatever is not precomputed
This commit is contained in:
Aidan Dempster
2022-06-25 22:05:20 -04:00
committed by GitHub
parent c453f468b1
commit f5760bdb92
4 changed files with 228 additions and 59 deletions

View File

@@ -21,7 +21,7 @@ def get_example_file(fs, path, file_format):
"""
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handlers.reraise_exception):
def embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', handler=wds.handlers.reraise_exception):
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
previous_tar_url = None
current_embeddings = None
@@ -56,7 +56,7 @@ def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handler
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
if torch.count_nonzero(embedding) == 0:
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
sample["npy"] = embedding
sample[sample_key] = embedding
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
@@ -84,18 +84,20 @@ def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.re
continue
else:
break
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
def verify_keys(samples, handler=wds.handlers.reraise_exception):
def join_embeddings(samples, handler=wds.handlers.reraise_exception):
"""
Requires that both the image and embedding are present in the sample
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
Takes the img_emb and text_emb keys and turns them into one key "emb": { "text": text_emb, "img": img_emb }
either or both of text_emb and img_emb may not be in the sample so we only add the ones that exist
"""
for sample in samples:
try:
assert "jpg" in sample, f"Sample {sample['__key__']} missing image"
assert "npy" in sample, f"Sample {sample['__key__']} missing embedding. Did you set embedding_folder_url?"
sample['emb'] = {}
if 'text_emb' in sample:
sample['emb']['text'] = sample['text_emb']
if 'img_emb' in sample:
sample['emb']['img'] = sample['img_emb']
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
@@ -103,6 +105,23 @@ def verify_keys(samples, handler=wds.handlers.reraise_exception):
else:
break
def verify_keys(samples, required_keys, handler=wds.handlers.reraise_exception):
"""
Requires that both the image and embedding are present in the sample
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
"""
for sample in samples:
try:
for key in required_keys:
assert key in sample, f"Sample {sample['__key__']} missing {key}. Has keys {sample.keys()}"
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
continue
else:
break
key_verifier = wds.filters.pipelinefilter(verify_keys)
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
"""
A fluid interface wrapper for DataPipline that returns image embedding pairs
@@ -112,7 +131,8 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
def __init__(
self,
urls,
embedding_folder_url=None,
img_embedding_folder_url=None,
text_embedding_folder_url=None,
index_width=None,
img_preproc=None,
extra_keys=[],
@@ -136,7 +156,12 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
"""
super().__init__()
keys = ["jpg", "npy"] + extra_keys
keys = ["jpg", "emb"] + extra_keys
# if img_embedding_folder_url is not None:
# keys.append("img_emb")
# if text_embedding_folder_url is not None:
# keys.append("text_emb")
# keys.extend(extra_keys)
self.key_map = {key: i for i, key in enumerate(keys)}
self.resampling = resample
self.img_preproc = img_preproc
@@ -145,7 +170,7 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
# Then this has an s3 link for the webdataset and we need extra packages
if shutil.which("s3cmd") is None:
raise RuntimeError("s3cmd is required for s3 webdataset")
if "s3:" in embedding_folder_url:
if (img_embedding_folder_url is not None and "s3:" in img_embedding_folder_url) or (text_embedding_folder_url is not None and "s3:" in text_embedding_folder_url):
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
try:
import s3fs
@@ -160,17 +185,24 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
if shuffle_shards:
self.append(wds.filters.shuffle(1000))
if embedding_folder_url is not None:
if img_embedding_folder_url is not None:
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
self.append(skip_unassociated_shards(embeddings_url=img_embedding_folder_url, handler=handler))
if text_embedding_folder_url is not None:
self.append(skip_unassociated_shards(embeddings_url=text_embedding_folder_url, handler=handler))
self.append(wds.tarfile_to_samples(handler=handler))
self.append(wds.decode("pilrgb", handler=handler))
if embedding_folder_url is not None:
# Then we are loading embeddings for a remote source
if img_embedding_folder_url is not None:
# Then we are loading image embeddings for a remote source
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
self.append(insert_embedding(embeddings_url=embedding_folder_url, index_width=index_width, handler=handler))
self.append(verify_keys)
self.append(insert_embedding(embeddings_url=img_embedding_folder_url, index_width=index_width, sample_key='img_emb', handler=handler))
if text_embedding_folder_url is not None:
# Then we are loading image embeddings for a remote source
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
self.append(insert_embedding(embeddings_url=text_embedding_folder_url, index_width=index_width, sample_key='text_emb', handler=handler))
self.append(join_embeddings)
self.append(key_verifier(required_keys=keys, handler=handler))
# Apply preprocessing
self.append(wds.map(self.preproc))
self.append(wds.to_tuple(*keys))
@@ -185,7 +217,8 @@ def create_image_embedding_dataloader(
tar_url,
num_workers,
batch_size,
embeddings_url=None,
img_embeddings_url=None,
text_embeddings_url=None,
index_width=None,
shuffle_num = None,
shuffle_shards = True,
@@ -211,7 +244,8 @@ def create_image_embedding_dataloader(
"""
ds = ImageEmbeddingDataset(
tar_url,
embeddings_url,
img_embedding_folder_url=img_embeddings_url,
text_embedding_folder_url=text_embeddings_url,
index_width=index_width,
shuffle_shards=shuffle_shards,
resample=resample_shards,
@@ -228,4 +262,4 @@ def create_image_embedding_dataloader(
prefetch_factor=2, # This might be good to have high so the next npy file is prefetched
pin_memory=True,
shuffle=False
)
)

View File

@@ -13,7 +13,7 @@ from dalle2_pytorch.dalle2_pytorch import (
Decoder,
DiffusionPrior,
DiffusionPriorNetwork,
XClipAdapter,
XClipAdapter
)
# helper functions
@@ -170,6 +170,8 @@ class DecoderConfig(BaseModel):
unets: ListOrTuple(UnetConfig)
image_size: int = None
image_sizes: ListOrTuple(int) = None
condition_on_text_encodings: bool = False
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
loss_type: str = 'l2'
@@ -180,9 +182,16 @@ class DecoderConfig(BaseModel):
def create(self):
decoder_kwargs = self.dict()
unet_configs = decoder_kwargs.pop('unets')
unets = [Unet(**config) for config in unet_configs]
return Decoder(unets, **decoder_kwargs)
has_clip = exists(decoder_kwargs.pop('clip'))
clip = None
if has_clip:
clip = self.clip.create()
return Decoder(unets, clip=clip, **decoder_kwargs)
@validator('image_sizes')
def check_image_sizes(cls, image_sizes, values):
@@ -194,8 +203,9 @@ class DecoderConfig(BaseModel):
extra = "allow"
class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images
embeddings_url: str # path to .npy files with embeddings
webdataset_base_url: str # path to a webdataset with jpg images
img_embeddings_url: Optional[str] # path to .npy files with embeddings
text_embeddings_url: Optional[str] # path to .npy files with embeddings
num_workers: int = 4
batch_size: int = 64
start_shard: int = 0
@@ -268,3 +278,26 @@ class TrainDecoderConfig(BaseModel):
with open(json_path) as f:
config = json.load(f)
return cls(**config)
@root_validator
def check_has_embeddings(cls, values):
# Makes sure that enough information is provided to get the embeddings specified for training
data_config, decoder_config = values.get('data'), values.get('decoder')
if data_config is None or decoder_config is None:
# Then something else errored and we should just pass through
return values
using_text_embeddings = decoder_config.condition_on_text_encodings
using_clip = exists(decoder_config.clip)
img_emb_url = data_config.img_embeddings_url
text_emb_url = data_config.text_embeddings_url
if using_text_embeddings:
# Then we need some way to get the embeddings
assert using_clip or text_emb_url is not None, 'If condition_on_text_encodings is true, either clip or text_embeddings_url must be provided'
if using_clip:
if using_text_embeddings:
assert text_emb_url is None or img_emb_url is None, 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
else:
assert img_emb_url is None, 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
if text_emb_url:
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
return values

View File

@@ -578,6 +578,18 @@ class DecoderTrainer(nn.Module):
return output
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def embed_text(self, *args, **kwargs):
return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def embed_image(self, *args, **kwargs):
return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs)
@cast_torch_tensor
def forward(
self,

View File

@@ -6,6 +6,7 @@ from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer, print_ribbon
from dalle2_pytorch.dalle2_pytorch import resize_image_to
from clip import tokenize
import torchvision
import torch
@@ -33,7 +34,8 @@ def exists(val):
def create_dataloaders(
available_shards,
webdataset_base_url,
embeddings_url,
img_embeddings_url=None,
text_embeddings_url=None,
shard_width=6,
num_workers=4,
batch_size=32,
@@ -63,14 +65,15 @@ def create_dataloaders(
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader(
create_dataloader = lambda tar_urls, shuffle=False, resample=False, for_sampling=False: create_image_embedding_dataloader(
tar_url=tar_urls,
num_workers=num_workers,
batch_size=batch_size if not for_sampling else n_sample_images,
embeddings_url=embeddings_url,
img_embeddings_url=img_embeddings_url,
text_embeddings_url=text_embeddings_url,
index_width=index_width,
shuffle_num = None,
extra_keys= ["txt"] if with_text else [],
extra_keys= ["txt"],
shuffle_shards = shuffle,
resample_shards = resample,
img_preproc=img_preproc,
@@ -79,8 +82,8 @@ def create_dataloaders(
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True)
test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True)
val_dataloader = create_dataloader(val_urls, shuffle=False)
test_dataloader = create_dataloader(test_urls, shuffle=False)
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
return {
"train": train_dataloader,
@@ -104,42 +107,65 @@ def get_example_data(dataloader, device, n=5):
Samples the dataloader and returns a zipped list of examples
"""
images = []
embeddings = []
img_embeddings = []
text_embeddings = []
captions = []
dataset_keys = get_dataset_keys(dataloader)
has_caption = "txt" in dataset_keys
for data in dataloader:
if has_caption:
img, emb, txt = data
for img, emb, txt in dataloader:
img_emb, text_emb = emb.get('img'), emb.get('text')
if img_emb is not None:
img_emb = img_emb.to(device=device, dtype=torch.float)
img_embeddings.extend(list(img_emb))
else:
img, emb = data
txt = [""] * emb.shape[0]
# Then we add None img.shape[0] times
img_embeddings.extend([None]*img.shape[0])
if text_emb is not None:
text_emb = text_emb.to(device=device, dtype=torch.float)
text_embeddings.extend(list(text_emb))
else:
# Then we add None img.shape[0] times
text_embeddings.extend([None]*img.shape[0])
img = img.to(device=device, dtype=torch.float)
emb = emb.to(device=device, dtype=torch.float)
images.extend(list(img))
embeddings.extend(list(emb))
captions.extend(list(txt))
if len(images) >= n:
break
return list(zip(images[:n], embeddings[:n], captions[:n]))
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, text_prepend=""):
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""):
"""
Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions
"""
real_images, embeddings, txts = zip(*example_data)
embeddings_tensor = torch.stack(embeddings)
samples = trainer.sample(embeddings_tensor)
real_images, img_embeddings, text_embeddings, txts = zip(*example_data)
sample_params = {}
if img_embeddings[0] is None:
# Generate image embeddings from clip
imgs_tensor = torch.stack(real_images)
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
sample_params["image_embed"] = img_embeddings
else:
# Then we are using precomputed image embeddings
img_embeddings = torch.stack(img_embeddings)
sample_params["image_embed"] = img_embeddings
if condition_on_text_encodings:
if text_embeddings[0] is None:
# Generate text embeddings from text
tokenized_texts = tokenize(txts, truncate=True)
sample_params["text"] = tokenized_texts
else:
# Then we are using precomputed text embeddings
text_embeddings = torch.stack(text_embeddings)
sample_params["text_encodings"] = text_embeddings
samples = trainer.sample(**sample_params)
generated_images = list(samples)
captions = [text_prepend + txt for txt in txts]
return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, text_prepend=""):
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
"""
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
"""
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
real_image_size = real_images[0].shape[-1]
generated_image_size = generated_images[0].shape[-1]
@@ -151,7 +177,7 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
"""
Computes evaluation metrics for the decoder
"""
@@ -161,7 +187,7 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID
if len(examples) == 0:
print("No data to evaluate. Check that your dataloader has shards.")
return metrics
real_images, generated_images, captions = generate_samples(trainer, examples)
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
@@ -250,6 +276,7 @@ def train(
save_latest=True,
save_best=True,
unet_training_mask=None,
condition_on_text_encodings=False,
**kwargs
):
"""
@@ -307,14 +334,22 @@ def train(
last_snapshot = sample
if next_task == 'train':
for i, (img, emb) in enumerate(dataloaders["train"]):
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
# We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
total_samples = all_samples.sum().item()
sample += total_samples
samples_seen += total_samples
img, emb = send_to_device((img, emb))
img_emb = emb.get('img')
has_img_embedding = img_emb is not None
if has_img_embedding:
img_emb, = send_to_device((img_emb,))
text_emb = emb.get('text')
has_text_embedding = text_emb is not None
if has_text_embedding:
text_emb, = send_to_device((text_emb,))
img, = send_to_device((img,))
trainer.train()
for unet in range(1, trainer.num_unets+1):
@@ -322,7 +357,20 @@ def train(
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
continue
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
forward_params = {}
if has_img_embedding:
forward_params['image_embed'] = img_emb
else:
# Forward pass automatically generates embedding
pass
if condition_on_text_encodings:
if has_text_embedding:
forward_params['text_encodings'] = text_emb
else:
# Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True)
forward_params['text'] = tokenized_texts
loss = trainer.forward(img, **forward_params, unet_number=unet)
trainer.update(unet_number=unet)
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
@@ -366,7 +414,7 @@ def train(
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
if epoch_samples is not None and sample >= epoch_samples:
@@ -389,14 +437,35 @@ def train(
all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item()
val_sample += total_samples
img, emb = send_to_device((img, emb))
img_emb = emb.get('img')
has_img_embedding = img_emb is not None
if has_img_embedding:
img_emb, = send_to_device((img_emb,))
text_emb = emb.get('text')
has_text_embedding = text_emb is not None
if has_text_embedding:
text_emb, = send_to_device((text_emb,))
img, = send_to_device((img,))
for unet in range(1, len(decoder.unets)+1):
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
# No need to evaluate an unchanging unet
continue
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
forward_params = {}
if has_img_embedding:
forward_params['image_embed'] = img_emb.float()
else:
# Forward pass automatically generates embedding
pass
if condition_on_text_encodings:
if has_text_embedding:
forward_params['text_encodings'] = text_emb.float()
else:
# Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True)
forward_params['text'] = tokenized_texts
loss = trainer.forward(img.float(), **forward_params, unet_number=unet)
average_val_loss_tensor[0, unet-1] += loss
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
@@ -423,7 +492,7 @@ def train(
if next_task == 'eval':
if exists(evaluate_config):
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
if is_master:
tracker.log(evaluation, step=step(), verbose=True)
next_task = 'sample'
@@ -434,8 +503,8 @@ def train(
# Generate examples and save the model if we are the master
# Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
@@ -525,14 +594,35 @@ def initialize_training(config, config_path):
# Create and initialize the tracker if we are the master
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
has_img_embeddings = config.data.img_embeddings_url is not None
has_text_embeddings = config.data.text_embeddings_url is not None
conditioning_on_text = config.decoder.condition_on_text_encodings
has_clip_model = config.decoder.clip is not None
data_source_string = ""
if has_img_embeddings:
data_source_string += "precomputed image embeddings"
elif has_clip_model:
data_source_string += "clip image embeddings generation"
else:
raise ValueError("No image embeddings source specified")
if conditioning_on_text:
if has_text_embeddings:
data_source_string += " and precomputed text embeddings"
elif has_clip_model:
data_source_string += " and clip text encoding generation"
else:
raise ValueError("No text embeddings source specified")
accelerator.print(print_ribbon("Loaded Config", repeat=40))
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
accelerator.print(f"Number of parameters: {num_parameters}")
train(dataloaders, decoder, accelerator,
tracker=tracker,
inference_device=accelerator.device,
load_config=config.load,
evaluate_config=config.evaluate,
condition_on_text_encodings=config.decoder.condition_on_text_encodings,
**config.train.dict(),
)