diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 02734ea..ad39dfc 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1,7 +1,7 @@ import math from tqdm import tqdm from inspect import isfunction -from functools import partial +from functools import partial, wraps from contextlib import contextmanager from collections import namedtuple from pathlib import Path @@ -45,6 +45,14 @@ def exists(val): def identity(t, *args, **kwargs): return t +def maybe(fn): + @wraps(fn) + def inner(x): + if not exists(x): + return x + return fn(x) + return inner + def default(val, d): if exists(val): return val @@ -1844,6 +1852,8 @@ class Decoder(BaseGaussianDiffusion): b = shape[0] img = torch.randn(shape, device = device) + lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img) + for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): img = self.p_sample( unet, @@ -1868,9 +1878,7 @@ class Decoder(BaseGaussianDiffusion): # normalize to [-1, 1] x_start = normalize_neg_one_to_one(x_start) - - if exists(lowres_cond_img): - lowres_cond_img = normalize_neg_one_to_one(lowres_cond_img) + lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img) # get x_t diff --git a/setup.py b/setup.py index a078830..4c9a312 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.16', + version = '0.2.17', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',