mixed precision for training diffusion prior + save optimizer and scaler states

This commit is contained in:
Phil Wang
2022-05-02 09:31:04 -07:00
parent 1924c7cc3d
commit 7ee0ecc388

View File

@@ -62,7 +62,8 @@ def train(image_embed_dim,
device, device,
learning_rate=0.001, learning_rate=0.001,
max_grad_norm=0.5, max_grad_norm=0.5,
weight_decay=0.01): weight_decay=0.01,
amp=False):
# DiffusionPriorNetwork # DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork( prior_network = DiffusionPriorNetwork(
@@ -92,6 +93,7 @@ def train(image_embed_dim,
os.makedirs(save_path) os.makedirs(save_path)
### Training code ### ### Training code ###
scaler = GradScaler(enabled=amp)
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate) optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
epochs = num_epochs epochs = num_epochs
@@ -108,23 +110,33 @@ def train(image_embed_dim,
text_reader(batch_size=batch_size, start=0, end=train_set_size)): text_reader(batch_size=batch_size, start=0, end=train_set_size)):
emb_images_tensor = torch.tensor(emb_images[0]).to(device) emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device) emb_text_tensor = torch.tensor(emb_text[0]).to(device)
optimizer.zero_grad()
with autocast(enabled=amp):
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor) loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
loss.backward() scaler.scale(loss).backward()
# Samples per second # Samples per second
step+=1 step+=1
samples_per_sec = batch_size*step/(time.time()-t) samples_per_sec = batch_size*step/(time.time()-t)
# Save checkpoint every save_interval minutes # Save checkpoint every save_interval minutes
if(int(time.time()-t) >= 60*save_interval): if(int(time.time()-t) >= 60*save_interval):
t = time.time() t = time.time()
save_model(save_path,diffusion_prior.state_dict())
save_model(
save_path,
dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict()))
# Log to wandb # Log to wandb
wandb.log({"Training loss": loss.item(), wandb.log({"Training loss": loss.item(),
"Steps": step, "Steps": step,
"Samples per second": samples_per_sec}) "Samples per second": samples_per_sec})
scaler.unscale_(optimizer)
nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm) nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
optimizer.step()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
### Evaluate model(validation run) ### ### Evaluate model(validation run) ###
start = train_set_size start = train_set_size
@@ -171,12 +183,15 @@ def main():
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2) parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
parser.add_argument("--dp-loss-type", type=str, default="l2") parser.add_argument("--dp-loss-type", type=str, default="l2")
parser.add_argument("--clip", type=str, default=None) parser.add_argument("--clip", type=str, default=None)
parser.add_argument("--amp", type=bool, default=False)
# Model checkpointing interval(minutes) # Model checkpointing interval(minutes)
parser.add_argument("--save-interval", type=int, default=30) parser.add_argument("--save-interval", type=int, default=30)
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints") parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
args = parser.parse_args() args = parser.parse_args()
print("Setting up wandb logging... Please wait...") print("Setting up wandb logging... Please wait...")
wandb.init( wandb.init(
entity=args.wandb_entity, entity=args.wandb_entity,
project=args.wandb_project, project=args.wandb_project,
@@ -186,6 +201,7 @@ def main():
"dataset": args.wandb_dataset, "dataset": args.wandb_dataset,
"epochs": args.num_epochs, "epochs": args.num_epochs,
}) })
print("wandb logging setup done!") print("wandb logging setup done!")
# Obtain the utilized device. # Obtain the utilized device.
@@ -216,7 +232,8 @@ def main():
device, device,
args.learning_rate, args.learning_rate,
args.max_grad_norm, args.max_grad_norm,
args.weight_decay) args.weight_decay,
args.amp)
if __name__ == "__main__": if __name__ == "__main__":
main() main()