From 870aeeca627d4f2f7ff9e7c230cc771c2e555f9d Mon Sep 17 00:00:00 2001 From: Aidan Dempster Date: Fri, 8 Jul 2022 20:11:34 -0400 Subject: [PATCH] Fixed issue where evaluation would error when large image was loaded (#194) --- train_decoder.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/train_decoder.py b/train_decoder.py index 6ab9050..ee5c807 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -132,7 +132,7 @@ def get_example_data(dataloader, device, n=5): break return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n])) -def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""): +def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend="", match_image_size=True): """ Takes example data and generates images from the embeddings Returns three lists: real images, generated images, and captions @@ -160,6 +160,9 @@ def generate_samples(trainer, example_data, condition_on_text_encodings=False, t samples = trainer.sample(**sample_params) generated_images = list(samples) captions = [text_prepend + txt for txt in txts] + if match_image_size: + generated_image_size = generated_images[0].shape[-1] + real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images] return real_images, generated_images, captions def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""): @@ -167,14 +170,6 @@ def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, Generates samples and uses torchvision to put them in a side by side grid for easy viewing """ real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend) - - real_image_size = real_images[0].shape[-1] - generated_image_size = generated_images[0].shape[-1] - - # training images may be larger than the generated one - if real_image_size > generated_image_size: - real_images = [resize_image_to(image, generated_image_size) for image in real_images] - grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)] return grid_images, captions