mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-20 23:04:21 +01:00
81 lines
2.9 KiB
Python
81 lines
2.9 KiB
Python
from typing import Optional
|
|
|
|
import torchdata.datapipes.iter
|
|
import webdataset as wds
|
|
from omegaconf import DictConfig
|
|
from pytorch_lightning import LightningDataModule
|
|
|
|
try:
|
|
from sdata import create_dataset, create_dummy_dataset, create_loader
|
|
except ImportError as e:
|
|
print("#" * 100)
|
|
print("Datasets not yet available")
|
|
print("to enable, we need to add stable-datasets as a submodule")
|
|
print("please use ``git submodule update --init --recursive``")
|
|
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
|
|
print("#" * 100)
|
|
exit(1)
|
|
|
|
|
|
class StableDataModuleFromConfig(LightningDataModule):
|
|
def __init__(
|
|
self,
|
|
train: DictConfig,
|
|
validation: Optional[DictConfig] = None,
|
|
test: Optional[DictConfig] = None,
|
|
skip_val_loader: bool = False,
|
|
dummy: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.train_config = train
|
|
assert (
|
|
"datapipeline" in self.train_config and "loader" in self.train_config
|
|
), "train config requires the fields `datapipeline` and `loader`"
|
|
|
|
self.val_config = validation
|
|
if not skip_val_loader:
|
|
if self.val_config is not None:
|
|
assert (
|
|
"datapipeline" in self.val_config and "loader" in self.val_config
|
|
), "validation config requires the fields `datapipeline` and `loader`"
|
|
else:
|
|
print(
|
|
"Warning: No Validation datapipeline defined, using that one from training"
|
|
)
|
|
self.val_config = train
|
|
|
|
self.test_config = test
|
|
if self.test_config is not None:
|
|
assert (
|
|
"datapipeline" in self.test_config and "loader" in self.test_config
|
|
), "test config requires the fields `datapipeline` and `loader`"
|
|
|
|
self.dummy = dummy
|
|
if self.dummy:
|
|
print("#" * 100)
|
|
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
|
|
print("#" * 100)
|
|
|
|
def setup(self, stage: str) -> None:
|
|
print("Preparing datasets")
|
|
if self.dummy:
|
|
data_fn = create_dummy_dataset
|
|
else:
|
|
data_fn = create_dataset
|
|
|
|
self.train_datapipeline = data_fn(**self.train_config.datapipeline)
|
|
if self.val_config:
|
|
self.val_datapipeline = data_fn(**self.val_config.datapipeline)
|
|
if self.test_config:
|
|
self.test_datapipeline = data_fn(**self.test_config.datapipeline)
|
|
|
|
def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
|
|
loader = create_loader(self.train_datapipeline, **self.train_config.loader)
|
|
return loader
|
|
|
|
def val_dataloader(self) -> wds.DataPipeline:
|
|
return create_loader(self.val_datapipeline, **self.val_config.loader)
|
|
|
|
def test_dataloader(self) -> wds.DataPipeline:
|
|
return create_loader(self.test_datapipeline, **self.test_config.loader)
|