Update train_diffusion_prior.py (#53)

This commit is contained in:
Kumar R
2022-05-03 11:14:57 +05:30
committed by GitHub
parent 81d83dd7f2
commit 72c16b496e

View File

@@ -7,6 +7,9 @@ from torch import nn
from embedding_reader import EmbeddingReader from embedding_reader import EmbeddingReader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.optimizer import get_optimizer from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.optimizer import get_optimizer
from torch.cuda.amp import autocast,GradScaler
import time import time
from tqdm import tqdm from tqdm import tqdm
@@ -136,7 +139,7 @@ def train(image_embed_dim,
"Samples per second": samples_per_sec}) "Samples per second": samples_per_sec})
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm) nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()