diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 00d5a6d..2a0477c 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -7,6 +7,7 @@ from contextlib import contextmanager import torch import torch.nn.functional as F from torch import nn, einsum +import torchvision.transforms as T from einops import rearrange, repeat from einops.layers.torch import Rearrange @@ -1516,12 +1517,15 @@ class DALLE2(nn.Module): self.prior_num_samples = prior_num_samples self.decoder_need_text_cond = self.decoder.condition_on_text_encodings + self.to_pil = T.ToPILImage() + @torch.no_grad() @eval_decorator def forward( self, text, - cond_scale = 1. + cond_scale = 1., + return_pil_images = False ): device = next(self.parameters()).device one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1) @@ -1535,7 +1539,11 @@ class DALLE2(nn.Module): text_cond = text if self.decoder_need_text_cond else None images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale) + if return_pil_images: + images = list(map(self.to_pil, images.unbind(dim = 0))) + if one_text: return images[0] return images + diff --git a/setup.py b/setup.py index 54ba910..03dc703 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.58', + version = '0.0.59', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',