mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Add static graph param (#226)
* Add static graph param * use static graph param
This commit is contained in:
@@ -556,7 +556,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
torch.manual_seed(config.seed)
|
||||
|
||||
# Set up accelerator for configurable distributed training
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters, static_graph=config.train.static_graph)
|
||||
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user