mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Add data flexibility to decoder trainer (#165)
* Added the ability to train decoder with text embeddings * Added the ability to train using on the fly generated embeddings with clip * Clip now generates embeddings for whatever is not precomputed
This commit is contained in:
@@ -21,7 +21,7 @@ def get_example_file(fs, path, file_format):
|
|||||||
"""
|
"""
|
||||||
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
|
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
|
||||||
|
|
||||||
def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handlers.reraise_exception):
|
def embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', handler=wds.handlers.reraise_exception):
|
||||||
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
|
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
|
||||||
previous_tar_url = None
|
previous_tar_url = None
|
||||||
current_embeddings = None
|
current_embeddings = None
|
||||||
@@ -56,7 +56,7 @@ def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handler
|
|||||||
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
|
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
|
||||||
if torch.count_nonzero(embedding) == 0:
|
if torch.count_nonzero(embedding) == 0:
|
||||||
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
|
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
|
||||||
sample["npy"] = embedding
|
sample[sample_key] = embedding
|
||||||
yield sample
|
yield sample
|
||||||
except Exception as exn: # From wds implementation
|
except Exception as exn: # From wds implementation
|
||||||
if handler(exn):
|
if handler(exn):
|
||||||
@@ -84,18 +84,20 @@ def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.re
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
|
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
|
||||||
|
|
||||||
def verify_keys(samples, handler=wds.handlers.reraise_exception):
|
def join_embeddings(samples, handler=wds.handlers.reraise_exception):
|
||||||
"""
|
"""
|
||||||
Requires that both the image and embedding are present in the sample
|
Takes the img_emb and text_emb keys and turns them into one key "emb": { "text": text_emb, "img": img_emb }
|
||||||
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
|
either or both of text_emb and img_emb may not be in the sample so we only add the ones that exist
|
||||||
"""
|
"""
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
try:
|
try:
|
||||||
assert "jpg" in sample, f"Sample {sample['__key__']} missing image"
|
sample['emb'] = {}
|
||||||
assert "npy" in sample, f"Sample {sample['__key__']} missing embedding. Did you set embedding_folder_url?"
|
if 'text_emb' in sample:
|
||||||
|
sample['emb']['text'] = sample['text_emb']
|
||||||
|
if 'img_emb' in sample:
|
||||||
|
sample['emb']['img'] = sample['img_emb']
|
||||||
yield sample
|
yield sample
|
||||||
except Exception as exn: # From wds implementation
|
except Exception as exn: # From wds implementation
|
||||||
if handler(exn):
|
if handler(exn):
|
||||||
@@ -103,6 +105,23 @@ def verify_keys(samples, handler=wds.handlers.reraise_exception):
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def verify_keys(samples, required_keys, handler=wds.handlers.reraise_exception):
|
||||||
|
"""
|
||||||
|
Requires that both the image and embedding are present in the sample
|
||||||
|
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
|
||||||
|
"""
|
||||||
|
for sample in samples:
|
||||||
|
try:
|
||||||
|
for key in required_keys:
|
||||||
|
assert key in sample, f"Sample {sample['__key__']} missing {key}. Has keys {sample.keys()}"
|
||||||
|
yield sample
|
||||||
|
except Exception as exn: # From wds implementation
|
||||||
|
if handler(exn):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
key_verifier = wds.filters.pipelinefilter(verify_keys)
|
||||||
|
|
||||||
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||||
"""
|
"""
|
||||||
A fluid interface wrapper for DataPipline that returns image embedding pairs
|
A fluid interface wrapper for DataPipline that returns image embedding pairs
|
||||||
@@ -112,7 +131,8 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
urls,
|
urls,
|
||||||
embedding_folder_url=None,
|
img_embedding_folder_url=None,
|
||||||
|
text_embedding_folder_url=None,
|
||||||
index_width=None,
|
index_width=None,
|
||||||
img_preproc=None,
|
img_preproc=None,
|
||||||
extra_keys=[],
|
extra_keys=[],
|
||||||
@@ -136,7 +156,12 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
keys = ["jpg", "npy"] + extra_keys
|
keys = ["jpg", "emb"] + extra_keys
|
||||||
|
# if img_embedding_folder_url is not None:
|
||||||
|
# keys.append("img_emb")
|
||||||
|
# if text_embedding_folder_url is not None:
|
||||||
|
# keys.append("text_emb")
|
||||||
|
# keys.extend(extra_keys)
|
||||||
self.key_map = {key: i for i, key in enumerate(keys)}
|
self.key_map = {key: i for i, key in enumerate(keys)}
|
||||||
self.resampling = resample
|
self.resampling = resample
|
||||||
self.img_preproc = img_preproc
|
self.img_preproc = img_preproc
|
||||||
@@ -145,7 +170,7 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
|||||||
# Then this has an s3 link for the webdataset and we need extra packages
|
# Then this has an s3 link for the webdataset and we need extra packages
|
||||||
if shutil.which("s3cmd") is None:
|
if shutil.which("s3cmd") is None:
|
||||||
raise RuntimeError("s3cmd is required for s3 webdataset")
|
raise RuntimeError("s3cmd is required for s3 webdataset")
|
||||||
if "s3:" in embedding_folder_url:
|
if (img_embedding_folder_url is not None and "s3:" in img_embedding_folder_url) or (text_embedding_folder_url is not None and "s3:" in text_embedding_folder_url):
|
||||||
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
|
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
|
||||||
try:
|
try:
|
||||||
import s3fs
|
import s3fs
|
||||||
@@ -160,17 +185,24 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
|||||||
if shuffle_shards:
|
if shuffle_shards:
|
||||||
self.append(wds.filters.shuffle(1000))
|
self.append(wds.filters.shuffle(1000))
|
||||||
|
|
||||||
if embedding_folder_url is not None:
|
if img_embedding_folder_url is not None:
|
||||||
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
|
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
|
||||||
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
|
self.append(skip_unassociated_shards(embeddings_url=img_embedding_folder_url, handler=handler))
|
||||||
|
if text_embedding_folder_url is not None:
|
||||||
|
self.append(skip_unassociated_shards(embeddings_url=text_embedding_folder_url, handler=handler))
|
||||||
|
|
||||||
self.append(wds.tarfile_to_samples(handler=handler))
|
self.append(wds.tarfile_to_samples(handler=handler))
|
||||||
self.append(wds.decode("pilrgb", handler=handler))
|
self.append(wds.decode("pilrgb", handler=handler))
|
||||||
if embedding_folder_url is not None:
|
if img_embedding_folder_url is not None:
|
||||||
# Then we are loading embeddings for a remote source
|
# Then we are loading image embeddings for a remote source
|
||||||
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
||||||
self.append(insert_embedding(embeddings_url=embedding_folder_url, index_width=index_width, handler=handler))
|
self.append(insert_embedding(embeddings_url=img_embedding_folder_url, index_width=index_width, sample_key='img_emb', handler=handler))
|
||||||
self.append(verify_keys)
|
if text_embedding_folder_url is not None:
|
||||||
|
# Then we are loading image embeddings for a remote source
|
||||||
|
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
||||||
|
self.append(insert_embedding(embeddings_url=text_embedding_folder_url, index_width=index_width, sample_key='text_emb', handler=handler))
|
||||||
|
self.append(join_embeddings)
|
||||||
|
self.append(key_verifier(required_keys=keys, handler=handler))
|
||||||
# Apply preprocessing
|
# Apply preprocessing
|
||||||
self.append(wds.map(self.preproc))
|
self.append(wds.map(self.preproc))
|
||||||
self.append(wds.to_tuple(*keys))
|
self.append(wds.to_tuple(*keys))
|
||||||
@@ -185,7 +217,8 @@ def create_image_embedding_dataloader(
|
|||||||
tar_url,
|
tar_url,
|
||||||
num_workers,
|
num_workers,
|
||||||
batch_size,
|
batch_size,
|
||||||
embeddings_url=None,
|
img_embeddings_url=None,
|
||||||
|
text_embeddings_url=None,
|
||||||
index_width=None,
|
index_width=None,
|
||||||
shuffle_num = None,
|
shuffle_num = None,
|
||||||
shuffle_shards = True,
|
shuffle_shards = True,
|
||||||
@@ -211,7 +244,8 @@ def create_image_embedding_dataloader(
|
|||||||
"""
|
"""
|
||||||
ds = ImageEmbeddingDataset(
|
ds = ImageEmbeddingDataset(
|
||||||
tar_url,
|
tar_url,
|
||||||
embeddings_url,
|
img_embedding_folder_url=img_embeddings_url,
|
||||||
|
text_embedding_folder_url=text_embeddings_url,
|
||||||
index_width=index_width,
|
index_width=index_width,
|
||||||
shuffle_shards=shuffle_shards,
|
shuffle_shards=shuffle_shards,
|
||||||
resample=resample_shards,
|
resample=resample_shards,
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from dalle2_pytorch.dalle2_pytorch import (
|
|||||||
Decoder,
|
Decoder,
|
||||||
DiffusionPrior,
|
DiffusionPrior,
|
||||||
DiffusionPriorNetwork,
|
DiffusionPriorNetwork,
|
||||||
XClipAdapter,
|
XClipAdapter
|
||||||
)
|
)
|
||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
@@ -170,6 +170,8 @@ class DecoderConfig(BaseModel):
|
|||||||
unets: ListOrTuple(UnetConfig)
|
unets: ListOrTuple(UnetConfig)
|
||||||
image_size: int = None
|
image_size: int = None
|
||||||
image_sizes: ListOrTuple(int) = None
|
image_sizes: ListOrTuple(int) = None
|
||||||
|
condition_on_text_encodings: bool = False
|
||||||
|
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
||||||
channels: int = 3
|
channels: int = 3
|
||||||
timesteps: int = 1000
|
timesteps: int = 1000
|
||||||
loss_type: str = 'l2'
|
loss_type: str = 'l2'
|
||||||
@@ -180,9 +182,16 @@ class DecoderConfig(BaseModel):
|
|||||||
|
|
||||||
def create(self):
|
def create(self):
|
||||||
decoder_kwargs = self.dict()
|
decoder_kwargs = self.dict()
|
||||||
|
|
||||||
unet_configs = decoder_kwargs.pop('unets')
|
unet_configs = decoder_kwargs.pop('unets')
|
||||||
unets = [Unet(**config) for config in unet_configs]
|
unets = [Unet(**config) for config in unet_configs]
|
||||||
return Decoder(unets, **decoder_kwargs)
|
|
||||||
|
has_clip = exists(decoder_kwargs.pop('clip'))
|
||||||
|
clip = None
|
||||||
|
if has_clip:
|
||||||
|
clip = self.clip.create()
|
||||||
|
|
||||||
|
return Decoder(unets, clip=clip, **decoder_kwargs)
|
||||||
|
|
||||||
@validator('image_sizes')
|
@validator('image_sizes')
|
||||||
def check_image_sizes(cls, image_sizes, values):
|
def check_image_sizes(cls, image_sizes, values):
|
||||||
@@ -194,8 +203,9 @@ class DecoderConfig(BaseModel):
|
|||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|
||||||
class DecoderDataConfig(BaseModel):
|
class DecoderDataConfig(BaseModel):
|
||||||
webdataset_base_url: str # path to a webdataset with jpg images
|
webdataset_base_url: str # path to a webdataset with jpg images
|
||||||
embeddings_url: str # path to .npy files with embeddings
|
img_embeddings_url: Optional[str] # path to .npy files with embeddings
|
||||||
|
text_embeddings_url: Optional[str] # path to .npy files with embeddings
|
||||||
num_workers: int = 4
|
num_workers: int = 4
|
||||||
batch_size: int = 64
|
batch_size: int = 64
|
||||||
start_shard: int = 0
|
start_shard: int = 0
|
||||||
@@ -268,3 +278,26 @@ class TrainDecoderConfig(BaseModel):
|
|||||||
with open(json_path) as f:
|
with open(json_path) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
return cls(**config)
|
return cls(**config)
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def check_has_embeddings(cls, values):
|
||||||
|
# Makes sure that enough information is provided to get the embeddings specified for training
|
||||||
|
data_config, decoder_config = values.get('data'), values.get('decoder')
|
||||||
|
if data_config is None or decoder_config is None:
|
||||||
|
# Then something else errored and we should just pass through
|
||||||
|
return values
|
||||||
|
using_text_embeddings = decoder_config.condition_on_text_encodings
|
||||||
|
using_clip = exists(decoder_config.clip)
|
||||||
|
img_emb_url = data_config.img_embeddings_url
|
||||||
|
text_emb_url = data_config.text_embeddings_url
|
||||||
|
if using_text_embeddings:
|
||||||
|
# Then we need some way to get the embeddings
|
||||||
|
assert using_clip or text_emb_url is not None, 'If condition_on_text_encodings is true, either clip or text_embeddings_url must be provided'
|
||||||
|
if using_clip:
|
||||||
|
if using_text_embeddings:
|
||||||
|
assert text_emb_url is None or img_emb_url is None, 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
|
||||||
|
else:
|
||||||
|
assert img_emb_url is None, 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
|
||||||
|
if text_emb_url:
|
||||||
|
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
|
||||||
|
return values
|
||||||
|
|||||||
@@ -578,6 +578,18 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@cast_torch_tensor
|
||||||
|
@prior_sample_in_chunks
|
||||||
|
def embed_text(self, *args, **kwargs):
|
||||||
|
return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@cast_torch_tensor
|
||||||
|
@prior_sample_in_chunks
|
||||||
|
def embed_image(self, *args, **kwargs):
|
||||||
|
return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs)
|
||||||
|
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
158
train_decoder.py
158
train_decoder.py
@@ -6,6 +6,7 @@ from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
|
|||||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||||
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
||||||
|
from clip import tokenize
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
import torch
|
import torch
|
||||||
@@ -33,7 +34,8 @@ def exists(val):
|
|||||||
def create_dataloaders(
|
def create_dataloaders(
|
||||||
available_shards,
|
available_shards,
|
||||||
webdataset_base_url,
|
webdataset_base_url,
|
||||||
embeddings_url,
|
img_embeddings_url=None,
|
||||||
|
text_embeddings_url=None,
|
||||||
shard_width=6,
|
shard_width=6,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
@@ -63,14 +65,15 @@ def create_dataloaders(
|
|||||||
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
|
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
|
||||||
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
|
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
|
||||||
|
|
||||||
create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader(
|
create_dataloader = lambda tar_urls, shuffle=False, resample=False, for_sampling=False: create_image_embedding_dataloader(
|
||||||
tar_url=tar_urls,
|
tar_url=tar_urls,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
batch_size=batch_size if not for_sampling else n_sample_images,
|
batch_size=batch_size if not for_sampling else n_sample_images,
|
||||||
embeddings_url=embeddings_url,
|
img_embeddings_url=img_embeddings_url,
|
||||||
|
text_embeddings_url=text_embeddings_url,
|
||||||
index_width=index_width,
|
index_width=index_width,
|
||||||
shuffle_num = None,
|
shuffle_num = None,
|
||||||
extra_keys= ["txt"] if with_text else [],
|
extra_keys= ["txt"],
|
||||||
shuffle_shards = shuffle,
|
shuffle_shards = shuffle,
|
||||||
resample_shards = resample,
|
resample_shards = resample,
|
||||||
img_preproc=img_preproc,
|
img_preproc=img_preproc,
|
||||||
@@ -79,8 +82,8 @@ def create_dataloaders(
|
|||||||
|
|
||||||
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
|
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
|
||||||
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
|
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
|
||||||
val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True)
|
val_dataloader = create_dataloader(val_urls, shuffle=False)
|
||||||
test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True)
|
test_dataloader = create_dataloader(test_urls, shuffle=False)
|
||||||
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
|
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
|
||||||
return {
|
return {
|
||||||
"train": train_dataloader,
|
"train": train_dataloader,
|
||||||
@@ -104,42 +107,65 @@ def get_example_data(dataloader, device, n=5):
|
|||||||
Samples the dataloader and returns a zipped list of examples
|
Samples the dataloader and returns a zipped list of examples
|
||||||
"""
|
"""
|
||||||
images = []
|
images = []
|
||||||
embeddings = []
|
img_embeddings = []
|
||||||
|
text_embeddings = []
|
||||||
captions = []
|
captions = []
|
||||||
dataset_keys = get_dataset_keys(dataloader)
|
for img, emb, txt in dataloader:
|
||||||
has_caption = "txt" in dataset_keys
|
img_emb, text_emb = emb.get('img'), emb.get('text')
|
||||||
for data in dataloader:
|
if img_emb is not None:
|
||||||
if has_caption:
|
img_emb = img_emb.to(device=device, dtype=torch.float)
|
||||||
img, emb, txt = data
|
img_embeddings.extend(list(img_emb))
|
||||||
else:
|
else:
|
||||||
img, emb = data
|
# Then we add None img.shape[0] times
|
||||||
txt = [""] * emb.shape[0]
|
img_embeddings.extend([None]*img.shape[0])
|
||||||
|
if text_emb is not None:
|
||||||
|
text_emb = text_emb.to(device=device, dtype=torch.float)
|
||||||
|
text_embeddings.extend(list(text_emb))
|
||||||
|
else:
|
||||||
|
# Then we add None img.shape[0] times
|
||||||
|
text_embeddings.extend([None]*img.shape[0])
|
||||||
img = img.to(device=device, dtype=torch.float)
|
img = img.to(device=device, dtype=torch.float)
|
||||||
emb = emb.to(device=device, dtype=torch.float)
|
|
||||||
images.extend(list(img))
|
images.extend(list(img))
|
||||||
embeddings.extend(list(emb))
|
|
||||||
captions.extend(list(txt))
|
captions.extend(list(txt))
|
||||||
if len(images) >= n:
|
if len(images) >= n:
|
||||||
break
|
break
|
||||||
return list(zip(images[:n], embeddings[:n], captions[:n]))
|
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
|
||||||
|
|
||||||
def generate_samples(trainer, example_data, text_prepend=""):
|
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""):
|
||||||
"""
|
"""
|
||||||
Takes example data and generates images from the embeddings
|
Takes example data and generates images from the embeddings
|
||||||
Returns three lists: real images, generated images, and captions
|
Returns three lists: real images, generated images, and captions
|
||||||
"""
|
"""
|
||||||
real_images, embeddings, txts = zip(*example_data)
|
real_images, img_embeddings, text_embeddings, txts = zip(*example_data)
|
||||||
embeddings_tensor = torch.stack(embeddings)
|
sample_params = {}
|
||||||
samples = trainer.sample(embeddings_tensor)
|
if img_embeddings[0] is None:
|
||||||
|
# Generate image embeddings from clip
|
||||||
|
imgs_tensor = torch.stack(real_images)
|
||||||
|
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
|
||||||
|
sample_params["image_embed"] = img_embeddings
|
||||||
|
else:
|
||||||
|
# Then we are using precomputed image embeddings
|
||||||
|
img_embeddings = torch.stack(img_embeddings)
|
||||||
|
sample_params["image_embed"] = img_embeddings
|
||||||
|
if condition_on_text_encodings:
|
||||||
|
if text_embeddings[0] is None:
|
||||||
|
# Generate text embeddings from text
|
||||||
|
tokenized_texts = tokenize(txts, truncate=True)
|
||||||
|
sample_params["text"] = tokenized_texts
|
||||||
|
else:
|
||||||
|
# Then we are using precomputed text embeddings
|
||||||
|
text_embeddings = torch.stack(text_embeddings)
|
||||||
|
sample_params["text_encodings"] = text_embeddings
|
||||||
|
samples = trainer.sample(**sample_params)
|
||||||
generated_images = list(samples)
|
generated_images = list(samples)
|
||||||
captions = [text_prepend + txt for txt in txts]
|
captions = [text_prepend + txt for txt in txts]
|
||||||
return real_images, generated_images, captions
|
return real_images, generated_images, captions
|
||||||
|
|
||||||
def generate_grid_samples(trainer, examples, text_prepend=""):
|
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
|
||||||
"""
|
"""
|
||||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||||
"""
|
"""
|
||||||
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
|
||||||
|
|
||||||
real_image_size = real_images[0].shape[-1]
|
real_image_size = real_images[0].shape[-1]
|
||||||
generated_image_size = generated_images[0].shape[-1]
|
generated_image_size = generated_images[0].shape[-1]
|
||||||
@@ -151,7 +177,7 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
|
|||||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||||
return grid_images, captions
|
return grid_images, captions
|
||||||
|
|
||||||
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||||
"""
|
"""
|
||||||
Computes evaluation metrics for the decoder
|
Computes evaluation metrics for the decoder
|
||||||
"""
|
"""
|
||||||
@@ -161,7 +187,7 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID
|
|||||||
if len(examples) == 0:
|
if len(examples) == 0:
|
||||||
print("No data to evaluate. Check that your dataloader has shards.")
|
print("No data to evaluate. Check that your dataloader has shards.")
|
||||||
return metrics
|
return metrics
|
||||||
real_images, generated_images, captions = generate_samples(trainer, examples)
|
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings)
|
||||||
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||||
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||||
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||||
@@ -250,6 +276,7 @@ def train(
|
|||||||
save_latest=True,
|
save_latest=True,
|
||||||
save_best=True,
|
save_best=True,
|
||||||
unet_training_mask=None,
|
unet_training_mask=None,
|
||||||
|
condition_on_text_encodings=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -307,14 +334,22 @@ def train(
|
|||||||
last_snapshot = sample
|
last_snapshot = sample
|
||||||
|
|
||||||
if next_task == 'train':
|
if next_task == 'train':
|
||||||
for i, (img, emb) in enumerate(dataloaders["train"]):
|
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
|
||||||
# We want to count the total number of samples across all processes
|
# We want to count the total number of samples across all processes
|
||||||
sample_length_tensor[0] = len(img)
|
sample_length_tensor[0] = len(img)
|
||||||
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
||||||
total_samples = all_samples.sum().item()
|
total_samples = all_samples.sum().item()
|
||||||
sample += total_samples
|
sample += total_samples
|
||||||
samples_seen += total_samples
|
samples_seen += total_samples
|
||||||
img, emb = send_to_device((img, emb))
|
img_emb = emb.get('img')
|
||||||
|
has_img_embedding = img_emb is not None
|
||||||
|
if has_img_embedding:
|
||||||
|
img_emb, = send_to_device((img_emb,))
|
||||||
|
text_emb = emb.get('text')
|
||||||
|
has_text_embedding = text_emb is not None
|
||||||
|
if has_text_embedding:
|
||||||
|
text_emb, = send_to_device((text_emb,))
|
||||||
|
img, = send_to_device((img,))
|
||||||
|
|
||||||
trainer.train()
|
trainer.train()
|
||||||
for unet in range(1, trainer.num_unets+1):
|
for unet in range(1, trainer.num_unets+1):
|
||||||
@@ -322,7 +357,20 @@ def train(
|
|||||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
|
forward_params = {}
|
||||||
|
if has_img_embedding:
|
||||||
|
forward_params['image_embed'] = img_emb
|
||||||
|
else:
|
||||||
|
# Forward pass automatically generates embedding
|
||||||
|
pass
|
||||||
|
if condition_on_text_encodings:
|
||||||
|
if has_text_embedding:
|
||||||
|
forward_params['text_encodings'] = text_emb
|
||||||
|
else:
|
||||||
|
# Then we need to pass the text instead
|
||||||
|
tokenized_texts = tokenize(txt, truncate=True)
|
||||||
|
forward_params['text'] = tokenized_texts
|
||||||
|
loss = trainer.forward(img, **forward_params, unet_number=unet)
|
||||||
trainer.update(unet_number=unet)
|
trainer.update(unet_number=unet)
|
||||||
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
||||||
|
|
||||||
@@ -366,7 +414,7 @@ def train(
|
|||||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
||||||
if exists(n_sample_images) and n_sample_images > 0:
|
if exists(n_sample_images) and n_sample_images > 0:
|
||||||
trainer.eval()
|
trainer.eval()
|
||||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||||
|
|
||||||
if epoch_samples is not None and sample >= epoch_samples:
|
if epoch_samples is not None and sample >= epoch_samples:
|
||||||
@@ -389,14 +437,35 @@ def train(
|
|||||||
all_samples = accelerator.gather(val_sample_length_tensor)
|
all_samples = accelerator.gather(val_sample_length_tensor)
|
||||||
total_samples = all_samples.sum().item()
|
total_samples = all_samples.sum().item()
|
||||||
val_sample += total_samples
|
val_sample += total_samples
|
||||||
img, emb = send_to_device((img, emb))
|
img_emb = emb.get('img')
|
||||||
|
has_img_embedding = img_emb is not None
|
||||||
|
if has_img_embedding:
|
||||||
|
img_emb, = send_to_device((img_emb,))
|
||||||
|
text_emb = emb.get('text')
|
||||||
|
has_text_embedding = text_emb is not None
|
||||||
|
if has_text_embedding:
|
||||||
|
text_emb, = send_to_device((text_emb,))
|
||||||
|
img, = send_to_device((img,))
|
||||||
|
|
||||||
for unet in range(1, len(decoder.unets)+1):
|
for unet in range(1, len(decoder.unets)+1):
|
||||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||||
# No need to evaluate an unchanging unet
|
# No need to evaluate an unchanging unet
|
||||||
continue
|
continue
|
||||||
|
|
||||||
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
|
forward_params = {}
|
||||||
|
if has_img_embedding:
|
||||||
|
forward_params['image_embed'] = img_emb.float()
|
||||||
|
else:
|
||||||
|
# Forward pass automatically generates embedding
|
||||||
|
pass
|
||||||
|
if condition_on_text_encodings:
|
||||||
|
if has_text_embedding:
|
||||||
|
forward_params['text_encodings'] = text_emb.float()
|
||||||
|
else:
|
||||||
|
# Then we need to pass the text instead
|
||||||
|
tokenized_texts = tokenize(txt, truncate=True)
|
||||||
|
forward_params['text'] = tokenized_texts
|
||||||
|
loss = trainer.forward(img.float(), **forward_params, unet_number=unet)
|
||||||
average_val_loss_tensor[0, unet-1] += loss
|
average_val_loss_tensor[0, unet-1] += loss
|
||||||
|
|
||||||
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
||||||
@@ -423,7 +492,7 @@ def train(
|
|||||||
if next_task == 'eval':
|
if next_task == 'eval':
|
||||||
if exists(evaluate_config):
|
if exists(evaluate_config):
|
||||||
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
|
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
|
||||||
if is_master:
|
if is_master:
|
||||||
tracker.log(evaluation, step=step(), verbose=True)
|
tracker.log(evaluation, step=step(), verbose=True)
|
||||||
next_task = 'sample'
|
next_task = 'sample'
|
||||||
@@ -434,8 +503,8 @@ def train(
|
|||||||
# Generate examples and save the model if we are the master
|
# Generate examples and save the model if we are the master
|
||||||
# Generate sample images
|
# Generate sample images
|
||||||
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
||||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
|
test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ")
|
||||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||||
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
||||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||||
|
|
||||||
@@ -525,14 +594,35 @@ def initialize_training(config, config_path):
|
|||||||
# Create and initialize the tracker if we are the master
|
# Create and initialize the tracker if we are the master
|
||||||
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
|
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
|
||||||
|
|
||||||
|
has_img_embeddings = config.data.img_embeddings_url is not None
|
||||||
|
has_text_embeddings = config.data.text_embeddings_url is not None
|
||||||
|
conditioning_on_text = config.decoder.condition_on_text_encodings
|
||||||
|
has_clip_model = config.decoder.clip is not None
|
||||||
|
data_source_string = ""
|
||||||
|
if has_img_embeddings:
|
||||||
|
data_source_string += "precomputed image embeddings"
|
||||||
|
elif has_clip_model:
|
||||||
|
data_source_string += "clip image embeddings generation"
|
||||||
|
else:
|
||||||
|
raise ValueError("No image embeddings source specified")
|
||||||
|
if conditioning_on_text:
|
||||||
|
if has_text_embeddings:
|
||||||
|
data_source_string += " and precomputed text embeddings"
|
||||||
|
elif has_clip_model:
|
||||||
|
data_source_string += " and clip text encoding generation"
|
||||||
|
else:
|
||||||
|
raise ValueError("No text embeddings source specified")
|
||||||
|
|
||||||
accelerator.print(print_ribbon("Loaded Config", repeat=40))
|
accelerator.print(print_ribbon("Loaded Config", repeat=40))
|
||||||
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
|
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
|
||||||
|
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
|
||||||
accelerator.print(f"Number of parameters: {num_parameters}")
|
accelerator.print(f"Number of parameters: {num_parameters}")
|
||||||
train(dataloaders, decoder, accelerator,
|
train(dataloaders, decoder, accelerator,
|
||||||
tracker=tracker,
|
tracker=tracker,
|
||||||
inference_device=accelerator.device,
|
inference_device=accelerator.device,
|
||||||
load_config=config.load,
|
load_config=config.load,
|
||||||
evaluate_config=config.evaluate,
|
evaluate_config=config.evaluate,
|
||||||
|
condition_on_text_encodings=config.decoder.condition_on_text_encodings,
|
||||||
**config.train.dict(),
|
**config.train.dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user