just clip the diffusion prior network parameters

This commit is contained in:
Phil Wang
2022-05-01 12:01:01 -07:00
parent 902693e271
commit d991b8c39c

View File

@@ -82,7 +82,7 @@ def train(image_embed_dim,
os.makedirs(save_path)
### Training code ###
optimizer = get_optimizer(diffusion_prior.parameters(), wd=weight_decay, lr=learning_rate)
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
epochs = num_epochs
step = 0
@@ -178,13 +178,12 @@ def main():
})
print("wandb logging setup done!")
# Obtain the utilized device.
if torch.cuda.is_available():
has_cuda = torch.cuda.is_available()
if has_cuda:
device = torch.device("cuda:0")
torch.cuda.set_device(device)
has_cuda = True
else:
device = torch.device("cpu")
has_cuda = False
# Training loop
train(args.image_embed_dim,
args.image_embed_url,