diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index 5ccb6d3..ee66659 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -7,6 +7,9 @@ from torch import nn from embedding_reader import EmbeddingReader from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork from dalle2_pytorch.optimizer import get_optimizer +from dalle2_pytorch.optimizer import get_optimizer +from torch.cuda.amp import autocast,GradScaler + import time from tqdm import tqdm @@ -136,7 +139,7 @@ def train(image_embed_dim, "Samples per second": samples_per_sec}) scaler.unscale_(optimizer) - nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm) + nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm) scaler.step(optimizer) scaler.update()