import logging from typing import Optional import torchdata.datapipes.iter import webdataset as wds from omegaconf import DictConfig from pytorch_lightning import LightningDataModule logger = logging.getLogger(__name__) try: from sdata import create_dataset, create_dummy_dataset, create_loader except ImportError as e: raise NotImplementedError( "Datasets not yet available. " "To enable, we need to add stable-datasets as a submodule; " "please use ``git submodule update --init --recursive`` " "and do ``pip install -e stable-datasets/`` from the root of this repo" ) from e class StableDataModuleFromConfig(LightningDataModule): def __init__( self, train: DictConfig, validation: Optional[DictConfig] = None, test: Optional[DictConfig] = None, skip_val_loader: bool = False, dummy: bool = False, ): super().__init__() self.train_config = train assert ( "datapipeline" in self.train_config and "loader" in self.train_config ), "train config requires the fields `datapipeline` and `loader`" self.val_config = validation if not skip_val_loader: if self.val_config is not None: assert ( "datapipeline" in self.val_config and "loader" in self.val_config ), "validation config requires the fields `datapipeline` and `loader`" else: logger.warning( "No Validation datapipeline defined, using that one from training" ) self.val_config = train self.test_config = test if self.test_config is not None: assert ( "datapipeline" in self.test_config and "loader" in self.test_config ), "test config requires the fields `datapipeline` and `loader`" self.dummy = dummy if self.dummy: logger.warning("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)") def setup(self, stage: str) -> None: logger.debug("Preparing datasets") if self.dummy: data_fn = create_dummy_dataset else: data_fn = create_dataset self.train_datapipeline = data_fn(**self.train_config.datapipeline) if self.val_config: self.val_datapipeline = data_fn(**self.val_config.datapipeline) if self.test_config: self.test_datapipeline = data_fn(**self.test_config.datapipeline) def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe: loader = create_loader(self.train_datapipeline, **self.train_config.loader) return loader def val_dataloader(self) -> wds.DataPipeline: return create_loader(self.val_datapipeline, **self.val_config.loader) def test_dataloader(self) -> wds.DataPipeline: return create_loader(self.test_datapipeline, **self.test_config.loader)