Update decoder dataloader (#105)

* Updated the decoder dataloader
Removed unnecessary logging for required packages
Transferred to using index width instead of shard width
Added the ability to select extra keys to return from the webdataset

* Added README for decoder loader
This commit is contained in:
Aidan Dempster
2022-05-20 19:38:55 -04:00
committed by GitHub
parent db0642c4cd
commit c85e0d5c35
2 changed files with 120 additions and 15 deletions

View File

@@ -3,6 +3,7 @@ import webdataset as wds
import torch
import numpy as np
import fsspec
import shutil
def get_shard(filename):
"""
@@ -20,7 +21,7 @@ def get_example_file(fs, path, file_format):
"""
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handlers.reraise_exception):
def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handlers.reraise_exception):
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
previous_tar_url = None
current_embeddings = None
@@ -50,8 +51,12 @@ def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handler
previous_tar_url = tar_url
current_embeddings = load_corresponding_embeds(tar_url)
embedding_index = int(key[shard_width:])
sample["npy"] = current_embeddings[embedding_index]
embedding_index = int(key[-index_width:])
embedding = current_embeddings[embedding_index]
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
if torch.count_nonzero(embedding) == 0:
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
sample["npy"] = embedding
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
@@ -60,6 +65,28 @@ def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handler
break
insert_embedding = wds.filters.pipelinefilter(embedding_inserter)
def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.reraise_exception):
"""Finds if the is a corresponding embedding for the tarfile at { url: [URL] }"""
embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
embedding_files = embeddings_fs.ls(embeddings_path)
get_embedding_shard = lambda embedding_file: int(embedding_file.split("_")[-1].split(".")[0])
embedding_shards = set([get_embedding_shard(filename) for filename in embedding_files]) # Sets have O(1) check for member
get_tar_shard = lambda tar_file: int(tar_file.split("/")[-1].split(".")[0])
for tarfile in tarfiles:
try:
webdataset_shard = get_tar_shard(tarfile["url"])
# If this shard has an associated embeddings file, we pass it through. Otherwise we iterate until we do have one
if webdataset_shard in embedding_shards:
yield tarfile
except Exception as exn: # From wds implementation
if handler(exn):
continue
else:
break
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
def verify_keys(samples, handler=wds.handlers.reraise_exception):
"""
Requires that both the image and embedding are present in the sample
@@ -86,7 +113,9 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
self,
urls,
embedding_folder_url=None,
shard_width=None,
index_width=None,
img_preproc=None,
extra_keys=[],
handler=wds.handlers.reraise_exception,
resample=False,
shuffle_shards=True
@@ -97,13 +126,31 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
:param urls: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
:param embedding_folder_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
:param shard_width: The number of digits in the shard number. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard with this 4 and the last three digits are the index.
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
:param img_preproc: This function is run on the img before it is batched and returned. Useful for data augmentation or converting to torch tensor.
:param handler: A webdataset handler.
:param resample: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
:param shuffle_shards: If true, shuffle the shards before resampling. This cannot be true if resample is true.
"""
super().__init__()
keys = ["jpg", "npy"] + extra_keys
self.key_map = {key: i for i, key in enumerate(keys)}
self.resampling = resample
self.img_preproc = img_preproc
# If s3, check if s3fs is installed and s3cmd is installed and check if the data is piped instead of straight up
if (isinstance(urls, str) and "s3:" in urls) or (isinstance(urls, list) and any(["s3:" in url for url in urls])):
# Then this has an s3 link for the webdataset and we need extra packages
if shutil.which("s3cmd") is None:
raise RuntimeError("s3cmd is required for s3 webdataset")
if "s3:" in embedding_folder_url:
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
try:
import s3fs
except ImportError:
raise RuntimeError("s3fs is required to load embeddings from s3")
# Add the shardList and randomize or resample if requested
if resample:
assert not shuffle_shards, "Cannot both resample and shuffle"
@@ -112,28 +159,43 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
self.append(wds.SimpleShardList(urls))
if shuffle_shards:
self.append(wds.filters.shuffle(1000))
if embedding_folder_url is not None:
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
self.append(wds.split_by_node)
self.append(wds.split_by_worker)
self.append(wds.tarfile_to_samples(handler=handler))
self.append(wds.decode("torchrgb"))
self.append(wds.decode("pilrgb", handler=handler))
if embedding_folder_url is not None:
assert shard_width is not None, "Reading embeddings separately requires shard length to be given"
self.append(insert_embedding(embeddings_url=embedding_folder_url, shard_width=shard_width, handler=handler))
# Then we are loading embeddings for a remote source
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
self.append(insert_embedding(embeddings_url=embedding_folder_url, index_width=index_width, handler=handler))
self.append(verify_keys)
self.append(wds.to_tuple("jpg", "npy"))
# Apply preprocessing
self.append(wds.map(self.preproc))
self.append(wds.to_tuple(*keys))
def preproc(self, sample):
"""Applies the preprocessing for images"""
if self.img_preproc is not None:
sample["jpg"] = self.img_preproc(sample["jpg"])
return sample
def create_image_embedding_dataloader(
tar_url,
num_workers,
batch_size,
embeddings_url=None,
shard_width=None,
index_width=None,
shuffle_num = None,
shuffle_shards = True,
resample_shards = False,
handler=wds.handlers.warn_and_continue
img_preproc=None,
extra_keys=[],
handler=wds.handlers.reraise_exception#warn_and_continue
):
"""
Convenience function to create an image embedding dataseta and dataloader in one line
@@ -143,8 +205,8 @@ def create_image_embedding_dataloader(
:param batch_size: The batch size to use for the dataloader
:param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
:param shard_width: The number of digits in the shard number. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index.
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
:param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling.
:param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true.
:param resample_shards: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
@@ -153,9 +215,11 @@ def create_image_embedding_dataloader(
ds = ImageEmbeddingDataset(
tar_url,
embeddings_url,
shard_width=shard_width,
index_width=index_width,
shuffle_shards=shuffle_shards,
resample=resample_shards,
extra_keys=extra_keys,
img_preproc=img_preproc,
handler=handler
)
if shuffle_num is not None and shuffle_num > 0: