mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Quality of life improvements for tracker savers (#210)
The default save location is now none so if keys are not specified the corresponding checkpoint type is not saved. Models and checkpoints are now both saved with version number and the config used to create them in order to simplify loading. Documentation was fixed to be in line with current usage.
This commit is contained in:
@@ -513,6 +513,7 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_
|
||||
}
|
||||
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
|
||||
tracker.save_config(config_path, config_name='decoder_config.json')
|
||||
tracker.add_save_metadata(state_dict_key='config', metadata=config.dict())
|
||||
return tracker
|
||||
|
||||
def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
|
||||
Reference in New Issue
Block a user