diff --git a/train_decoder.py b/train_decoder.py index f9d36eb..f06dadc 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -557,7 +557,7 @@ def initialize_training(config: TrainDecoderConfig, config_path): # Create the decoder model and print basic info decoder = config.decoder.create() - num_parameters = sum(p.numel() for p in decoder.parameters()) + get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training)) # Create and initialize the tracker if we are the master tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0) @@ -586,7 +586,10 @@ def initialize_training(config: TrainDecoderConfig, config_path): accelerator.print(print_ribbon("Loaded Config", repeat=40)) accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training") accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}") - accelerator.print(f"Number of parameters: {num_parameters}") + accelerator.print(f"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training") + for i, unet in enumerate(decoder.unets): + accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training") + train(dataloaders, decoder, accelerator, tracker=tracker, inference_device=accelerator.device,