Files
generative-models/sgm/data/dataset.py
2023-07-25 15:21:30 +02:00

81 lines
2.9 KiB
Python

import logging
from typing import Optional
import torchdata.datapipes.iter
import webdataset as wds
from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule
logger = logging.getLogger(__name__)
try:
from sdata import create_dataset, create_dummy_dataset, create_loader
except ImportError as e:
raise NotImplementedError(
"Datasets not yet available. "
"To enable, we need to add stable-datasets as a submodule; "
"please use ``git submodule update --init --recursive`` "
"and do ``pip install -e stable-datasets/`` from the root of this repo"
) from e
class StableDataModuleFromConfig(LightningDataModule):
def __init__(
self,
train: DictConfig,
validation: Optional[DictConfig] = None,
test: Optional[DictConfig] = None,
skip_val_loader: bool = False,
dummy: bool = False,
):
super().__init__()
self.train_config = train
assert (
"datapipeline" in self.train_config and "loader" in self.train_config
), "train config requires the fields `datapipeline` and `loader`"
self.val_config = validation
if not skip_val_loader:
if self.val_config is not None:
assert (
"datapipeline" in self.val_config and "loader" in self.val_config
), "validation config requires the fields `datapipeline` and `loader`"
else:
logger.warning(
"No Validation datapipeline defined, using that one from training"
)
self.val_config = train
self.test_config = test
if self.test_config is not None:
assert (
"datapipeline" in self.test_config and "loader" in self.test_config
), "test config requires the fields `datapipeline` and `loader`"
self.dummy = dummy
if self.dummy:
logger.warning("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
def setup(self, stage: str) -> None:
logger.debug("Preparing datasets")
if self.dummy:
data_fn = create_dummy_dataset
else:
data_fn = create_dataset
self.train_datapipeline = data_fn(**self.train_config.datapipeline)
if self.val_config:
self.val_datapipeline = data_fn(**self.val_config.datapipeline)
if self.test_config:
self.test_datapipeline = data_fn(**self.test_config.datapipeline)
def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
loader = create_loader(self.train_datapipeline, **self.train_config.loader)
return loader
def val_dataloader(self) -> wds.DataPipeline:
return create_loader(self.val_datapipeline, **self.val_config.loader)
def test_dataloader(self) -> wds.DataPipeline:
return create_loader(self.test_datapipeline, **self.test_config.loader)