mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
Ensure Eval Mode In Metric Functions (#79)
* add eval/train toggles * train/eval flags * shift train toggle Co-authored-by: nousr <z@localhost.com>
This commit is contained in:
@@ -42,6 +42,7 @@ def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_t
|
|||||||
wandb.log({f'{phase} {loss_type}': avg_loss})
|
wandb.log({f'{phase} {loss_type}': avg_loss})
|
||||||
|
|
||||||
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device):
|
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device):
|
||||||
|
diffusion_prior.eval()
|
||||||
|
|
||||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||||
|
|
||||||
@@ -170,10 +171,12 @@ def train(image_embed_dim,
|
|||||||
eval_start = train_set_size
|
eval_start = train_set_size
|
||||||
|
|
||||||
for _ in range(epochs):
|
for _ in range(epochs):
|
||||||
diffusion_prior.train()
|
|
||||||
|
|
||||||
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
|
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
|
||||||
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
|
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
|
||||||
|
|
||||||
|
diffusion_prior.train()
|
||||||
|
|
||||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user