mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
just clip the diffusion prior network parameters
This commit is contained in:
@@ -82,7 +82,7 @@ def train(image_embed_dim,
|
|||||||
os.makedirs(save_path)
|
os.makedirs(save_path)
|
||||||
|
|
||||||
### Training code ###
|
### Training code ###
|
||||||
optimizer = get_optimizer(diffusion_prior.parameters(), wd=weight_decay, lr=learning_rate)
|
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
||||||
epochs = num_epochs
|
epochs = num_epochs
|
||||||
|
|
||||||
step = 0
|
step = 0
|
||||||
@@ -178,13 +178,12 @@ def main():
|
|||||||
})
|
})
|
||||||
print("wandb logging setup done!")
|
print("wandb logging setup done!")
|
||||||
# Obtain the utilized device.
|
# Obtain the utilized device.
|
||||||
if torch.cuda.is_available():
|
|
||||||
|
has_cuda = torch.cuda.is_available()
|
||||||
|
if has_cuda:
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
has_cuda = True
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
has_cuda = False
|
|
||||||
# Training loop
|
# Training loop
|
||||||
train(args.image_embed_dim,
|
train(args.image_embed_dim,
|
||||||
args.image_embed_url,
|
args.image_embed_url,
|
||||||
|
|||||||
Reference in New Issue
Block a user