From d0cdeb324781c61b4d5227ea6e64b99568c7e6cb Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 27 Apr 2022 19:58:06 -0700 Subject: [PATCH] add ability for DALL-E2 to return PIL images with `return_pil_images = True` on forward, for those who have no clue about deep learning --- dalle2_pytorch/dalle2_pytorch.py | 10 +++++++++- setup.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 00d5a6d..2a0477c 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -7,6 +7,7 @@ from contextlib import contextmanager import torch import torch.nn.functional as F from torch import nn, einsum +import torchvision.transforms as T from einops import rearrange, repeat from einops.layers.torch import Rearrange @@ -1516,12 +1517,15 @@ class DALLE2(nn.Module): self.prior_num_samples = prior_num_samples self.decoder_need_text_cond = self.decoder.condition_on_text_encodings + self.to_pil = T.ToPILImage() + @torch.no_grad() @eval_decorator def forward( self, text, - cond_scale = 1. + cond_scale = 1., + return_pil_images = False ): device = next(self.parameters()).device one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1) @@ -1535,7 +1539,11 @@ class DALLE2(nn.Module): text_cond = text if self.decoder_need_text_cond else None images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale) + if return_pil_images: + images = list(map(self.to_pil, images.unbind(dim = 0))) + if one_text: return images[0] return images + diff --git a/setup.py b/setup.py index 54ba910..03dc703 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.58', + version = '0.0.59', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',