mixed precision for training diffusion prior + save optimizer and scaler states

This commit is contained in:
Phil Wang
2022-05-02 09:31:04 -07:00
parent 1924c7cc3d
commit 7ee0ecc388

View File

@@ -62,7 +62,8 @@ def train(image_embed_dim,
device,
learning_rate=0.001,
max_grad_norm=0.5,
weight_decay=0.01):
weight_decay=0.01,
amp=False):
# DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork(
@@ -92,6 +93,7 @@ def train(image_embed_dim,
os.makedirs(save_path)
### Training code ###
scaler = GradScaler(enabled=amp)
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
epochs = num_epochs
@@ -108,23 +110,33 @@ def train(image_embed_dim,
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
optimizer.zero_grad()
with autocast(enabled=amp):
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
loss.backward()
scaler.scale(loss).backward()
# Samples per second
step+=1
samples_per_sec = batch_size*step/(time.time()-t)
# Save checkpoint every save_interval minutes
if(int(time.time()-t) >= 60*save_interval):
t = time.time()
save_model(save_path,diffusion_prior.state_dict())
save_model(
save_path,
dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict()))
# Log to wandb
wandb.log({"Training loss": loss.item(),
"Steps": step,
"Samples per second": samples_per_sec})
scaler.unscale_(optimizer)
nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
optimizer.step()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
### Evaluate model(validation run) ###
start = train_set_size
@@ -171,12 +183,15 @@ def main():
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
parser.add_argument("--dp-loss-type", type=str, default="l2")
parser.add_argument("--clip", type=str, default=None)
parser.add_argument("--amp", type=bool, default=False)
# Model checkpointing interval(minutes)
parser.add_argument("--save-interval", type=int, default=30)
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
args = parser.parse_args()
print("Setting up wandb logging... Please wait...")
wandb.init(
entity=args.wandb_entity,
project=args.wandb_project,
@@ -186,6 +201,7 @@ def main():
"dataset": args.wandb_dataset,
"epochs": args.num_epochs,
})
print("wandb logging setup done!")
# Obtain the utilized device.
@@ -216,7 +232,8 @@ def main():
device,
args.learning_rate,
args.max_grad_norm,
args.weight_decay)
args.weight_decay,
args.amp)
if __name__ == "__main__":
main()