diff --git a/dalle2_pytorch/cli.py b/dalle2_pytorch/cli.py index 9e89f0d..aca5267 100644 --- a/dalle2_pytorch/cli.py +++ b/dalle2_pytorch/cli.py @@ -1,9 +1,51 @@ import click +import torch +import torchvision.transforms as T +from pathlib import Path + +from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior + +def safeget(dictionary, keys, default = None): + return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary) + +def simple_slugify(text, max_length = 255): + return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length] + +def get_pkg_version(): + from pkg_resources import get_distribution + return get_distribution('dalle2_pytorch').version def main(): pass @click.command() +@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model') +@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder') @click.argument('text') -def dream(text): - return 'not ready yet' +def dream( + model, + cond_scale, + text +): + model_path = Path(model) + full_model_path = str(model_path.resolve()) + assert model_path.exists(), f'model not found at {full_model_path}' + loaded = torch.load(str(model_path)) + + version = safeget(loaded, 'version') + print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}') + + prior_init_params = safeget(loaded, 'init_params.prior') + decoder_init_params = safeget(loaded, 'init_params.decoder') + model_params = safeget(loaded, 'model_params') + + prior = DiffusionPrior(**prior_init_params) + decoder = Decoder(**decoder_init_params) + + dalle2 = DALLE2(prior, decoder) + dalle2.load_state_dict(model_params) + + image = dalle2(text, cond_scale = cond_scale) + + pil_image = T.ToPILImage()(image) + return pil_image.save(f'./{simple_slugify(text)}.png') diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index cf0870f..35f3d1a 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1451,6 +1451,7 @@ class DALLE2(nn.Module): cond_scale = 1. ): device = next(self.parameters()).device + one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1) if isinstance(text, str) or is_list_str(text): text = [text] if not isinstance(text, (list, tuple)) else text @@ -1458,4 +1459,8 @@ class DALLE2(nn.Module): image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples) images = self.decoder.sample(image_embed, cond_scale = cond_scale) + + if one_text: + return images[0] + return images