mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
add ability for DALL-E2 to return PIL images with return_pil_images = True on forward, for those who have no clue about deep learning
This commit is contained in:
@@ -7,6 +7,7 @@ from contextlib import contextmanager
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
import torchvision.transforms as T
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
@@ -1516,12 +1517,15 @@ class DALLE2(nn.Module):
|
||||
self.prior_num_samples = prior_num_samples
|
||||
self.decoder_need_text_cond = self.decoder.condition_on_text_encodings
|
||||
|
||||
self.to_pil = T.ToPILImage()
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def forward(
|
||||
self,
|
||||
text,
|
||||
cond_scale = 1.
|
||||
cond_scale = 1.,
|
||||
return_pil_images = False
|
||||
):
|
||||
device = next(self.parameters()).device
|
||||
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
|
||||
@@ -1535,7 +1539,11 @@ class DALLE2(nn.Module):
|
||||
text_cond = text if self.decoder_need_text_cond else None
|
||||
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
|
||||
|
||||
if return_pil_images:
|
||||
images = list(map(self.to_pil, images.unbind(dim = 0)))
|
||||
|
||||
if one_text:
|
||||
return images[0]
|
||||
|
||||
return images
|
||||
|
||||
|
||||
Reference in New Issue
Block a user