From 8c2015fd39a2078631c7256efdf7f0084812513e Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 27 Apr 2022 19:57:27 -0700 Subject: [PATCH] add ability for DALL-E2 to return PIL images with `return_pil_images = True` on forward, for those who have no clue about deep learning --- dalle2_pytorch/dalle2_pytorch.py | 12 +++++++++++- setup.py | 2 +- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 00d5a6d..58ec153 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -7,6 +7,7 @@ from contextlib import contextmanager import torch import torch.nn.functional as F from torch import nn, einsum +import torchvision.transforms as T from einops import rearrange, repeat from einops.layers.torch import Rearrange @@ -1516,12 +1517,15 @@ class DALLE2(nn.Module): self.prior_num_samples = prior_num_samples self.decoder_need_text_cond = self.decoder.condition_on_text_encodings + self.to_pil = T.ToPILImage() + @torch.no_grad() @eval_decorator def forward( self, text, - cond_scale = 1. + cond_scale = 1., + return_pil_images = False ): device = next(self.parameters()).device one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1) @@ -1535,7 +1539,13 @@ class DALLE2(nn.Module): text_cond = text if self.decoder_need_text_cond else None images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale) + if return_pil_images: + # do some magic - if the user passed in a string text, or a list of strings + # assume they do not know anything about tensors and return PIL Image(s) + images = list(map(self.to_pil, images.unbind(dim = 0))) + if one_text: return images[0] return images + diff --git a/setup.py b/setup.py index 54ba910..03dc703 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.58', + version = '0.0.59', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',