mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
address not calculating average eval / test loss when training diffusion prior https://github.com/lucidrains/DALLE2-pytorch/issues/49
This commit is contained in:
@@ -17,14 +17,24 @@ os.environ["WANDB_SILENT"] = "true"
|
|||||||
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
|
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
|
total_loss = 0.
|
||||||
|
total_samples = 0.
|
||||||
|
|
||||||
|
for emb_images, emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
|
||||||
text_reader(batch_size=batch_size, start=start, end=end)):
|
text_reader(batch_size=batch_size, start=start, end=end)):
|
||||||
|
|
||||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||||
|
|
||||||
|
batches = emb_images_tensor.shape[0]
|
||||||
|
|
||||||
loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
|
loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
|
||||||
|
|
||||||
# Log to wandb
|
total_loss += loss.item() * batches
|
||||||
wandb.log({f'{phase} {loss_type}': loss})
|
total_samples += batches
|
||||||
|
|
||||||
|
avg_loss = (total_loss / total_samples)
|
||||||
|
wandb.log({f'{phase} {loss_type}': avg_loss})
|
||||||
|
|
||||||
def save_model(save_path,state_dict):
|
def save_model(save_path,state_dict):
|
||||||
# Saving State Dict
|
# Saving State Dict
|
||||||
|
|||||||
Reference in New Issue
Block a user