mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 18:24:28 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f4edff214 |
@@ -64,6 +64,22 @@ class DecoderDataConfig(BaseModel):
|
|||||||
resample_train: bool = False
|
resample_train: bool = False
|
||||||
preprocessing: Dict[str, Any] = {'ToTensor': True}
|
preprocessing: Dict[str, Any] = {'ToTensor': True}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def img_preproc(self):
|
||||||
|
def _get_transformation(transformation_name, **kwargs):
|
||||||
|
if transformation_name == "RandomResizedCrop":
|
||||||
|
return T.RandomResizedCrop(**kwargs)
|
||||||
|
elif transformation_name == "RandomHorizontalFlip":
|
||||||
|
return T.RandomHorizontalFlip()
|
||||||
|
elif transformation_name == "ToTensor":
|
||||||
|
return T.ToTensor()
|
||||||
|
|
||||||
|
transforms = []
|
||||||
|
for transform_name, transform_kwargs_or_bool in self.preprocessing.items():
|
||||||
|
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
|
||||||
|
transforms.append(_get_transformation(transform_name, **transform_kwargs))
|
||||||
|
return T.Compose(transforms)
|
||||||
|
|
||||||
class DecoderTrainConfig(BaseModel):
|
class DecoderTrainConfig(BaseModel):
|
||||||
epochs: int = 20
|
epochs: int = 20
|
||||||
lr: float = 1e-4
|
lr: float = 1e-4
|
||||||
@@ -117,19 +133,3 @@ class TrainDecoderConfig(BaseModel):
|
|||||||
with open(json_path) as f:
|
with open(json_path) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
return cls(**config)
|
return cls(**config)
|
||||||
|
|
||||||
@property
|
|
||||||
def img_preproc(self):
|
|
||||||
def _get_transformation(transformation_name, **kwargs):
|
|
||||||
if transformation_name == "RandomResizedCrop":
|
|
||||||
return T.RandomResizedCrop(**kwargs)
|
|
||||||
elif transformation_name == "RandomHorizontalFlip":
|
|
||||||
return T.RandomHorizontalFlip()
|
|
||||||
elif transformation_name == "ToTensor":
|
|
||||||
return T.ToTensor()
|
|
||||||
|
|
||||||
transforms = []
|
|
||||||
for transform_name, transform_kwargs_or_bool in self.data.preprocessing.items():
|
|
||||||
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
|
|
||||||
transforms.append(_get_transformation(transform_name, **transform_kwargs))
|
|
||||||
return T.Compose(transforms)
|
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.4.6',
|
version = '0.4.7',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
|
|||||||
@@ -420,7 +420,7 @@ def initialize_training(config):
|
|||||||
|
|
||||||
dataloaders = create_dataloaders (
|
dataloaders = create_dataloaders (
|
||||||
available_shards=all_shards,
|
available_shards=all_shards,
|
||||||
img_preproc = config.img_preproc,
|
img_preproc = config.data.img_preproc,
|
||||||
train_prop = config.data.splits.train,
|
train_prop = config.data.splits.train,
|
||||||
val_prop = config.data.splits.val,
|
val_prop = config.data.splits.val,
|
||||||
test_prop = config.data.splits.test,
|
test_prop = config.data.splits.test,
|
||||||
|
|||||||
Reference in New Issue
Block a user