mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
take a stab at fixing generate_grid_samples when real images have a greater image size than generated
This commit is contained in:
@@ -4,6 +4,7 @@ from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
|||||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
||||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||||
|
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
import torch
|
import torch
|
||||||
@@ -136,6 +137,14 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
|
|||||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||||
"""
|
"""
|
||||||
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
||||||
|
|
||||||
|
real_image_size = real_images[0].shape[-1]
|
||||||
|
generated_image_size = generated_images[0].shape[-1]
|
||||||
|
|
||||||
|
# training images may be larger than the generated one
|
||||||
|
if real_image_size > generated_image_size:
|
||||||
|
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
|
||||||
|
|
||||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||||
return grid_images, captions
|
return grid_images, captions
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user