diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 3d87c82..927d38b 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -35,6 +35,10 @@ def default(val, d): def cast_tuple(val, length = 1): return val if isinstance(val, tuple) else ((val,) * length) +@contextmanager +def null_context(*args, **kwargs): + yield + def eval_decorator(fn): def inner(model, *args, **kwargs): was_training = model.training @@ -1382,7 +1386,10 @@ class Decoder(BaseGaussianDiffusion): img = None for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)): - with self.one_unet_in_gpu(unet = unet): + + context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context() + + with context: lowres_cond_img = None shape = (batch_size, channel, image_size, image_size) diff --git a/setup.py b/setup.py index 316375e..ffd1ac5 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.54', + version = '0.0.55', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',