Unet parameter count is now shown (#202)

This commit is contained in:
Aidan Dempster
2022-07-10 19:45:59 -04:00
committed by GitHub
parent 7ea314e2f0
commit 1a217e99e3

View File

@@ -557,7 +557,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
# Create the decoder model and print basic info # Create the decoder model and print basic info
decoder = config.decoder.create() decoder = config.decoder.create()
num_parameters = sum(p.numel() for p in decoder.parameters()) get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
# Create and initialize the tracker if we are the master # Create and initialize the tracker if we are the master
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0) tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
@@ -586,7 +586,10 @@ def initialize_training(config: TrainDecoderConfig, config_path):
accelerator.print(print_ribbon("Loaded Config", repeat=40)) accelerator.print(print_ribbon("Loaded Config", repeat=40))
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training") accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}") accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
accelerator.print(f"Number of parameters: {num_parameters}") accelerator.print(f"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training")
for i, unet in enumerate(decoder.unets):
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
train(dataloaders, decoder, accelerator, train(dataloaders, decoder, accelerator,
tracker=tracker, tracker=tracker,
inference_device=accelerator.device, inference_device=accelerator.device,