mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Update train_diffusion_prior.py (#53)
This commit is contained in:
@@ -7,6 +7,9 @@ from torch import nn
|
||||
from embedding_reader import EmbeddingReader
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
from torch.cuda.amp import autocast,GradScaler
|
||||
|
||||
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
@@ -136,7 +139,7 @@ def train(image_embed_dim,
|
||||
"Samples per second": samples_per_sec})
|
||||
|
||||
scaler.unscale_(optimizer)
|
||||
nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
|
||||
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
Reference in New Issue
Block a user