make sure cpu-only still works

This commit is contained in:
Phil Wang
2022-04-27 08:02:10 -07:00
parent 2705e7c9b0
commit fa3bb6ba5c
2 changed files with 9 additions and 2 deletions

View File

@@ -35,6 +35,10 @@ def default(val, d):
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
@contextmanager
def null_context(*args, **kwargs):
yield
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
@@ -1382,7 +1386,10 @@ class Decoder(BaseGaussianDiffusion):
img = None
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
with self.one_unet_in_gpu(unet = unet):
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
with context:
lowres_cond_img = None
shape = (batch_size, channel, image_size, image_size)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.54',
version = '0.0.55',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',