diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index 6c5bee8..cc133fa 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -307,6 +307,7 @@ class DecoderTrainConfig(BaseModel): wd: SingularOrIterable[float] = 0.01 warmup_steps: Optional[SingularOrIterable[int]] = None find_unused_parameters: bool = True + static_graph: bool = True max_grad_norm: SingularOrIterable[float] = 0.5 save_every_n_samples: int = 100000 n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset diff --git a/train_decoder.py b/train_decoder.py index a13bfaf..e2212c4 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -556,7 +556,7 @@ def initialize_training(config: TrainDecoderConfig, config_path): torch.manual_seed(config.seed) # Set up accelerator for configurable distributed training - ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters, static_graph=config.train.static_graph) init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60)) accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])