mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
* Updated the decoder dataloader Removed unnecessary logging for required packages Transferred to using index width instead of shard width Added the ability to select extra keys to return from the webdataset * Added README for decoder loader
234 lines
12 KiB
Python
234 lines
12 KiB
Python
import os
|
|
import webdataset as wds
|
|
import torch
|
|
import numpy as np
|
|
import fsspec
|
|
import shutil
|
|
|
|
def get_shard(filename):
|
|
"""
|
|
Filenames with shards in them have a consistent structure that we can take advantage of
|
|
Standard structure: path/to/file/prefix_string_00001.ext
|
|
"""
|
|
try:
|
|
return filename.split("_")[-1].split(".")[0]
|
|
except ValueError:
|
|
raise RuntimeError(f"Could not find shard for filename {filename}")
|
|
|
|
def get_example_file(fs, path, file_format):
|
|
"""
|
|
Given a file system and a file extension, return the example file
|
|
"""
|
|
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
|
|
|
|
def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handlers.reraise_exception):
|
|
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
|
|
previous_tar_url = None
|
|
current_embeddings = None
|
|
# Get a reference to an abstract file system where the embeddings are stored
|
|
embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
|
|
example_embedding_file = get_example_file(embeddings_fs, embeddings_path, "npy")
|
|
example_embedding_shard = get_shard(example_embedding_file)
|
|
emb_shard_width = len(example_embedding_shard)
|
|
# Easier to get the basename without the shard once than search through for the correct file every time
|
|
embedding_file_basename = '_'.join(example_embedding_file.split("_")[:-1]) + "_"
|
|
|
|
def load_corresponding_embeds(tar_url):
|
|
"""Finds and reads the npy files that contains embeddings for the given webdataset tar"""
|
|
shard = int(tar_url.split("/")[-1].split(".")[0])
|
|
embedding_url = embedding_file_basename + str(shard).zfill(emb_shard_width) + '.npy'
|
|
with embeddings_fs.open(embedding_url) as f:
|
|
data = np.load(f)
|
|
return torch.from_numpy(data)
|
|
|
|
for sample in samples:
|
|
try:
|
|
tar_url = sample["__url__"]
|
|
key = sample["__key__"]
|
|
if tar_url != previous_tar_url:
|
|
# If the tar changed, we need to download new embeddings
|
|
# This means if we shuffle before inserting it will load many more files than we expect and be very inefficient.
|
|
previous_tar_url = tar_url
|
|
current_embeddings = load_corresponding_embeds(tar_url)
|
|
|
|
embedding_index = int(key[-index_width:])
|
|
embedding = current_embeddings[embedding_index]
|
|
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
|
|
if torch.count_nonzero(embedding) == 0:
|
|
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
|
|
sample["npy"] = embedding
|
|
yield sample
|
|
except Exception as exn: # From wds implementation
|
|
if handler(exn):
|
|
continue
|
|
else:
|
|
break
|
|
insert_embedding = wds.filters.pipelinefilter(embedding_inserter)
|
|
|
|
def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.reraise_exception):
|
|
"""Finds if the is a corresponding embedding for the tarfile at { url: [URL] }"""
|
|
embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
|
|
embedding_files = embeddings_fs.ls(embeddings_path)
|
|
get_embedding_shard = lambda embedding_file: int(embedding_file.split("_")[-1].split(".")[0])
|
|
embedding_shards = set([get_embedding_shard(filename) for filename in embedding_files]) # Sets have O(1) check for member
|
|
|
|
get_tar_shard = lambda tar_file: int(tar_file.split("/")[-1].split(".")[0])
|
|
for tarfile in tarfiles:
|
|
try:
|
|
webdataset_shard = get_tar_shard(tarfile["url"])
|
|
# If this shard has an associated embeddings file, we pass it through. Otherwise we iterate until we do have one
|
|
if webdataset_shard in embedding_shards:
|
|
yield tarfile
|
|
except Exception as exn: # From wds implementation
|
|
if handler(exn):
|
|
continue
|
|
else:
|
|
break
|
|
|
|
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
|
|
|
|
def verify_keys(samples, handler=wds.handlers.reraise_exception):
|
|
"""
|
|
Requires that both the image and embedding are present in the sample
|
|
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
|
|
"""
|
|
for sample in samples:
|
|
try:
|
|
assert "jpg" in sample, f"Sample {sample['__key__']} missing image"
|
|
assert "npy" in sample, f"Sample {sample['__key__']} missing embedding. Did you set embedding_folder_url?"
|
|
yield sample
|
|
except Exception as exn: # From wds implementation
|
|
if handler(exn):
|
|
continue
|
|
else:
|
|
break
|
|
|
|
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
|
"""
|
|
A fluid interface wrapper for DataPipline that returns image embedding pairs
|
|
Reads embeddings as npy files from the webdataset if they exist. If embedding_folder_url is set, they will be inserted in from the alternate source.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
urls,
|
|
embedding_folder_url=None,
|
|
index_width=None,
|
|
img_preproc=None,
|
|
extra_keys=[],
|
|
handler=wds.handlers.reraise_exception,
|
|
resample=False,
|
|
shuffle_shards=True
|
|
):
|
|
"""
|
|
Modeled directly off of the WebDataset constructor
|
|
|
|
:param urls: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
|
|
:param embedding_folder_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
|
|
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
|
|
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
|
|
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
|
|
:param img_preproc: This function is run on the img before it is batched and returned. Useful for data augmentation or converting to torch tensor.
|
|
:param handler: A webdataset handler.
|
|
:param resample: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
|
|
:param shuffle_shards: If true, shuffle the shards before resampling. This cannot be true if resample is true.
|
|
|
|
|
|
"""
|
|
super().__init__()
|
|
keys = ["jpg", "npy"] + extra_keys
|
|
self.key_map = {key: i for i, key in enumerate(keys)}
|
|
self.resampling = resample
|
|
self.img_preproc = img_preproc
|
|
# If s3, check if s3fs is installed and s3cmd is installed and check if the data is piped instead of straight up
|
|
if (isinstance(urls, str) and "s3:" in urls) or (isinstance(urls, list) and any(["s3:" in url for url in urls])):
|
|
# Then this has an s3 link for the webdataset and we need extra packages
|
|
if shutil.which("s3cmd") is None:
|
|
raise RuntimeError("s3cmd is required for s3 webdataset")
|
|
if "s3:" in embedding_folder_url:
|
|
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
|
|
try:
|
|
import s3fs
|
|
except ImportError:
|
|
raise RuntimeError("s3fs is required to load embeddings from s3")
|
|
# Add the shardList and randomize or resample if requested
|
|
if resample:
|
|
assert not shuffle_shards, "Cannot both resample and shuffle"
|
|
self.append(wds.ResampledShards(urls))
|
|
else:
|
|
self.append(wds.SimpleShardList(urls))
|
|
if shuffle_shards:
|
|
self.append(wds.filters.shuffle(1000))
|
|
|
|
if embedding_folder_url is not None:
|
|
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
|
|
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
|
|
|
|
self.append(wds.split_by_node)
|
|
self.append(wds.split_by_worker)
|
|
|
|
self.append(wds.tarfile_to_samples(handler=handler))
|
|
self.append(wds.decode("pilrgb", handler=handler))
|
|
if embedding_folder_url is not None:
|
|
# Then we are loading embeddings for a remote source
|
|
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
|
self.append(insert_embedding(embeddings_url=embedding_folder_url, index_width=index_width, handler=handler))
|
|
self.append(verify_keys)
|
|
# Apply preprocessing
|
|
self.append(wds.map(self.preproc))
|
|
self.append(wds.to_tuple(*keys))
|
|
|
|
def preproc(self, sample):
|
|
"""Applies the preprocessing for images"""
|
|
if self.img_preproc is not None:
|
|
sample["jpg"] = self.img_preproc(sample["jpg"])
|
|
return sample
|
|
|
|
def create_image_embedding_dataloader(
|
|
tar_url,
|
|
num_workers,
|
|
batch_size,
|
|
embeddings_url=None,
|
|
index_width=None,
|
|
shuffle_num = None,
|
|
shuffle_shards = True,
|
|
resample_shards = False,
|
|
img_preproc=None,
|
|
extra_keys=[],
|
|
handler=wds.handlers.reraise_exception#warn_and_continue
|
|
):
|
|
"""
|
|
Convenience function to create an image embedding dataseta and dataloader in one line
|
|
|
|
:param tar_url: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
|
|
:param num_workers: The number of workers to use for the dataloader
|
|
:param batch_size: The batch size to use for the dataloader
|
|
:param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
|
|
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
|
|
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
|
|
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
|
|
:param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling.
|
|
:param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true.
|
|
:param resample_shards: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
|
|
:param handler: A webdataset handler.
|
|
"""
|
|
ds = ImageEmbeddingDataset(
|
|
tar_url,
|
|
embeddings_url,
|
|
index_width=index_width,
|
|
shuffle_shards=shuffle_shards,
|
|
resample=resample_shards,
|
|
extra_keys=extra_keys,
|
|
img_preproc=img_preproc,
|
|
handler=handler
|
|
)
|
|
if shuffle_num is not None and shuffle_num > 0:
|
|
ds.shuffle(1000)
|
|
return wds.WebLoader(
|
|
ds,
|
|
num_workers=num_workers,
|
|
batch_size=batch_size,
|
|
prefetch_factor=2, # This might be good to have high so the next npy file is prefetched
|
|
pin_memory=True,
|
|
shuffle=False
|
|
) |