mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-15 16:24:20 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fb662a62f3 | ||
|
|
587c8c9b44 | ||
|
|
aa900213e7 |
@@ -223,7 +223,18 @@ class BaseGaussianDiffusion(nn.Module):
|
|||||||
|
|
||||||
timesteps, = betas.shape
|
timesteps, = betas.shape
|
||||||
self.num_timesteps = int(timesteps)
|
self.num_timesteps = int(timesteps)
|
||||||
|
|
||||||
|
if loss_type == 'l1':
|
||||||
|
loss_fn = F.l1_loss
|
||||||
|
elif loss_type == 'l2':
|
||||||
|
loss_fn = F.mse_loss
|
||||||
|
elif loss_type == 'huber':
|
||||||
|
loss_fn = F.smooth_l1_loss
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
self.loss_type = loss_type
|
self.loss_type = loss_type
|
||||||
|
self.loss_fn = loss_fn
|
||||||
|
|
||||||
self.register_buffer('betas', betas)
|
self.register_buffer('betas', betas)
|
||||||
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
||||||
@@ -703,29 +714,21 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def p_losses(self, image_embed, t, text_cond, noise = None):
|
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||||
|
|
||||||
image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise)
|
image_embed_noisy = self.q_sample(x_start = image_embed, t = times, noise = noise)
|
||||||
|
|
||||||
x_recon = self.net(
|
pred = self.net(
|
||||||
image_embed_noisy,
|
image_embed_noisy,
|
||||||
t,
|
times,
|
||||||
cond_drop_prob = self.cond_drop_prob,
|
cond_drop_prob = self.cond_drop_prob,
|
||||||
**text_cond
|
**text_cond
|
||||||
)
|
)
|
||||||
|
|
||||||
to_predict = noise if not self.predict_x_start else image_embed
|
target = noise if not self.predict_x_start else image_embed
|
||||||
|
|
||||||
if self.loss_type == 'l1':
|
|
||||||
loss = F.l1_loss(to_predict, x_recon)
|
|
||||||
elif self.loss_type == 'l2':
|
|
||||||
loss = F.mse_loss(to_predict, x_recon)
|
|
||||||
elif self.loss_type == "huber":
|
|
||||||
loss = F.smooth_l1_loss(to_predict, x_recon)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
loss = self.loss_fn(pred, target)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -1066,13 +1069,14 @@ class Unet(nn.Module):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
lowres_cond,
|
lowres_cond,
|
||||||
channels
|
channels,
|
||||||
|
cond_on_image_embeds
|
||||||
):
|
):
|
||||||
if lowres_cond == self.lowres_cond and channels == self.channels:
|
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels}
|
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds}
|
||||||
return self.__class__(**updated_kwargs)
|
return self.__class__(**{**self._locals, **updated_kwargs})
|
||||||
|
|
||||||
def forward_with_cond_scale(
|
def forward_with_cond_scale(
|
||||||
self,
|
self,
|
||||||
@@ -1209,7 +1213,7 @@ class LowresConditioner(nn.Module):
|
|||||||
target_image_size = cast_tuple(target_image_size, 2)
|
target_image_size = cast_tuple(target_image_size, 2)
|
||||||
|
|
||||||
if self.training and self.downsample_first and exists(downsample_image_size):
|
if self.training and self.downsample_first and exists(downsample_image_size):
|
||||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, mode = self.cond_upsample_mode)
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
# when training, blur the low resolution conditional image
|
# when training, blur the low resolution conditional image
|
||||||
@@ -1279,6 +1283,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
one_unet = one_unet.cast_model_parameters(
|
one_unet = one_unet.cast_model_parameters(
|
||||||
lowres_cond = not is_first,
|
lowres_cond = not is_first,
|
||||||
|
cond_on_image_embeds = is_first,
|
||||||
channels = unet_channels
|
channels = unet_channels
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1386,14 +1391,14 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
|
||||||
|
|
||||||
x_recon = unet(
|
pred = unet(
|
||||||
x_noisy,
|
x_noisy,
|
||||||
t,
|
times,
|
||||||
image_embed = image_embed,
|
image_embed = image_embed,
|
||||||
text_encodings = text_encodings,
|
text_encodings = text_encodings,
|
||||||
lowres_cond_img = lowres_cond_img,
|
lowres_cond_img = lowres_cond_img,
|
||||||
@@ -1402,15 +1407,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
target = noise if not predict_x_start else x_start
|
target = noise if not predict_x_start else x_start
|
||||||
|
|
||||||
if self.loss_type == 'l1':
|
loss = self.loss_fn(pred, target)
|
||||||
loss = F.l1_loss(target, x_recon)
|
|
||||||
elif self.loss_type == 'l2':
|
|
||||||
loss = F.mse_loss(target, x_recon)
|
|
||||||
elif self.loss_type == "huber":
|
|
||||||
loss = F.smooth_l1_loss(target, x_recon)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
Reference in New Issue
Block a user