diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index faa7bc1..bdab5d3 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1,7 +1,6 @@ import math import random from tqdm import tqdm -from inspect import isfunction from functools import partial, wraps from contextlib import contextmanager from collections import namedtuple @@ -57,7 +56,7 @@ def maybe(fn): def default(val, d): if exists(val): return val - return d() if isfunction(d) else d + return d() if callable(d) else d def cast_tuple(val, length = 1): if isinstance(val, list): @@ -314,11 +313,6 @@ def extract(a, t, x_shape): out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) -def noise_like(shape, device, repeat=False): - repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) - noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() - def meanflat(x): return x.mean(dim = tuple(range(1, len(x.shape)))) @@ -946,10 +940,10 @@ class DiffusionPrior(BaseGaussianDiffusion): return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() - def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.): + def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.): b, *_, device = *x.shape, x.device model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale) - noise = noise_like(x.shape, device, repeat_noise) + noise = torch.randn_like(x) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @@ -1956,10 +1950,10 @@ class Decoder(BaseGaussianDiffusion): return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() - def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False): + def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True): b, *_, device = *x.shape, x.device model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance) - noise = noise_like(x.shape, device, repeat_noise) + noise = torch.randn_like(x) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise