mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
just clip the diffusion prior network parameters
This commit is contained in:
@@ -82,7 +82,7 @@ def train(image_embed_dim,
|
||||
os.makedirs(save_path)
|
||||
|
||||
### Training code ###
|
||||
optimizer = get_optimizer(diffusion_prior.parameters(), wd=weight_decay, lr=learning_rate)
|
||||
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
||||
epochs = num_epochs
|
||||
|
||||
step = 0
|
||||
@@ -178,13 +178,12 @@ def main():
|
||||
})
|
||||
print("wandb logging setup done!")
|
||||
# Obtain the utilized device.
|
||||
if torch.cuda.is_available():
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
if has_cuda:
|
||||
device = torch.device("cuda:0")
|
||||
torch.cuda.set_device(device)
|
||||
has_cuda = True
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
has_cuda = False
|
||||
|
||||
# Training loop
|
||||
train(args.image_embed_dim,
|
||||
args.image_embed_url,
|
||||
|
||||
Reference in New Issue
Block a user