Compare commits

..

12 Commits
1.2.1 ... 1.6.3

5 changed files with 328 additions and 86 deletions

View File

@@ -396,7 +396,7 @@ decoder = Decoder(
).cuda() ).cuda()
for unet_number in (1, 2): for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward() loss.backward()
# do above for many steps # do above for many steps
@@ -627,6 +627,18 @@ images = dalle2(
# save your image (in this example, of size 256x256) # save your image (in this example, of size 256x256)
``` ```
Alternatively, you can also use <a href="https://github.com/mlfoundations/open_clip">Open Clip</a>
```bash
$ pip install open-clip-torch
```
```python
from dalle2_pytorch import OpenClipAdapter
clip = OpenClipAdapter()
```
Now you'll just have to worry about training the Prior and the Decoder! Now you'll just have to worry about training the Prior and the Decoder!
## Inpainting ## Inpainting
@@ -861,25 +873,23 @@ unet1 = Unet(
text_embed_dim = 512, text_embed_dim = 512,
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults=(1, 2, 4, 8) dim_mults=(1, 2, 4, 8),
cond_on_text_encodings = True,
).cuda() ).cuda()
unet2 = Unet( unet2 = Unet(
dim = 16, dim = 16,
image_embed_dim = 512, image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults = (1, 2, 4, 8, 16), dim_mults = (1, 2, 4, 8, 16),
cond_on_text_encodings = True
).cuda() ).cuda()
decoder = Decoder( decoder = Decoder(
unet = (unet1, unet2), unet = (unet1, unet2),
image_sizes = (128, 256), image_sizes = (128, 256),
clip = clip, clip = clip,
timesteps = 1000, timesteps = 1000
condition_on_text_encodings = True
).cuda() ).cuda()
decoder_trainer = DecoderTrainer( decoder_trainer = DecoderTrainer(
@@ -904,8 +914,8 @@ for unet_number in (1, 2):
# after much training # after much training
# you can sample from the exponentially moving averaged unets as so # you can sample from the exponentially moving averaged unets as so
mock_image_embed = torch.randn(4, 512).cuda() mock_image_embed = torch.randn(32, 512).cuda()
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256) images = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256)
``` ```
### Diffusion Prior Training ### Diffusion Prior Training
@@ -1243,4 +1253,15 @@ For detailed information on training the diffusion prior, please refer to the [d
} }
``` ```
```bibtex
@misc{chen2022analog,
title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
year = {2022},
eprint = {2208.04202},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a> *Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -8,6 +8,7 @@ from pathlib import Path
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torch import nn, einsum from torch import nn, einsum
import torchvision.transforms as T import torchvision.transforms as T
@@ -108,6 +109,28 @@ def pad_tuple_to_length(t, length, fillvalue = None):
return t return t
return (*t, *((fillvalue,) * remain_length)) return (*t, *((fillvalue,) * remain_length))
# checkpointing helper function
def make_checkpointable(fn, **kwargs):
if isinstance(fn, nn.ModuleList):
return [maybe(make_checkpointable)(el, **kwargs) for el in fn]
condition = kwargs.pop('condition', None)
if exists(condition) and not condition(fn):
return fn
@wraps(fn)
def inner(*args):
input_needs_grad = any([isinstance(el, torch.Tensor) and el.requires_grad for el in args])
if not input_needs_grad:
return fn(*args)
return checkpoint(fn, *args)
return inner
# for controlling freezing of CLIP # for controlling freezing of CLIP
def set_module_requires_grad_(module, requires_grad): def set_module_requires_grad_(module, requires_grad):
@@ -339,6 +362,75 @@ class OpenAIClipAdapter(BaseClipAdapter):
image_embed = self.clip.encode_image(image) image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None) return EmbeddedImage(l2norm(image_embed.float()), None)
class OpenClipAdapter(BaseClipAdapter):
def __init__(
self,
name = 'ViT-B/32',
pretrained = 'laion400m_e32'
):
import open_clip
clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained)
super().__init__(clip)
self.eos_id = 49407
text_attention_final = self.find_layer('ln_final')
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = preprocess.transforms[-1]
self.cleared = False
def find_layer(self, layer):
modules = dict([*self.clip.named_modules()])
return modules.get(layer, None)
def clear(self):
if self.cleared:
return
self.handle()
def _hook(self, _, inputs, outputs):
self.text_encodings = outputs
@property
def dim_latent(self):
return 512
@property
def image_size(self):
return self.clip.visual.image_size
@property
def image_channels(self):
return 3
@property
def max_text_len(self):
return self.clip.context_length
@torch.no_grad()
def embed_text(self, text):
text = text[..., :self.max_text_len]
is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
assert not self.cleared
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
del self.text_encodings
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())
@torch.no_grad()
def embed_image(self, image):
assert not self.cleared
image = self.validate_and_resize_image(image)
image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None)
# classifier free guidance functions # classifier free guidance functions
def prob_mask_like(shape, prob, device): def prob_mask_like(shape, prob, device):
@@ -547,34 +639,40 @@ class NoiseScheduler(nn.Module):
# diffusion prior # diffusion prior
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5, stable = False): def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.fp16_eps = fp16_eps
self.stable = stable self.stable = stable
self.g = nn.Parameter(torch.ones(dim)) self.g = nn.Parameter(torch.ones(dim))
def forward(self, x): def forward(self, x):
eps = self.eps if x.dtype == torch.float32 else self.fp16_eps
if self.stable: if self.stable:
x = x / x.amax(dim = -1, keepdim = True).detach() x = x / x.amax(dim = -1, keepdim = True).detach()
var = torch.var(x, dim = -1, unbiased = False, keepdim = True) var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True) mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g return (x - mean) * (var + eps).rsqrt() * self.g
class ChanLayerNorm(nn.Module): class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5, stable = False): def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.fp16_eps = fp16_eps
self.stable = stable self.stable = stable
self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x): def forward(self, x):
eps = self.eps if x.dtype == torch.float32 else self.fp16_eps
if self.stable: if self.stable:
x = x / x.amax(dim = 1, keepdim = True).detach() x = x / x.amax(dim = 1, keepdim = True).detach()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True) var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True) mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g return (x - mean) * (var + eps).rsqrt() * self.g
class Residual(nn.Module): class Residual(nn.Module):
def __init__(self, fn): def __init__(self, fn):
@@ -695,11 +793,12 @@ class Attention(nn.Module):
dropout = 0., dropout = 0.,
causal = False, causal = False,
rotary_emb = None, rotary_emb = None,
pb_relax_alpha = 128 cosine_sim = True,
cosine_sim_scale = 16
): ):
super().__init__() super().__init__()
self.pb_relax_alpha = pb_relax_alpha self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1) self.cosine_sim = cosine_sim
self.heads = heads self.heads = heads
inner_dim = dim_head * heads inner_dim = dim_head * heads
@@ -739,6 +838,13 @@ class Attention(nn.Module):
k = torch.cat((nk, k), dim = -2) k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2) v = torch.cat((nv, v), dim = -2)
# whether to use cosine sim
if self.cosine_sim:
q, k = map(l2norm, (q, k))
q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
# calculate query / key similarities # calculate query / key similarities
sim = einsum('b h i d, b j d -> b h i j', q, k) sim = einsum('b h i d, b j d -> b h i j', q, k)
@@ -764,10 +870,7 @@ class Attention(nn.Module):
# attention # attention
sim = sim - sim.amax(dim = -1, keepdim = True).detach() attn = sim.softmax(dim = -1, dtype = torch.float32)
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1)
attn = self.dropout(attn) attn = self.dropout(attn)
# aggregate values # aggregate values
@@ -834,9 +937,12 @@ class DiffusionPriorNetwork(nn.Module):
num_image_embeds = 1, num_image_embeds = 1,
num_text_embeds = 1, num_text_embeds = 1,
max_text_len = 256, max_text_len = 256,
self_cond = False,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
self.dim = dim
self.num_time_embeds = num_time_embeds self.num_time_embeds = num_time_embeds
self.num_image_embeds = num_image_embeds self.num_image_embeds = num_image_embeds
self.num_text_embeds = num_text_embeds self.num_text_embeds = num_text_embeds
@@ -864,6 +970,10 @@ class DiffusionPriorNetwork(nn.Module):
self.max_text_len = max_text_len self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim)) self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
# whether to use self conditioning, Hinton's group's new ddpm technique
self.self_cond = self_cond
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,
*args, *args,
@@ -885,12 +995,19 @@ class DiffusionPriorNetwork(nn.Module):
*, *,
text_embed, text_embed,
text_encodings = None, text_encodings = None,
self_cond = None,
cond_drop_prob = 0. cond_drop_prob = 0.
): ):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
# setup self conditioning
if self.self_cond:
self_cond = default(self_cond, lambda: torch.zeros(batch, self.dim, device = device, dtype = dtype))
self_cond = rearrange(self_cond, 'b d -> b 1 d')
# in section 2.2, last paragraph # in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction" # "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
@@ -940,13 +1057,16 @@ class DiffusionPriorNetwork(nn.Module):
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right # but let's just do it right
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.to_time_embeds(diffusion_timesteps) time_embed = self.to_time_embeds(diffusion_timesteps)
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch) learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
if self.self_cond:
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
tokens = torch.cat(( tokens = torch.cat((
text_encodings, text_encodings,
text_embed, text_embed,
@@ -1048,45 +1168,50 @@ class DiffusionPrior(nn.Module):
def l2norm_clamp_embed(self, image_embed): def l2norm_clamp_embed(self, image_embed):
return l2norm(image_embed) * self.image_embed_scale return l2norm(image_embed) * self.image_embed_scale
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.): def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = False, cond_scale = 1.):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond) pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
if self.predict_x_start: if self.predict_x_start:
x_recon = pred x_start = pred
else: else:
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised and not self.predict_x_start: if clip_denoised and not self.predict_x_start:
x_recon.clamp_(-1., 1.) x_start.clamp_(-1., 1.)
if self.predict_x_start and self.sampling_clamp_l2norm: if self.predict_x_start and self.sampling_clamp_l2norm:
x_recon = l2norm(x_recon) * self.image_embed_scale x_start = l2norm(x_start) * self.image_embed_scale
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad() @torch.no_grad()
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.): def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale) model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
noise = torch.randn_like(x) noise = torch.randn_like(x)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return pred, x_start
@torch.no_grad() @torch.no_grad()
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.): def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
batch, device = shape[0], self.device batch, device = shape[0], self.device
image_embed = torch.randn(shape, device = device) image_embed = torch.randn(shape, device = device)
x_start = None # for self-conditioning
if self.init_image_embed_l2norm: if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps): for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
times = torch.full((batch,), i, device = device, dtype = torch.long) times = torch.full((batch,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
self_cond = x_start if self.net.self_cond else None
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale)
if self.sampling_final_clamp_l2norm and self.predict_x_start: if self.sampling_final_clamp_l2norm and self.predict_x_start:
image_embed = self.l2norm_clamp_embed(image_embed) image_embed = self.l2norm_clamp_embed(image_embed)
@@ -1104,6 +1229,8 @@ class DiffusionPrior(nn.Module):
image_embed = torch.randn(shape, device = device) image_embed = torch.randn(shape, device = device)
x_start = None # for self-conditioning
if self.init_image_embed_l2norm: if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale image_embed = l2norm(image_embed) * self.image_embed_scale
@@ -1113,7 +1240,9 @@ class DiffusionPrior(nn.Module):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long) time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond) self_cond = x_start if self.net.self_cond else None
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
if self.predict_x_start: if self.predict_x_start:
x_start = pred x_start = pred
@@ -1157,9 +1286,15 @@ class DiffusionPrior(nn.Module):
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise) image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
self_cond = None
if self.net.self_cond and random.random() < 0.5:
with torch.no_grad():
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
pred = self.net( pred = self.net(
image_embed_noisy, image_embed_noisy,
times, times,
self_cond = self_cond,
cond_drop_prob = self.cond_drop_prob, cond_drop_prob = self.cond_drop_prob,
**text_cond **text_cond
) )
@@ -1357,7 +1492,8 @@ class ResnetBlock(nn.Module):
*, *,
cond_dim = None, cond_dim = None,
time_cond_dim = None, time_cond_dim = None,
groups = 8 groups = 8,
cosine_sim_cross_attn = False
): ):
super().__init__() super().__init__()
@@ -1377,7 +1513,8 @@ class ResnetBlock(nn.Module):
'b (h w) c', 'b (h w) c',
CrossAttention( CrossAttention(
dim = dim_out, dim = dim_out,
context_dim = cond_dim context_dim = cond_dim,
cosine_sim = cosine_sim_cross_attn
) )
) )
@@ -1412,11 +1549,12 @@ class CrossAttention(nn.Module):
heads = 8, heads = 8,
dropout = 0., dropout = 0.,
norm_context = False, norm_context = False,
pb_relax_alpha = 32 ** 2 cosine_sim = False,
cosine_sim_scale = 16
): ):
super().__init__() super().__init__()
self.pb_relax_alpha = pb_relax_alpha self.cosine_sim = cosine_sim
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1) self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
self.heads = heads self.heads = heads
inner_dim = dim_head * heads inner_dim = dim_head * heads
@@ -1452,7 +1590,10 @@ class CrossAttention(nn.Module):
k = torch.cat((nk, k), dim = -2) k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2) v = torch.cat((nv, v), dim = -2)
q = q * self.scale if self.cosine_sim:
q, k = map(l2norm, (q, k))
q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
sim = einsum('b h i d, b h j d -> b h i j', q, k) sim = einsum('b h i d, b h j d -> b h i j', q, k)
max_neg_value = -torch.finfo(sim.dtype).max max_neg_value = -torch.finfo(sim.dtype).max
@@ -1462,10 +1603,7 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j') mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value) sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True).detach() attn = sim.softmax(dim = -1, dtype = torch.float32)
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v) out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)') out = rearrange(out, 'b h n d -> b n (h d)')
@@ -1476,7 +1614,8 @@ class LinearAttention(nn.Module):
self, self,
dim, dim,
dim_head = 32, dim_head = 32,
heads = 8 heads = 8,
**kwargs
): ):
super().__init__() super().__init__()
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
@@ -1494,6 +1633,7 @@ class LinearAttention(nn.Module):
def forward(self, fmap): def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:] h, x, y = self.heads, *fmap.shape[-2:]
seq_len = x * y
fmap = self.norm(fmap) fmap = self.norm(fmap)
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1) q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
@@ -1503,7 +1643,9 @@ class LinearAttention(nn.Module):
k = k.softmax(dim = -2) k = k.softmax(dim = -2)
q = q * self.scale q = q * self.scale
v = v / (x * y) v = l2norm(v)
k, v = map(lambda t: t / math.sqrt(seq_len), (k, v))
context = einsum('b n d, b n e -> b d e', k, v) context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context) out = einsum('b n d, b d e -> b n e', q, context)
@@ -1590,7 +1732,10 @@ class Unet(nn.Module):
attn_heads = 16, attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
self_cond = False, # set this to True to use the self-conditioning technique from - https://arxiv.org/abs/2208.04202
sparse_attn = False, sparse_attn = False,
cosine_sim_cross_attn = False,
cosine_sim_self_attn = False,
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False, cond_on_text_encodings = False,
max_text_len = 256, max_text_len = 256,
@@ -1609,6 +1754,7 @@ class Unet(nn.Module):
pixel_shuffle_upsample = True, pixel_shuffle_upsample = True,
final_conv_kernel_size = 1, final_conv_kernel_size = 1,
combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
checkpoint_during_training = False,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -1622,12 +1768,21 @@ class Unet(nn.Module):
self.lowres_cond = lowres_cond self.lowres_cond = lowres_cond
# whether to do self conditioning
self.self_cond = self_cond
# determine dimensions # determine dimensions
self.channels = channels self.channels = channels
self.channels_out = default(channels_out, channels) self.channels_out = default(channels_out, channels)
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis # initial number of channels depends on
# (1) low resolution conditioning from cascading ddpm paper, conditioned on previous unet output in the cascade
# (2) self conditioning (bit diffusion paper)
init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
init_dim = default(init_dim, dim) init_dim = default(init_dim, dim)
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2) self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
@@ -1711,7 +1866,7 @@ class Unet(nn.Module):
# attention related params # attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head) attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim = cosine_sim_self_attn)
self_attn = cast_tuple(self_attn, num_stages) self_attn = cast_tuple(self_attn, num_stages)
@@ -1734,9 +1889,13 @@ class Unet(nn.Module):
upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
# prepare resnet klass
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn)
# give memory efficient unet an initial resnet block # give memory efficient unet an initial resnet block
self.init_resnet_block = ResnetBlock(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None self.init_resnet_block = resnet_block(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
# layers # layers
@@ -1763,17 +1922,17 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([ self.downs.append(nn.ModuleList([
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None, downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
ResnetBlock(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups), resnet_block(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]), nn.ModuleList([resnet_block(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
attention, attention,
downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1) downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
])) ]))
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) self.mid_block1 = resnet_block(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_attn = create_self_attn(mid_dim) self.mid_attn = create_self_attn(mid_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) self.mid_block2 = resnet_block(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))): for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))):
is_last = ind >= (len(in_out) - 1) is_last = ind >= (len(in_out) - 1)
@@ -1790,8 +1949,8 @@ class Unet(nn.Module):
upsample_combiner_dims.append(dim_out) upsample_combiner_dims.append(dim_out)
self.ups.append(nn.ModuleList([ self.ups.append(nn.ModuleList([
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups), resnet_block(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]), nn.ModuleList([resnet_block(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
attention, attention,
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity() upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
])) ]))
@@ -1807,7 +1966,7 @@ class Unet(nn.Module):
# a final resnet block # a final resnet block
self.final_resnet_block = ResnetBlock(self.upsample_combiner.dim_out + dim, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) self.final_resnet_block = resnet_block(self.upsample_combiner.dim_out + dim, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
out_dim_in = dim + (channels if lowres_cond else 0) out_dim_in = dim + (channels if lowres_cond else 0)
@@ -1815,6 +1974,10 @@ class Unet(nn.Module):
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
# whether to checkpoint during training
self.checkpoint_during_training = checkpoint_during_training
# if the current settings for the unet are not correct # if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings # for cascading DDPM, then reinit the unet with the right settings
def cast_model_parameters( def cast_model_parameters(
@@ -1831,7 +1994,7 @@ class Unet(nn.Module):
channels == self.channels and \ channels == self.channels and \
cond_on_image_embeds == self.cond_on_image_embeds and \ cond_on_image_embeds == self.cond_on_image_embeds and \
cond_on_text_encodings == self.cond_on_text_encodings and \ cond_on_text_encodings == self.cond_on_text_encodings and \
cond_on_lowres_noise == self.cond_on_lowres_noise and \ lowres_noise_cond == self.lowres_noise_cond and \
channels_out == self.channels_out: channels_out == self.channels_out:
return self return self
@@ -1872,7 +2035,9 @@ class Unet(nn.Module):
image_cond_drop_prob = 0., image_cond_drop_prob = 0.,
text_cond_drop_prob = 0., text_cond_drop_prob = 0.,
blur_sigma = None, blur_sigma = None,
blur_kernel_size = None blur_kernel_size = None,
disable_checkpoint = False,
self_cond = None
): ):
batch_size, device = x.shape[0], x.device batch_size, device = x.shape[0], x.device
@@ -1880,6 +2045,14 @@ class Unet(nn.Module):
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present' assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
# concat self conditioning, if needed
if self.self_cond:
self_cond = default(self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x, self_cond), dim = 1)
# concat low resolution conditioning
if exists(lowres_cond_img): if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1) x = torch.cat((x, lowres_cond_img), dim = 1)
@@ -1994,17 +2167,29 @@ class Unet(nn.Module):
c = self.norm_cond(c) c = self.norm_cond(c)
mid_c = self.norm_mid_cond(mid_c) mid_c = self.norm_mid_cond(mid_c)
# gradient checkpointing
can_checkpoint = self.training and self.checkpoint_during_training and not disable_checkpoint
apply_checkpoint_fn = make_checkpointable if can_checkpoint else identity
# make checkpointable modules
init_resnet_block, mid_block1, mid_attn, mid_block2, final_resnet_block = [maybe(apply_checkpoint_fn)(module) for module in (self.init_resnet_block, self.mid_block1, self.mid_attn, self.mid_block2, self.final_resnet_block)]
can_checkpoint_cond = lambda m: isinstance(m, ResnetBlock)
downs, ups = [maybe(apply_checkpoint_fn)(m, condition = can_checkpoint_cond) for m in (self.downs, self.ups)]
# initial resnet block # initial resnet block
if exists(self.init_resnet_block): if exists(init_resnet_block):
x = self.init_resnet_block(x, t) x = init_resnet_block(x, t)
# go through the layers of the unet, down and up # go through the layers of the unet, down and up
down_hiddens = [] down_hiddens = []
up_hiddens = [] up_hiddens = []
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs: for pre_downsample, init_block, resnet_blocks, attn, post_downsample in downs:
if exists(pre_downsample): if exists(pre_downsample):
x = pre_downsample(x) x = pre_downsample(x)
@@ -2020,16 +2205,16 @@ class Unet(nn.Module):
if exists(post_downsample): if exists(post_downsample):
x = post_downsample(x) x = post_downsample(x)
x = self.mid_block1(x, t, mid_c) x = mid_block1(x, t, mid_c)
if exists(self.mid_attn): if exists(mid_attn):
x = self.mid_attn(x) x = mid_attn(x)
x = self.mid_block2(x, t, mid_c) x = mid_block2(x, t, mid_c)
connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1) connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)
for init_block, resnet_blocks, attn, upsample in self.ups: for init_block, resnet_blocks, attn, upsample in ups:
x = connect_skip(x) x = connect_skip(x)
x = init_block(x, t, c) x = init_block(x, t, c)
@@ -2046,7 +2231,7 @@ class Unet(nn.Module):
x = torch.cat((x, r), dim = 1) x = torch.cat((x, r), dim = 1)
x = self.final_resnet_block(x, t) x = final_resnet_block(x, t)
if exists(lowres_cond_img): if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1) x = torch.cat((x, lowres_cond_img), dim = 1)
@@ -2437,23 +2622,23 @@ class Decoder(nn.Module):
x = x.clamp(-s, s) / s x = x.clamp(-s, s) / s
return x return x
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None): def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)) pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
if learned_variance: if learned_variance:
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1) pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
if predict_x_start: if predict_x_start:
x_recon = pred x_start = pred
else: else:
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised: if clip_denoised:
x_recon = self.dynamic_threshold(x_recon) x_start = self.dynamic_threshold(x_start)
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
if learned_variance: if learned_variance:
# if learned variance, posterio variance and posterior log variance are predicted by the network # if learned variance, posterio variance and posterior log variance are predicted by the network
@@ -2469,16 +2654,17 @@ class Decoder(nn.Module):
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp() posterior_variance = posterior_log_variance.exp()
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad() @torch.no_grad()
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None): def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level) model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
noise = torch.randn_like(x) noise = torch.randn_like(x)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return pred, x_start
@torch.no_grad() @torch.no_grad()
def p_sample_loop_ddpm( def p_sample_loop_ddpm(
@@ -2504,6 +2690,8 @@ class Decoder(nn.Module):
b = shape[0] b = shape[0]
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
x_start = None # for self-conditioning
is_inpaint = exists(inpaint_image) is_inpaint = exists(inpaint_image)
resample_times = inpaint_resample_times if is_inpaint else 1 resample_times = inpaint_resample_times if is_inpaint else 1
@@ -2531,13 +2719,16 @@ class Decoder(nn.Module):
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times) noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask) img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
img = self.p_sample( self_cond = x_start if unet.self_cond else None
img, x_start = self.p_sample(
unet, unet,
img, img,
times, times,
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
cond_scale = cond_scale, cond_scale = cond_scale,
self_cond = self_cond,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level, lowres_noise_level = lowres_noise_level,
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
@@ -2596,6 +2787,8 @@ class Decoder(nn.Module):
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
x_start = None # for self-conditioning
if not is_latent_diffusion: if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img) lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
@@ -2616,7 +2809,9 @@ class Decoder(nn.Module):
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond) noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask) img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level) self_cond = x_start if unet.self_cond else None
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
if learned_variance: if learned_variance:
pred, _ = pred.chunk(2, dim = 1) pred, _ = pred.chunk(2, dim = 1)
@@ -2676,13 +2871,35 @@ class Decoder(nn.Module):
x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise) x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
model_output = unet( # unet kwargs
x_noisy,
times, unet_kwargs = dict(
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level, lowres_noise_level = lowres_noise_level,
)
# self conditioning
self_cond = None
if unet.self_cond and random.random() < 0.5:
with torch.no_grad():
self_cond = unet(x_noisy, times, **unet_kwargs)
if learned_variance:
self_cond, _ = self_cond.chunk(2, dim = 1)
self_cond = self_cond.detach()
# forward to get model prediction
model_output = unet(
x_noisy,
times,
**unet_kwargs,
self_cond = self_cond,
image_cond_drop_prob = self.image_cond_drop_prob, image_cond_drop_prob = self.image_cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob, text_cond_drop_prob = self.text_cond_drop_prob,
) )
@@ -2713,7 +2930,7 @@ class Decoder(nn.Module):
# if learning the variance, also include the extra weight kl loss # if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times) true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output) model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
# kl loss with detached model predicted mean, for stability reasons as in paper # kl loss with detached model predicted mean, for stability reasons as in paper

View File

@@ -174,7 +174,7 @@ class DiffusionPriorTrainer(nn.Module):
def __init__( def __init__(
self, self,
diffusion_prior, diffusion_prior,
accelerator, accelerator = None,
use_ema = True, use_ema = True,
lr = 3e-4, lr = 3e-4,
wd = 1e-2, wd = 1e-2,
@@ -186,8 +186,12 @@ class DiffusionPriorTrainer(nn.Module):
): ):
super().__init__() super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior) assert isinstance(diffusion_prior, DiffusionPrior)
assert isinstance(accelerator, Accelerator)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)
if not exists(accelerator):
accelerator = Accelerator(**accelerator_kwargs)
# assign some helpful member vars # assign some helpful member vars

View File

@@ -1 +1 @@
__version__ = '1.2.1' __version__ = '1.6.3'

View File

@@ -26,7 +26,7 @@ setup(
install_requires=[ install_requires=[
'accelerate', 'accelerate',
'click', 'click',
'clip-anytorch', 'clip-anytorch>=2.4.0',
'coca-pytorch>=0.0.5', 'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7', 'ema-pytorch>=0.0.7',
'einops>=0.4', 'einops>=0.4',