From d991b8c39c2123953b98d8d58b43e0483583ca24 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 1 May 2022 12:01:01 -0700 Subject: [PATCH] just clip the diffusion prior network parameters --- train_diffusion_prior.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index 78e98ab..a45199d 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -82,7 +82,7 @@ def train(image_embed_dim, os.makedirs(save_path) ### Training code ### - optimizer = get_optimizer(diffusion_prior.parameters(), wd=weight_decay, lr=learning_rate) + optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate) epochs = num_epochs step = 0 @@ -178,13 +178,12 @@ def main(): }) print("wandb logging setup done!") # Obtain the utilized device. - if torch.cuda.is_available(): + + has_cuda = torch.cuda.is_available() + if has_cuda: device = torch.device("cuda:0") torch.cuda.set_device(device) - has_cuda = True - else: - device = torch.device("cpu") - has_cuda = False + # Training loop train(args.image_embed_dim, args.image_embed_url,