add ability for DALL-E2 to return PIL images with return_pil_images = True on forward, for those who have no clue about deep learning

This commit is contained in:
Phil Wang
2022-04-27 19:58:06 -07:00
parent 8c610aad9a
commit d0cdeb3247
2 changed files with 10 additions and 2 deletions

View File

@@ -7,6 +7,7 @@ from contextlib import contextmanager
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
import torchvision.transforms as T
from einops import rearrange, repeat from einops import rearrange, repeat
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
@@ -1516,12 +1517,15 @@ class DALLE2(nn.Module):
self.prior_num_samples = prior_num_samples self.prior_num_samples = prior_num_samples
self.decoder_need_text_cond = self.decoder.condition_on_text_encodings self.decoder_need_text_cond = self.decoder.condition_on_text_encodings
self.to_pil = T.ToPILImage()
@torch.no_grad() @torch.no_grad()
@eval_decorator @eval_decorator
def forward( def forward(
self, self,
text, text,
cond_scale = 1. cond_scale = 1.,
return_pil_images = False
): ):
device = next(self.parameters()).device device = next(self.parameters()).device
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1) one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
@@ -1535,7 +1539,11 @@ class DALLE2(nn.Module):
text_cond = text if self.decoder_need_text_cond else None text_cond = text if self.decoder_need_text_cond else None
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale) images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
if return_pil_images:
images = list(map(self.to_pil, images.unbind(dim = 0)))
if one_text: if one_text:
return images[0] return images[0]
return images return images

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.58', version = '0.0.59',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',