mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9440411954 | ||
|
|
981d407792 | ||
|
|
7c5477b26d | ||
|
|
be3bb868bf | ||
|
|
451de34871 | ||
|
|
f22e8c8741 | ||
|
|
87432e93ad |
23
README.md
23
README.md
@@ -627,6 +627,18 @@ images = dalle2(
|
||||
# save your image (in this example, of size 256x256)
|
||||
```
|
||||
|
||||
Alternatively, you can also use <a href="https://github.com/mlfoundations/open_clip">Open Clip</a>
|
||||
|
||||
```bash
|
||||
$ pip install open-clip-torch
|
||||
```
|
||||
|
||||
```python
|
||||
from dalle2_pytorch import OpenClipAdapter
|
||||
|
||||
clip = OpenClipAdapter()
|
||||
```
|
||||
|
||||
Now you'll just have to worry about training the Prior and the Decoder!
|
||||
|
||||
## Inpainting
|
||||
@@ -1241,4 +1253,15 @@ For detailed information on training the diffusion prior, please refer to the [d
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{chen2022analog,
|
||||
title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
|
||||
author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
|
||||
year = {2022},
|
||||
eprint = {2208.04202},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -8,6 +8,7 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from torch import nn, einsum
|
||||
import torchvision.transforms as T
|
||||
|
||||
@@ -108,6 +109,28 @@ def pad_tuple_to_length(t, length, fillvalue = None):
|
||||
return t
|
||||
return (*t, *((fillvalue,) * remain_length))
|
||||
|
||||
# checkpointing helper function
|
||||
|
||||
def make_checkpointable(fn, **kwargs):
|
||||
if isinstance(fn, nn.ModuleList):
|
||||
return [maybe(make_checkpointable)(el, **kwargs) for el in fn]
|
||||
|
||||
condition = kwargs.pop('condition', None)
|
||||
|
||||
if exists(condition) and not condition(fn):
|
||||
return fn
|
||||
|
||||
@wraps(fn)
|
||||
def inner(*args):
|
||||
input_needs_grad = any([isinstance(el, torch.Tensor) and el.requires_grad for el in args])
|
||||
|
||||
if not input_needs_grad:
|
||||
return fn(*args)
|
||||
|
||||
return checkpoint(fn, *args)
|
||||
|
||||
return inner
|
||||
|
||||
# for controlling freezing of CLIP
|
||||
|
||||
def set_module_requires_grad_(module, requires_grad):
|
||||
@@ -339,6 +362,75 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||
|
||||
class OpenClipAdapter(BaseClipAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
name = 'ViT-B/32',
|
||||
pretrained = 'laion400m_e32'
|
||||
):
|
||||
import open_clip
|
||||
clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained)
|
||||
|
||||
super().__init__(clip)
|
||||
self.eos_id = 49407
|
||||
|
||||
text_attention_final = self.find_layer('ln_final')
|
||||
self.handle = text_attention_final.register_forward_hook(self._hook)
|
||||
self.clip_normalize = preprocess.transforms[-1]
|
||||
self.cleared = False
|
||||
|
||||
def find_layer(self, layer):
|
||||
modules = dict([*self.clip.named_modules()])
|
||||
return modules.get(layer, None)
|
||||
|
||||
def clear(self):
|
||||
if self.cleared:
|
||||
return
|
||||
|
||||
self.handle()
|
||||
|
||||
def _hook(self, _, inputs, outputs):
|
||||
self.text_encodings = outputs
|
||||
|
||||
@property
|
||||
def dim_latent(self):
|
||||
return 512
|
||||
|
||||
@property
|
||||
def image_size(self):
|
||||
return self.clip.visual.image_size
|
||||
|
||||
@property
|
||||
def image_channels(self):
|
||||
return 3
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
return self.clip.context_length
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_text(self, text):
|
||||
text = text[..., :self.max_text_len]
|
||||
|
||||
is_eos_id = (text == self.eos_id)
|
||||
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
|
||||
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
|
||||
assert not self.cleared
|
||||
|
||||
text_embed = self.clip.encode_text(text)
|
||||
text_encodings = self.text_encodings
|
||||
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
||||
del self.text_encodings
|
||||
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
assert not self.cleared
|
||||
image = self.validate_and_resize_image(image)
|
||||
image = self.clip_normalize(image)
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||
|
||||
# classifier free guidance functions
|
||||
|
||||
def prob_mask_like(shape, prob, device):
|
||||
@@ -778,7 +870,7 @@ class Attention(nn.Module):
|
||||
|
||||
# attention
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
@@ -845,9 +937,12 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
num_image_embeds = 1,
|
||||
num_text_embeds = 1,
|
||||
max_text_len = 256,
|
||||
self_cond = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
self.num_time_embeds = num_time_embeds
|
||||
self.num_image_embeds = num_image_embeds
|
||||
self.num_text_embeds = num_text_embeds
|
||||
@@ -875,6 +970,10 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
self.max_text_len = max_text_len
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
|
||||
|
||||
# whether to use self conditioning, Hinton's group's new ddpm technique
|
||||
|
||||
self.self_cond = self_cond
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
*args,
|
||||
@@ -896,12 +995,19 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
*,
|
||||
text_embed,
|
||||
text_encodings = None,
|
||||
self_cond = None,
|
||||
cond_drop_prob = 0.
|
||||
):
|
||||
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
||||
|
||||
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
|
||||
|
||||
# setup self conditioning
|
||||
|
||||
self_cond = None
|
||||
if self.self_cond:
|
||||
self_cond = default(self_cond, lambda: torch.zeros(batch, 1, self.dim, device = device, dtype = dtype))
|
||||
|
||||
# in section 2.2, last paragraph
|
||||
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
||||
|
||||
@@ -951,13 +1057,16 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||
# but let's just do it right
|
||||
|
||||
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
|
||||
attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds
|
||||
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
time_embed = self.to_time_embeds(diffusion_timesteps)
|
||||
|
||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||
|
||||
if self.self_cond:
|
||||
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
|
||||
|
||||
tokens = torch.cat((
|
||||
text_encodings,
|
||||
text_embed,
|
||||
@@ -1059,45 +1168,50 @@ class DiffusionPrior(nn.Module):
|
||||
def l2norm_clamp_embed(self, image_embed):
|
||||
return l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
|
||||
def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = False, cond_scale = 1.):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)
|
||||
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
x_recon = pred
|
||||
x_start = pred
|
||||
else:
|
||||
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised and not self.predict_x_start:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
x_start.clamp_(-1., 1.)
|
||||
|
||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||
x_recon = l2norm(x_recon) * self.image_embed_scale
|
||||
x_start = l2norm(x_start) * self.image_embed_scale
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance, x_start
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
|
||||
def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||
noise = torch.randn_like(x)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
return pred, x_start
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
|
||||
batch, device = shape[0], self.device
|
||||
|
||||
image_embed = torch.randn(shape, device = device)
|
||||
x_start = None # for self-conditioning
|
||||
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
|
||||
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
||||
|
||||
self_cond = x_start if self.net.self_cond else None
|
||||
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale)
|
||||
|
||||
if self.sampling_final_clamp_l2norm and self.predict_x_start:
|
||||
image_embed = self.l2norm_clamp_embed(image_embed)
|
||||
@@ -1115,6 +1229,8 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
image_embed = torch.randn(shape, device = device)
|
||||
|
||||
x_start = None # for self-conditioning
|
||||
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
@@ -1124,7 +1240,9 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||
|
||||
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
|
||||
self_cond = x_start if self.net.self_cond else None
|
||||
|
||||
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
x_start = pred
|
||||
@@ -1168,9 +1286,15 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
|
||||
|
||||
self_cond = None
|
||||
if self.net.self_cond and random.random() < 0.5:
|
||||
with torch.no_grad():
|
||||
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
|
||||
|
||||
pred = self.net(
|
||||
image_embed_noisy,
|
||||
times,
|
||||
self_cond = self_cond,
|
||||
cond_drop_prob = self.cond_drop_prob,
|
||||
**text_cond
|
||||
)
|
||||
@@ -1479,7 +1603,7 @@ class CrossAttention(nn.Module):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
sim = sim.masked_fill(~mask, max_neg_value)
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
@@ -1490,7 +1614,8 @@ class LinearAttention(nn.Module):
|
||||
self,
|
||||
dim,
|
||||
dim_head = 32,
|
||||
heads = 8
|
||||
heads = 8,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
@@ -1607,6 +1732,7 @@ class Unet(nn.Module):
|
||||
attn_heads = 16,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
|
||||
self_cond = False, # set this to True to use the self-conditioning technique from - https://arxiv.org/abs/2208.04202
|
||||
sparse_attn = False,
|
||||
cosine_sim_cross_attn = False,
|
||||
cosine_sim_self_attn = False,
|
||||
@@ -1628,6 +1754,7 @@ class Unet(nn.Module):
|
||||
pixel_shuffle_upsample = True,
|
||||
final_conv_kernel_size = 1,
|
||||
combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
|
||||
checkpoint_during_training = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1641,12 +1768,21 @@ class Unet(nn.Module):
|
||||
|
||||
self.lowres_cond = lowres_cond
|
||||
|
||||
# whether to do self conditioning
|
||||
|
||||
self.self_cond = self_cond
|
||||
|
||||
# determine dimensions
|
||||
|
||||
self.channels = channels
|
||||
self.channels_out = default(channels_out, channels)
|
||||
|
||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||
# initial number of channels depends on
|
||||
# (1) low resolution conditioning from cascading ddpm paper, conditioned on previous unet output in the cascade
|
||||
# (2) self conditioning (bit diffusion paper)
|
||||
|
||||
init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
|
||||
|
||||
init_dim = default(init_dim, dim)
|
||||
|
||||
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
||||
@@ -1838,6 +1974,10 @@ class Unet(nn.Module):
|
||||
|
||||
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
|
||||
|
||||
# whether to checkpoint during training
|
||||
|
||||
self.checkpoint_during_training = checkpoint_during_training
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
# for cascading DDPM, then reinit the unet with the right settings
|
||||
def cast_model_parameters(
|
||||
@@ -1895,7 +2035,9 @@ class Unet(nn.Module):
|
||||
image_cond_drop_prob = 0.,
|
||||
text_cond_drop_prob = 0.,
|
||||
blur_sigma = None,
|
||||
blur_kernel_size = None
|
||||
blur_kernel_size = None,
|
||||
disable_checkpoint = False,
|
||||
self_cond = None
|
||||
):
|
||||
batch_size, device = x.shape[0], x.device
|
||||
|
||||
@@ -1903,6 +2045,14 @@ class Unet(nn.Module):
|
||||
|
||||
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
|
||||
|
||||
# concat self conditioning, if needed
|
||||
|
||||
if self.self_cond:
|
||||
self_cond = default(self_cond, lambda: torch.zeros_like(x))
|
||||
x = torch.cat((x, self_cond), dim = 1)
|
||||
|
||||
# concat low resolution conditioning
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||
|
||||
@@ -2017,17 +2167,29 @@ class Unet(nn.Module):
|
||||
c = self.norm_cond(c)
|
||||
mid_c = self.norm_mid_cond(mid_c)
|
||||
|
||||
# gradient checkpointing
|
||||
|
||||
can_checkpoint = self.training and self.checkpoint_during_training and not disable_checkpoint
|
||||
apply_checkpoint_fn = make_checkpointable if can_checkpoint else identity
|
||||
|
||||
# make checkpointable modules
|
||||
|
||||
init_resnet_block, mid_block1, mid_attn, mid_block2, final_resnet_block = [maybe(apply_checkpoint_fn)(module) for module in (self.init_resnet_block, self.mid_block1, self.mid_attn, self.mid_block2, self.final_resnet_block)]
|
||||
|
||||
can_checkpoint_cond = lambda m: isinstance(m, ResnetBlock)
|
||||
downs, ups = [maybe(apply_checkpoint_fn)(m, condition = can_checkpoint_cond) for m in (self.downs, self.ups)]
|
||||
|
||||
# initial resnet block
|
||||
|
||||
if exists(self.init_resnet_block):
|
||||
x = self.init_resnet_block(x, t)
|
||||
if exists(init_resnet_block):
|
||||
x = init_resnet_block(x, t)
|
||||
|
||||
# go through the layers of the unet, down and up
|
||||
|
||||
down_hiddens = []
|
||||
up_hiddens = []
|
||||
|
||||
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
|
||||
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in downs:
|
||||
if exists(pre_downsample):
|
||||
x = pre_downsample(x)
|
||||
|
||||
@@ -2043,16 +2205,16 @@ class Unet(nn.Module):
|
||||
if exists(post_downsample):
|
||||
x = post_downsample(x)
|
||||
|
||||
x = self.mid_block1(x, t, mid_c)
|
||||
x = mid_block1(x, t, mid_c)
|
||||
|
||||
if exists(self.mid_attn):
|
||||
x = self.mid_attn(x)
|
||||
if exists(mid_attn):
|
||||
x = mid_attn(x)
|
||||
|
||||
x = self.mid_block2(x, t, mid_c)
|
||||
x = mid_block2(x, t, mid_c)
|
||||
|
||||
connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)
|
||||
|
||||
for init_block, resnet_blocks, attn, upsample in self.ups:
|
||||
for init_block, resnet_blocks, attn, upsample in ups:
|
||||
x = connect_skip(x)
|
||||
x = init_block(x, t, c)
|
||||
|
||||
@@ -2069,7 +2231,7 @@ class Unet(nn.Module):
|
||||
|
||||
x = torch.cat((x, r), dim = 1)
|
||||
|
||||
x = self.final_resnet_block(x, t)
|
||||
x = final_resnet_block(x, t)
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||
@@ -2460,23 +2622,23 @@ class Decoder(nn.Module):
|
||||
x = x.clamp(-s, s) / s
|
||||
return x
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level))
|
||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
|
||||
|
||||
if learned_variance:
|
||||
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
|
||||
|
||||
if predict_x_start:
|
||||
x_recon = pred
|
||||
x_start = pred
|
||||
else:
|
||||
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised:
|
||||
x_recon = self.dynamic_threshold(x_recon)
|
||||
x_start = self.dynamic_threshold(x_start)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
|
||||
|
||||
if learned_variance:
|
||||
# if learned variance, posterio variance and posterior log variance are predicted by the network
|
||||
@@ -2492,16 +2654,17 @@ class Decoder(nn.Module):
|
||||
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
||||
posterior_variance = posterior_log_variance.exp()
|
||||
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
return model_mean, posterior_variance, posterior_log_variance, x_start
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
||||
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
||||
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
||||
noise = torch.randn_like(x)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
return pred, x_start
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddpm(
|
||||
@@ -2527,6 +2690,8 @@ class Decoder(nn.Module):
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
x_start = None # for self-conditioning
|
||||
|
||||
is_inpaint = exists(inpaint_image)
|
||||
resample_times = inpaint_resample_times if is_inpaint else 1
|
||||
|
||||
@@ -2554,13 +2719,16 @@ class Decoder(nn.Module):
|
||||
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
|
||||
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
||||
|
||||
img = self.p_sample(
|
||||
self_cond = x_start if unet.self_cond else None
|
||||
|
||||
img, x_start = self.p_sample(
|
||||
unet,
|
||||
img,
|
||||
times,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
cond_scale = cond_scale,
|
||||
self_cond = self_cond,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
lowres_noise_level = lowres_noise_level,
|
||||
predict_x_start = predict_x_start,
|
||||
@@ -2619,6 +2787,8 @@ class Decoder(nn.Module):
|
||||
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
x_start = None # for self-conditioning
|
||||
|
||||
if not is_latent_diffusion:
|
||||
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||
|
||||
@@ -2639,7 +2809,9 @@ class Decoder(nn.Module):
|
||||
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
|
||||
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
||||
|
||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
||||
self_cond = x_start if unet.self_cond else None
|
||||
|
||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
||||
|
||||
if learned_variance:
|
||||
pred, _ = pred.chunk(2, dim = 1)
|
||||
@@ -2699,13 +2871,35 @@ class Decoder(nn.Module):
|
||||
|
||||
x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
|
||||
|
||||
model_output = unet(
|
||||
x_noisy,
|
||||
times,
|
||||
# unet kwargs
|
||||
|
||||
unet_kwargs = dict(
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
lowres_noise_level = lowres_noise_level,
|
||||
)
|
||||
|
||||
# self conditioning
|
||||
|
||||
self_cond = None
|
||||
|
||||
if unet.self_cond and random.random() < 0.5:
|
||||
with torch.no_grad():
|
||||
self_cond = unet(x_noisy, times, **unet_kwargs)
|
||||
|
||||
if learned_variance:
|
||||
self_cond, _ = self_cond.chunk(2, dim = 1)
|
||||
|
||||
self_cond = self_cond.detach()
|
||||
|
||||
# forward to get model prediction
|
||||
|
||||
model_output = unet(
|
||||
x_noisy,
|
||||
times,
|
||||
**unet_kwargs,
|
||||
self_cond = self_cond,
|
||||
image_cond_drop_prob = self.image_cond_drop_prob,
|
||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||
)
|
||||
@@ -2736,7 +2930,7 @@ class Decoder(nn.Module):
|
||||
# if learning the variance, also include the extra weight kl loss
|
||||
|
||||
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
||||
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
||||
|
||||
# kl loss with detached model predicted mean, for stability reasons as in paper
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '1.4.3'
|
||||
__version__ = '1.6.1'
|
||||
|
||||
Reference in New Issue
Block a user