diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index 78e98ab..a45199d 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -82,7 +82,7 @@ def train(image_embed_dim, os.makedirs(save_path) ### Training code ### - optimizer = get_optimizer(diffusion_prior.parameters(), wd=weight_decay, lr=learning_rate) + optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate) epochs = num_epochs step = 0 @@ -178,13 +178,12 @@ def main(): }) print("wandb logging setup done!") # Obtain the utilized device. - if torch.cuda.is_available(): + + has_cuda = torch.cuda.is_available() + if has_cuda: device = torch.device("cuda:0") torch.cuda.set_device(device) - has_cuda = True - else: - device = torch.device("cpu") - has_cuda = False + # Training loop train(args.image_embed_dim, args.image_embed_url,