make sure cpu-only still works

This commit is contained in:
Phil Wang
2022-04-27 08:02:10 -07:00
parent 2705e7c9b0
commit fa3bb6ba5c
2 changed files with 9 additions and 2 deletions

View File

@@ -35,6 +35,10 @@ def default(val, d):
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
@contextmanager
def null_context(*args, **kwargs):
yield
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
@@ -1382,7 +1386,10 @@ class Decoder(BaseGaussianDiffusion):
img = None
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
with self.one_unet_in_gpu(unet = unet):
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
with context:
lowres_cond_img = None
shape = (batch_size, channel, image_size, image_size)