From 9f37705d87e30cb4cfada944a3fb72aa49e603c5 Mon Sep 17 00:00:00 2001 From: Romain Beaumont Date: Tue, 25 Oct 2022 19:31:29 +0200 Subject: [PATCH] Add static graph param (#226) * Add static graph param * use static graph param --- dalle2_pytorch/train_configs.py | 1 + train_decoder.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index 6c5bee8..cc133fa 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -307,6 +307,7 @@ class DecoderTrainConfig(BaseModel): wd: SingularOrIterable[float] = 0.01 warmup_steps: Optional[SingularOrIterable[int]] = None find_unused_parameters: bool = True + static_graph: bool = True max_grad_norm: SingularOrIterable[float] = 0.5 save_every_n_samples: int = 100000 n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset diff --git a/train_decoder.py b/train_decoder.py index a13bfaf..e2212c4 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -556,7 +556,7 @@ def initialize_training(config: TrainDecoderConfig, config_path): torch.manual_seed(config.seed) # Set up accelerator for configurable distributed training - ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters, static_graph=config.train.static_graph) init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60)) accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])