diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index 7e1f10b..f2f8ad2 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -42,6 +42,7 @@ def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_t wandb.log({f'{phase} {loss_type}': avg_loss}) def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device): + diffusion_prior.eval() cos = nn.CosineSimilarity(dim=1, eps=1e-6) @@ -170,10 +171,12 @@ def train(image_embed_dim, eval_start = train_set_size for _ in range(epochs): - diffusion_prior.train() for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size), text_reader(batch_size=batch_size, start=0, end=train_set_size)): + + diffusion_prior.train() + emb_images_tensor = torch.tensor(emb_images[0]).to(device) emb_text_tensor = torch.tensor(emb_text[0]).to(device)