some outlines to the eventual CLI endpoint

This commit is contained in:
Phil Wang
2022-04-24 09:27:15 -07:00
parent d5318aef4f
commit 473808850a
2 changed files with 49 additions and 2 deletions

View File

@@ -1451,6 +1451,7 @@ class DALLE2(nn.Module):
cond_scale = 1.
):
device = next(self.parameters()).device
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
if isinstance(text, str) or is_list_str(text):
text = [text] if not isinstance(text, (list, tuple)) else text
@@ -1458,4 +1459,8 @@ class DALLE2(nn.Module):
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
if one_text:
return images[0]
return images