mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
273
dalle2_pytorch/dataloaders/prior_loader.py
Normal file
273
dalle2_pytorch/dataloaders/prior_loader.py
Normal file
@@ -0,0 +1,273 @@
|
||||
from math import ceil
|
||||
from clip import tokenize
|
||||
from embedding_reader import EmbeddingReader
|
||||
from torch import from_numpy
|
||||
from torch.utils.data import IterableDataset, DataLoader
|
||||
|
||||
|
||||
class PriorEmbeddingDataset(IterableDataset):
|
||||
"""
|
||||
PriorEmbeddingDataset is a wrapper of EmbeddingReader.
|
||||
|
||||
It enables one to simplify the logic necessary to yield samples from
|
||||
the different EmbeddingReader configurations available.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
start: int,
|
||||
stop: int,
|
||||
image_reader,
|
||||
text_reader: EmbeddingReader = None,
|
||||
) -> None:
|
||||
super(PriorEmbeddingDataset).__init__()
|
||||
|
||||
self.text_conditioned = text_conditioned
|
||||
|
||||
if not self.text_conditioned:
|
||||
self.text_reader = text_reader
|
||||
|
||||
self.image_reader = image_reader
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __len__(self):
|
||||
return self.stop - self.start
|
||||
|
||||
def __iter__(self):
|
||||
# D.R.Y loader args
|
||||
loader_args = dict(
|
||||
batch_size=self.batch_size,
|
||||
start=self.start,
|
||||
end=self.stop,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
# if the data requested is text conditioned, only load images
|
||||
if self.text_conditioned:
|
||||
self.loader = self.image_reader(**loader_args)
|
||||
# otherwise, include text embeddings and bypass metadata
|
||||
else:
|
||||
self.loader = zip(
|
||||
self.image_reader(**loader_args), self.text_reader(**loader_args)
|
||||
)
|
||||
|
||||
# return the data loader in its formatted state
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return self.get_sample()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
||||
def __str__(self):
|
||||
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
|
||||
|
||||
def get_sample(self):
|
||||
"""
|
||||
pre-proocess data from either reader into a common format
|
||||
"""
|
||||
if self.text_conditioned:
|
||||
image_embedding, caption = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding)
|
||||
tokenized_caption = tokenize(caption["caption"].to_list(), truncate=True)
|
||||
|
||||
return image_embedding, tokenized_caption
|
||||
|
||||
else:
|
||||
(image_embedding, _), (text_embedding, _) = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding)
|
||||
text_embedding = from_numpy(text_embedding)
|
||||
|
||||
return image_embedding, text_embedding
|
||||
|
||||
|
||||
# helper functions
|
||||
|
||||
|
||||
def distribute_to_rank(start, stop, rank, world_size):
|
||||
"""
|
||||
Distribute data to each rank given the world size.
|
||||
|
||||
Return:
|
||||
- New start and stop points for this rank.
|
||||
"""
|
||||
num_samples = int(stop - start)
|
||||
|
||||
per_rank = int(ceil((num_samples) / float(world_size)))
|
||||
|
||||
assert (
|
||||
per_rank > 0
|
||||
), f"Number of samples per rank must be larger than 0, (found: {per_rank})"
|
||||
|
||||
rank_start = start + rank * per_rank
|
||||
|
||||
rank_stop = min(rank_start + per_rank, stop)
|
||||
|
||||
new_length = rank_stop - rank_start
|
||||
|
||||
assert (
|
||||
new_length > 0
|
||||
), "Calculated start and stop points result in a length of zero for this rank."
|
||||
|
||||
return rank_start, rank_stop
|
||||
|
||||
|
||||
def get_reader(
|
||||
text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None
|
||||
):
|
||||
"""
|
||||
Create an EmbeddingReader object from the specified URLs
|
||||
|
||||
get_reader() will always expect a url to image embeddings.
|
||||
|
||||
If text-conditioned, it will also expect a meta_url for the captions.
|
||||
Otherwise, it will need txt_url for the matching text embeddings.
|
||||
|
||||
Returns an image_reader object if text-conditioned.
|
||||
Otherwise it returns both an image_reader and a text_reader
|
||||
"""
|
||||
|
||||
assert img_url is not None, "Must supply a image url"
|
||||
|
||||
if text_conditioned:
|
||||
assert meta_url is not None, "Must supply meta url if text-conditioned"
|
||||
|
||||
image_reader = EmbeddingReader(
|
||||
embeddings_folder=img_url,
|
||||
file_format="parquet_npy",
|
||||
# will assume the caption column exists and is the only one requested
|
||||
meta_columns=["caption"],
|
||||
metadata_folder=meta_url,
|
||||
)
|
||||
|
||||
return image_reader
|
||||
|
||||
# otherwise we will require text embeddings as well and return two readers
|
||||
assert (
|
||||
txt_url is not None
|
||||
), "Must supply text embedding url if not text-conditioning"
|
||||
|
||||
image_reader = EmbeddingReader(img_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(txt_url, file_format="npy")
|
||||
|
||||
return image_reader, text_reader
|
||||
|
||||
|
||||
def make_splits(
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
num_data_points: int,
|
||||
train_split: float,
|
||||
eval_split: float,
|
||||
image_reader: EmbeddingReader,
|
||||
text_reader: EmbeddingReader = None,
|
||||
start=0,
|
||||
rank=0,
|
||||
world_size=1,
|
||||
):
|
||||
"""
|
||||
Split an embedding reader object as needed.
|
||||
|
||||
NOTE: make_splits() will infer the test set size from your train and eval.
|
||||
|
||||
Input:
|
||||
- text_conditioned: whether to prepare text-conditioned training data
|
||||
- batch_size: the batch size for a single gpu
|
||||
- num_data_points: the total number of data points you wish to train on
|
||||
- train_split: the percentage of data you wish to train on
|
||||
- eval_split: the percentage of data you wish to validate on
|
||||
- image_reader: the image_reader you wish to split
|
||||
- text_reader: the text_reader you want to split (if !text_conditioned)
|
||||
- start: the starting point within your dataset
|
||||
- rank: the rank of your worker
|
||||
- world_size: the total world size of your distributed training run
|
||||
|
||||
Returns:
|
||||
- PyTorch Dataloaders that yield tuples of (img, txt) data.
|
||||
"""
|
||||
|
||||
assert start < image_reader.count, "start position cannot exceed reader count."
|
||||
|
||||
# verify that the num_data_points does not exceed the max points
|
||||
if num_data_points > (image_reader.count - start):
|
||||
print(
|
||||
"Specified count is larger than what's available...defaulting to reader's count."
|
||||
)
|
||||
num_data_points = image_reader.count
|
||||
|
||||
# compute split points
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_start = train_set_size
|
||||
eval_stop = int(eval_start + eval_set_size)
|
||||
|
||||
assert (
|
||||
train_split + eval_split
|
||||
) < 1.0, "Specified train and eval split is too large to infer a test split."
|
||||
|
||||
# distribute to rank
|
||||
rank_train_start, rank_train_stop = distribute_to_rank(
|
||||
start, train_set_size, rank, world_size
|
||||
)
|
||||
rank_eval_start, rank_eval_stop = distribute_to_rank(
|
||||
train_set_size, eval_stop, rank, world_size
|
||||
)
|
||||
rank_test_start, rank_test_stop = distribute_to_rank(
|
||||
eval_stop, num_data_points, rank, world_size
|
||||
)
|
||||
|
||||
# wrap up splits into a dict
|
||||
train_split_args = dict(
|
||||
start=rank_train_start, stop=rank_train_stop, batch_size=batch_size
|
||||
)
|
||||
eval_split_args = dict(
|
||||
start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size
|
||||
)
|
||||
test_split_args = dict(
|
||||
start=rank_test_start, stop=rank_test_stop, batch_size=batch_size
|
||||
)
|
||||
|
||||
if text_conditioned:
|
||||
# add the text-conditioned args to a unified dict
|
||||
reader_args = dict(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
)
|
||||
|
||||
train_split_args = dict(**reader_args, **train_split_args)
|
||||
eval_split_args = dict(**reader_args, **eval_split_args)
|
||||
test_split_args = dict(**reader_args, **test_split_args)
|
||||
|
||||
train = PriorEmbeddingDataset(**train_split_args)
|
||||
val = PriorEmbeddingDataset(**eval_split_args)
|
||||
test = PriorEmbeddingDataset(**test_split_args)
|
||||
|
||||
else:
|
||||
# add the non-conditioned args to a unified dict
|
||||
reader_args = dict(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
)
|
||||
|
||||
train_split_args = dict(**reader_args, **train_split_args)
|
||||
eval_split_args = dict(**reader_args, **eval_split_args)
|
||||
test_split_args = dict(**reader_args, **test_split_args)
|
||||
|
||||
train = PriorEmbeddingDataset(**train_split_args)
|
||||
val = PriorEmbeddingDataset(**eval_split_args)
|
||||
test = PriorEmbeddingDataset(**test_split_args)
|
||||
|
||||
# true batch size is specifed in the PriorEmbeddingDataset
|
||||
train_loader = DataLoader(train, batch_size=None)
|
||||
eval_loader = DataLoader(val, batch_size=None)
|
||||
test_loader = DataLoader(test, batch_size=None)
|
||||
|
||||
return train_loader, eval_loader, test_loader
|
||||
Reference in New Issue
Block a user