mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
DRY a tiny bit for gaussian diffusion related logic
This commit is contained in:
@@ -643,7 +643,8 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
|
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
|
||||||
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
||||||
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
|
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
|
||||||
- [ ] abstract interface for CLIP adapter class, so other CLIPs can be brought in - use inheritance just this once for sharing logic between decoder and prior network ddpms
|
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
|
||||||
|
- [ ] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||||
|
|||||||
@@ -143,6 +143,92 @@ def sigmoid_beta_schedule(timesteps):
|
|||||||
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
||||||
|
|
||||||
|
|
||||||
|
class BaseGaussianDiffusion(nn.Module):
|
||||||
|
def __init__(self, *, beta_schedule, timesteps, loss_type):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if beta_schedule == "cosine":
|
||||||
|
betas = cosine_beta_schedule(timesteps)
|
||||||
|
elif beta_schedule == "linear":
|
||||||
|
betas = linear_beta_schedule(timesteps)
|
||||||
|
elif beta_schedule == "quadratic":
|
||||||
|
betas = quadratic_beta_schedule(timesteps)
|
||||||
|
elif beta_schedule == "jsd":
|
||||||
|
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
||||||
|
elif beta_schedule == "sigmoid":
|
||||||
|
betas = sigmoid_beta_schedule(timesteps)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
alphas = 1. - betas
|
||||||
|
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
||||||
|
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||||
|
|
||||||
|
timesteps, = betas.shape
|
||||||
|
self.num_timesteps = int(timesteps)
|
||||||
|
self.loss_type = loss_type
|
||||||
|
|
||||||
|
self.register_buffer('betas', betas)
|
||||||
|
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
||||||
|
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
|
||||||
|
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
||||||
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
||||||
|
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
||||||
|
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
||||||
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
||||||
|
|
||||||
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||||
|
|
||||||
|
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
||||||
|
|
||||||
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||||
|
|
||||||
|
self.register_buffer('posterior_variance', posterior_variance)
|
||||||
|
|
||||||
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||||
|
|
||||||
|
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
||||||
|
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||||
|
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||||
|
|
||||||
|
def q_mean_variance(self, x_start, t):
|
||||||
|
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||||
|
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
||||||
|
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
||||||
|
return mean, variance, log_variance
|
||||||
|
|
||||||
|
def q_posterior(self, x_start, x_t, t):
|
||||||
|
posterior_mean = (
|
||||||
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||||
|
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||||
|
)
|
||||||
|
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
||||||
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||||
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||||
|
|
||||||
|
def q_sample(self, x_start, t, noise=None):
|
||||||
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
|
return (
|
||||||
|
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||||
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict_start_from_noise(self, x_t, t, noise):
|
||||||
|
return (
|
||||||
|
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
||||||
|
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
# diffusion prior
|
# diffusion prior
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
@@ -481,7 +567,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
|
|
||||||
return pred_image_embed
|
return pred_image_embed
|
||||||
|
|
||||||
class DiffusionPrior(nn.Module):
|
class DiffusionPrior(BaseGaussianDiffusion):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
net,
|
net,
|
||||||
@@ -497,7 +583,11 @@ class DiffusionPrior(nn.Module):
|
|||||||
beta_schedule = "cosine",
|
beta_schedule = "cosine",
|
||||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(
|
||||||
|
beta_schedule = beta_schedule,
|
||||||
|
timesteps = timesteps,
|
||||||
|
loss_type = loss_type
|
||||||
|
)
|
||||||
|
|
||||||
if exists(clip):
|
if exists(clip):
|
||||||
assert isinstance(clip, CLIP)
|
assert isinstance(clip, CLIP)
|
||||||
@@ -517,53 +607,6 @@ class DiffusionPrior(nn.Module):
|
|||||||
self.predict_x_start = predict_x_start
|
self.predict_x_start = predict_x_start
|
||||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||||
|
|
||||||
if beta_schedule == "cosine":
|
|
||||||
betas = cosine_beta_schedule(timesteps)
|
|
||||||
elif beta_schedule == "linear":
|
|
||||||
betas = linear_beta_schedule(timesteps)
|
|
||||||
elif beta_schedule == "quadratic":
|
|
||||||
betas = quadratic_beta_schedule(timesteps)
|
|
||||||
elif beta_schedule == "jsd":
|
|
||||||
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
|
||||||
elif beta_schedule == "sigmoid":
|
|
||||||
betas = sigmoid_beta_schedule(timesteps)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
alphas = 1. - betas
|
|
||||||
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
|
||||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
|
||||||
|
|
||||||
timesteps, = betas.shape
|
|
||||||
self.num_timesteps = int(timesteps)
|
|
||||||
self.loss_type = loss_type
|
|
||||||
|
|
||||||
self.register_buffer('betas', betas)
|
|
||||||
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
|
||||||
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
|
||||||
|
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
||||||
|
|
||||||
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
|
||||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
|
||||||
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
|
||||||
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
|
||||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
|
||||||
|
|
||||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
|
||||||
|
|
||||||
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
|
||||||
|
|
||||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
|
||||||
|
|
||||||
self.register_buffer('posterior_variance', posterior_variance)
|
|
||||||
|
|
||||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
|
||||||
|
|
||||||
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
|
||||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
|
||||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_image_embed(self, image):
|
def get_image_embed(self, image):
|
||||||
assert exists(self.clip)
|
assert exists(self.clip)
|
||||||
@@ -587,27 +630,6 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
|
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
|
||||||
|
|
||||||
def q_mean_variance(self, x_start, t):
|
|
||||||
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
|
||||||
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
|
||||||
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
|
||||||
return mean, variance, log_variance
|
|
||||||
|
|
||||||
def predict_start_from_noise(self, x_t, t, noise):
|
|
||||||
return (
|
|
||||||
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
|
||||||
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
|
||||||
)
|
|
||||||
|
|
||||||
def q_posterior(self, x_start, x_t, t):
|
|
||||||
posterior_mean = (
|
|
||||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
|
||||||
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
|
||||||
)
|
|
||||||
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
|
||||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
|
||||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
||||||
|
|
||||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||||
pred = self.net(x, t, **text_cond)
|
pred = self.net(x, t, **text_cond)
|
||||||
|
|
||||||
@@ -644,14 +666,6 @@ class DiffusionPrior(nn.Module):
|
|||||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def q_sample(self, x_start, t, noise=None):
|
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
||||||
|
|
||||||
return (
|
|
||||||
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
|
||||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
|
||||||
)
|
|
||||||
|
|
||||||
def p_losses(self, image_embed, t, text_cond, noise = None):
|
def p_losses(self, image_embed, t, text_cond, noise = None):
|
||||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||||
|
|
||||||
@@ -1164,7 +1178,7 @@ class LowresConditioner(nn.Module):
|
|||||||
|
|
||||||
return cond_fmap
|
return cond_fmap
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(BaseGaussianDiffusion):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
unet,
|
unet,
|
||||||
@@ -1184,7 +1198,12 @@ class Decoder(nn.Module):
|
|||||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(
|
||||||
|
beta_schedule = beta_schedule,
|
||||||
|
timesteps = timesteps,
|
||||||
|
loss_type = loss_type
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(clip, CLIP)
|
assert isinstance(clip, CLIP)
|
||||||
freeze_model_and_make_eval_(clip)
|
freeze_model_and_make_eval_(clip)
|
||||||
self.clip = clip
|
self.clip = clip
|
||||||
@@ -1248,55 +1267,6 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
self.cond_drop_prob = cond_drop_prob
|
self.cond_drop_prob = cond_drop_prob
|
||||||
|
|
||||||
# noise schedule
|
|
||||||
|
|
||||||
if beta_schedule == "cosine":
|
|
||||||
betas = cosine_beta_schedule(timesteps)
|
|
||||||
elif beta_schedule == "linear":
|
|
||||||
betas = linear_beta_schedule(timesteps)
|
|
||||||
elif beta_schedule == "quadratic":
|
|
||||||
betas = quadratic_beta_schedule(timesteps)
|
|
||||||
elif beta_schedule == "jsd":
|
|
||||||
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
|
||||||
elif beta_schedule == "sigmoid":
|
|
||||||
betas = sigmoid_beta_schedule(timesteps)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
alphas = 1. - betas
|
|
||||||
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
|
||||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
|
||||||
|
|
||||||
timesteps, = betas.shape
|
|
||||||
self.num_timesteps = int(timesteps)
|
|
||||||
self.loss_type = loss_type
|
|
||||||
|
|
||||||
self.register_buffer('betas', betas)
|
|
||||||
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
|
||||||
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
|
||||||
|
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
||||||
|
|
||||||
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
|
||||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
|
||||||
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
|
||||||
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
|
||||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
|
||||||
|
|
||||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
|
||||||
|
|
||||||
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
|
||||||
|
|
||||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
|
||||||
|
|
||||||
self.register_buffer('posterior_variance', posterior_variance)
|
|
||||||
|
|
||||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
|
||||||
|
|
||||||
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
|
||||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
|
||||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
|
||||||
|
|
||||||
def get_unet(self, unet_number):
|
def get_unet(self, unet_number):
|
||||||
assert 0 < unet_number <= len(self.unets)
|
assert 0 < unet_number <= len(self.unets)
|
||||||
index = unet_number - 1
|
index = unet_number - 1
|
||||||
@@ -1329,27 +1299,6 @@ class Decoder(nn.Module):
|
|||||||
image_embed = self.clip.to_visual_latent(image_cls)
|
image_embed = self.clip.to_visual_latent(image_cls)
|
||||||
return l2norm(image_embed)
|
return l2norm(image_embed)
|
||||||
|
|
||||||
def q_mean_variance(self, x_start, t):
|
|
||||||
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
|
||||||
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
|
||||||
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
|
||||||
return mean, variance, log_variance
|
|
||||||
|
|
||||||
def predict_start_from_noise(self, x_t, t, noise):
|
|
||||||
return (
|
|
||||||
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
|
||||||
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
|
||||||
)
|
|
||||||
|
|
||||||
def q_posterior(self, x_start, x_t, t):
|
|
||||||
posterior_mean = (
|
|
||||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
|
||||||
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
|
||||||
)
|
|
||||||
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
|
||||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
|
||||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
||||||
|
|
||||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||||
|
|
||||||
@@ -1394,14 +1343,6 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def q_sample(self, x_start, t, noise=None):
|
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
||||||
|
|
||||||
return (
|
|
||||||
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
|
||||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
|
||||||
)
|
|
||||||
|
|
||||||
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user