mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
optimize for clarity
This commit is contained in:
@@ -223,7 +223,18 @@ class BaseGaussianDiffusion(nn.Module):
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
|
||||
if loss_type == 'l1':
|
||||
loss_fn = F.l1_loss
|
||||
elif loss_type == 'l2':
|
||||
loss_fn = F.mse_loss
|
||||
elif loss_type == 'huber':
|
||||
loss_fn = F.smooth_l1_loss
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.loss_type = loss_type
|
||||
self.loss_fn = loss_fn
|
||||
|
||||
self.register_buffer('betas', betas)
|
||||
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
||||
@@ -703,29 +714,21 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
||||
return img
|
||||
|
||||
def p_losses(self, image_embed, t, text_cond, noise = None):
|
||||
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||
|
||||
image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise)
|
||||
image_embed_noisy = self.q_sample(x_start = image_embed, t = times, noise = noise)
|
||||
|
||||
x_recon = self.net(
|
||||
pred = self.net(
|
||||
image_embed_noisy,
|
||||
t,
|
||||
times,
|
||||
cond_drop_prob = self.cond_drop_prob,
|
||||
**text_cond
|
||||
)
|
||||
|
||||
to_predict = noise if not self.predict_x_start else image_embed
|
||||
|
||||
if self.loss_type == 'l1':
|
||||
loss = F.l1_loss(to_predict, x_recon)
|
||||
elif self.loss_type == 'l2':
|
||||
loss = F.mse_loss(to_predict, x_recon)
|
||||
elif self.loss_type == "huber":
|
||||
loss = F.smooth_l1_loss(to_predict, x_recon)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
target = noise if not self.predict_x_start else image_embed
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -1388,14 +1391,14 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
return img
|
||||
|
||||
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
||||
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
|
||||
|
||||
x_recon = unet(
|
||||
pred = unet(
|
||||
x_noisy,
|
||||
t,
|
||||
times,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
@@ -1404,15 +1407,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
target = noise if not predict_x_start else x_start
|
||||
|
||||
if self.loss_type == 'l1':
|
||||
loss = F.l1_loss(target, x_recon)
|
||||
elif self.loss_type == 'l2':
|
||||
loss = F.mse_loss(target, x_recon)
|
||||
elif self.loss_type == "huber":
|
||||
loss = F.smooth_l1_loss(target, x_recon)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
Reference in New Issue
Block a user