from pathlib import Path import torch from torch.utils import data from torchvision import transforms, utils from PIL import Image # helpers functions def cycle(dl): while True: for data in dl: yield data # dataset and dataloader class Dataset(data.Dataset): def __init__( self, folder, image_size, exts = ['jpg', 'jpeg', 'png'] ): super().__init__() self.folder = folder self.image_size = image_size self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] self.transform = transforms.Compose([ transforms.Resize(image_size), transforms.RandomHorizontalFlip(), transforms.CenterCrop(image_size), transforms.ToTensor() ]) def __len__(self): return len(self.paths) def __getitem__(self, index): path = self.paths[index] img = Image.open(path) return self.transform(img) def get_images_dataloader( folder, *, batch_size, image_size, shuffle = True, cycle_dl = True, pin_memory = True ): ds = Dataset(folder, image_size) dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory) if cycle_dl: dl = cycle(dl) return dl