mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
make sure cpu-only still works
This commit is contained in:
@@ -35,6 +35,10 @@ def default(val, d):
|
|||||||
def cast_tuple(val, length = 1):
|
def cast_tuple(val, length = 1):
|
||||||
return val if isinstance(val, tuple) else ((val,) * length)
|
return val if isinstance(val, tuple) else ((val,) * length)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def null_context(*args, **kwargs):
|
||||||
|
yield
|
||||||
|
|
||||||
def eval_decorator(fn):
|
def eval_decorator(fn):
|
||||||
def inner(model, *args, **kwargs):
|
def inner(model, *args, **kwargs):
|
||||||
was_training = model.training
|
was_training = model.training
|
||||||
@@ -1382,7 +1386,10 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
img = None
|
img = None
|
||||||
|
|
||||||
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||||
with self.one_unet_in_gpu(unet = unet):
|
|
||||||
|
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
|
||||||
|
|
||||||
|
with context:
|
||||||
lowres_cond_img = None
|
lowres_cond_img = None
|
||||||
shape = (batch_size, channel, image_size, image_size)
|
shape = (batch_size, channel, image_size, image_size)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user