mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
60 lines
1.3 KiB
Python
60 lines
1.3 KiB
Python
from pathlib import Path
|
|
|
|
import torch
|
|
from torch.utils import data
|
|
from torchvision import transforms, utils
|
|
|
|
from PIL import Image
|
|
|
|
# helpers functions
|
|
|
|
def cycle(dl):
|
|
while True:
|
|
for data in dl:
|
|
yield data
|
|
|
|
# dataset and dataloader
|
|
|
|
class Dataset(data.Dataset):
|
|
def __init__(
|
|
self,
|
|
folder,
|
|
image_size,
|
|
exts = ['jpg', 'jpeg', 'png']
|
|
):
|
|
super().__init__()
|
|
self.folder = folder
|
|
self.image_size = image_size
|
|
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
|
|
|
|
self.transform = transforms.Compose([
|
|
transforms.Resize(image_size),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.CenterCrop(image_size),
|
|
transforms.ToTensor()
|
|
])
|
|
|
|
def __len__(self):
|
|
return len(self.paths)
|
|
|
|
def __getitem__(self, index):
|
|
path = self.paths[index]
|
|
img = Image.open(path)
|
|
return self.transform(img)
|
|
|
|
def get_images_dataloader(
|
|
folder,
|
|
*,
|
|
batch_size,
|
|
image_size,
|
|
shuffle = True,
|
|
cycle_dl = True,
|
|
pin_memory = True
|
|
):
|
|
ds = Dataset(folder, image_size)
|
|
dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
|
|
|
|
if cycle_dl:
|
|
dl = cycle(dl)
|
|
return dl
|