just clip the diffusion prior network parameters

This commit is contained in:
Phil Wang
2022-05-01 12:01:01 -07:00
parent 902693e271
commit d991b8c39c

View File

@@ -82,7 +82,7 @@ def train(image_embed_dim,
os.makedirs(save_path) os.makedirs(save_path)
### Training code ### ### Training code ###
optimizer = get_optimizer(diffusion_prior.parameters(), wd=weight_decay, lr=learning_rate) optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
epochs = num_epochs epochs = num_epochs
step = 0 step = 0
@@ -178,13 +178,12 @@ def main():
}) })
print("wandb logging setup done!") print("wandb logging setup done!")
# Obtain the utilized device. # Obtain the utilized device.
if torch.cuda.is_available():
has_cuda = torch.cuda.is_available()
if has_cuda:
device = torch.device("cuda:0") device = torch.device("cuda:0")
torch.cuda.set_device(device) torch.cuda.set_device(device)
has_cuda = True
else:
device = torch.device("cpu")
has_cuda = False
# Training loop # Training loop
train(args.image_embed_dim, train(args.image_embed_dim,
args.image_embed_url, args.image_embed_url,