mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
Add static graph param (#226)
* Add static graph param * use static graph param
This commit is contained in:
@@ -307,6 +307,7 @@ class DecoderTrainConfig(BaseModel):
|
|||||||
wd: SingularOrIterable[float] = 0.01
|
wd: SingularOrIterable[float] = 0.01
|
||||||
warmup_steps: Optional[SingularOrIterable[int]] = None
|
warmup_steps: Optional[SingularOrIterable[int]] = None
|
||||||
find_unused_parameters: bool = True
|
find_unused_parameters: bool = True
|
||||||
|
static_graph: bool = True
|
||||||
max_grad_norm: SingularOrIterable[float] = 0.5
|
max_grad_norm: SingularOrIterable[float] = 0.5
|
||||||
save_every_n_samples: int = 100000
|
save_every_n_samples: int = 100000
|
||||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||||
|
|||||||
@@ -556,7 +556,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
|||||||
torch.manual_seed(config.seed)
|
torch.manual_seed(config.seed)
|
||||||
|
|
||||||
# Set up accelerator for configurable distributed training
|
# Set up accelerator for configurable distributed training
|
||||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters, static_graph=config.train.static_graph)
|
||||||
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
|
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
|
||||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
|
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user