mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
mixed precision for training diffusion prior + save optimizer and scaler states
This commit is contained in:
@@ -62,7 +62,8 @@ def train(image_embed_dim,
|
|||||||
device,
|
device,
|
||||||
learning_rate=0.001,
|
learning_rate=0.001,
|
||||||
max_grad_norm=0.5,
|
max_grad_norm=0.5,
|
||||||
weight_decay=0.01):
|
weight_decay=0.01,
|
||||||
|
amp=False):
|
||||||
|
|
||||||
# DiffusionPriorNetwork
|
# DiffusionPriorNetwork
|
||||||
prior_network = DiffusionPriorNetwork(
|
prior_network = DiffusionPriorNetwork(
|
||||||
@@ -92,6 +93,7 @@ def train(image_embed_dim,
|
|||||||
os.makedirs(save_path)
|
os.makedirs(save_path)
|
||||||
|
|
||||||
### Training code ###
|
### Training code ###
|
||||||
|
scaler = GradScaler(enabled=amp)
|
||||||
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
||||||
epochs = num_epochs
|
epochs = num_epochs
|
||||||
|
|
||||||
@@ -108,23 +110,33 @@ def train(image_embed_dim,
|
|||||||
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
|
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
|
||||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||||
optimizer.zero_grad()
|
|
||||||
|
with autocast(enabled=amp):
|
||||||
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
|
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
|
||||||
loss.backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
# Samples per second
|
# Samples per second
|
||||||
step+=1
|
step+=1
|
||||||
samples_per_sec = batch_size*step/(time.time()-t)
|
samples_per_sec = batch_size*step/(time.time()-t)
|
||||||
# Save checkpoint every save_interval minutes
|
# Save checkpoint every save_interval minutes
|
||||||
if(int(time.time()-t) >= 60*save_interval):
|
if(int(time.time()-t) >= 60*save_interval):
|
||||||
t = time.time()
|
t = time.time()
|
||||||
save_model(save_path,diffusion_prior.state_dict())
|
|
||||||
|
save_model(
|
||||||
|
save_path,
|
||||||
|
dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict()))
|
||||||
|
|
||||||
# Log to wandb
|
# Log to wandb
|
||||||
wandb.log({"Training loss": loss.item(),
|
wandb.log({"Training loss": loss.item(),
|
||||||
"Steps": step,
|
"Steps": step,
|
||||||
"Samples per second": samples_per_sec})
|
"Samples per second": samples_per_sec})
|
||||||
|
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
|
nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
|
||||||
optimizer.step()
|
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
### Evaluate model(validation run) ###
|
### Evaluate model(validation run) ###
|
||||||
start = train_set_size
|
start = train_set_size
|
||||||
@@ -171,12 +183,15 @@ def main():
|
|||||||
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
|
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
|
||||||
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
||||||
parser.add_argument("--clip", type=str, default=None)
|
parser.add_argument("--clip", type=str, default=None)
|
||||||
|
parser.add_argument("--amp", type=bool, default=False)
|
||||||
# Model checkpointing interval(minutes)
|
# Model checkpointing interval(minutes)
|
||||||
parser.add_argument("--save-interval", type=int, default=30)
|
parser.add_argument("--save-interval", type=int, default=30)
|
||||||
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
|
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
print("Setting up wandb logging... Please wait...")
|
print("Setting up wandb logging... Please wait...")
|
||||||
|
|
||||||
wandb.init(
|
wandb.init(
|
||||||
entity=args.wandb_entity,
|
entity=args.wandb_entity,
|
||||||
project=args.wandb_project,
|
project=args.wandb_project,
|
||||||
@@ -186,6 +201,7 @@ def main():
|
|||||||
"dataset": args.wandb_dataset,
|
"dataset": args.wandb_dataset,
|
||||||
"epochs": args.num_epochs,
|
"epochs": args.num_epochs,
|
||||||
})
|
})
|
||||||
|
|
||||||
print("wandb logging setup done!")
|
print("wandb logging setup done!")
|
||||||
# Obtain the utilized device.
|
# Obtain the utilized device.
|
||||||
|
|
||||||
@@ -216,7 +232,8 @@ def main():
|
|||||||
device,
|
device,
|
||||||
args.learning_rate,
|
args.learning_rate,
|
||||||
args.max_grad_norm,
|
args.max_grad_norm,
|
||||||
args.weight_decay)
|
args.weight_decay,
|
||||||
|
args.amp)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user