From 72c16b496ef06e1f935aa83594fb6e87fc88ff33 Mon Sep 17 00:00:00 2001 From: Kumar R Date: Tue, 3 May 2022 11:14:57 +0530 Subject: [PATCH] Update train_diffusion_prior.py (#53) --- train_diffusion_prior.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index 5ccb6d3..ee66659 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -7,6 +7,9 @@ from torch import nn from embedding_reader import EmbeddingReader from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork from dalle2_pytorch.optimizer import get_optimizer +from dalle2_pytorch.optimizer import get_optimizer +from torch.cuda.amp import autocast,GradScaler + import time from tqdm import tqdm @@ -136,7 +139,7 @@ def train(image_embed_dim, "Samples per second": samples_per_sec}) scaler.unscale_(optimizer) - nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm) + nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm) scaler.step(optimizer) scaler.update()