Add static graph param (#226)

* Add static graph param

* use static graph param
This commit is contained in:
Romain Beaumont
2022-10-25 19:31:29 +02:00
committed by GitHub
parent c3df46e374
commit 9f37705d87
2 changed files with 2 additions and 1 deletions

View File

@@ -556,7 +556,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
torch.manual_seed(config.seed)
# Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters, static_graph=config.train.static_graph)
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])