diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 7b1ba76..30dc082 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -152,23 +152,24 @@ def do_sample( with autocast(device) as precision_scope: with model.ema_scope(): num_samples = [num_samples] - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - num_samples, - ) - for key in batch: - if isinstance(batch[key], torch.Tensor): - print(key, batch[key].shape) - elif isinstance(batch[key], list): - print(key, [len(l) for l in batch[key]]) - else: - print(key, batch[key]) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) + with ModelOnDevice(model.conditioner, device): + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + num_samples, + ) + for key in batch: + if isinstance(batch[key], torch.Tensor): + print(key, batch[key].shape) + elif isinstance(batch[key], list): + print(key, [len(l) for l in batch[key]]) + else: + print(key, batch[key]) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) for k in c: if not k == "crossattn": @@ -292,16 +293,17 @@ def do_img2img( with torch.no_grad(): with autocast(device): with model.ema_scope(): - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - [num_samples], - ) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) + with ModelOnDevice(model.conditioner, device): + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + [num_samples], + ) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) for k in c: c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) @@ -311,8 +313,11 @@ def do_img2img( if skip_encode: z = img else: - z = model.encode_first_stage(img) + with ModelOnDevice(model.first_stage_model, device): + z = model.encode_first_stage(img) + noise = torch.randn_like(z) + sigmas = sampler.discretization(sampler.num_steps) sigma = sigmas[0].to(z.device)