From 9025345e2984b76f5641b6347e2f32a068121cde Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 2 Jun 2022 11:33:15 -0700 Subject: [PATCH] take a stab at fixing generate_grid_samples when real images have a greater image size than generated --- train_decoder.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/train_decoder.py b/train_decoder.py index a4651d8..057ca6d 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -4,6 +4,7 @@ from dalle2_pytorch.dataloaders import create_image_embedding_dataloader from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker from dalle2_pytorch.train_configs import TrainDecoderConfig from dalle2_pytorch.utils import Timer, print_ribbon +from dalle2_pytorch.dalle2_pytorch import resize_image_to import torchvision import torch @@ -136,6 +137,14 @@ def generate_grid_samples(trainer, examples, text_prepend=""): Generates samples and uses torchvision to put them in a side by side grid for easy viewing """ real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend) + + real_image_size = real_images[0].shape[-1] + generated_image_size = generated_images[0].shape[-1] + + # training images may be larger than the generated one + if real_image_size > generated_image_size: + real_images = [resize_image_to(image, generated_image_size) for image in real_images] + grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)] return grid_images, captions