Compare commits

..

8 Commits
1.2.1 ... 1.5.0

5 changed files with 201 additions and 53 deletions

View File

@@ -396,7 +396,7 @@ decoder = Decoder(
).cuda()
for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
@@ -627,6 +627,18 @@ images = dalle2(
# save your image (in this example, of size 256x256)
```
Alternatively, you can also use <a href="https://github.com/mlfoundations/open_clip">Open Clip</a>
```bash
$ pip install open-clip-torch
```
```python
from dalle2_pytorch import OpenClipAdapter
clip = OpenClipAdapter()
```
Now you'll just have to worry about training the Prior and the Decoder!
## Inpainting
@@ -861,25 +873,23 @@ unet1 = Unet(
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
dim_mults=(1, 2, 4, 8),
cond_on_text_encodings = True,
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
cond_on_text_encodings = True
).cuda()
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 1000,
condition_on_text_encodings = True
timesteps = 1000
).cuda()
decoder_trainer = DecoderTrainer(
@@ -904,8 +914,8 @@ for unet_number in (1, 2):
# after much training
# you can sample from the exponentially moving averaged unets as so
mock_image_embed = torch.randn(4, 512).cuda()
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
mock_image_embed = torch.randn(32, 512).cuda()
images = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256)
```
### Diffusion Prior Training

View File

@@ -8,6 +8,7 @@ from pathlib import Path
import torch
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torch import nn, einsum
import torchvision.transforms as T
@@ -108,6 +109,28 @@ def pad_tuple_to_length(t, length, fillvalue = None):
return t
return (*t, *((fillvalue,) * remain_length))
# checkpointing helper function
def make_checkpointable(fn, **kwargs):
if isinstance(fn, nn.ModuleList):
return [maybe(make_checkpointable)(el, **kwargs) for el in fn]
condition = kwargs.pop('condition', None)
if exists(condition) and not condition(fn):
return fn
@wraps(fn)
def inner(*args):
input_needs_grad = any([isinstance(el, torch.Tensor) and el.requires_grad for el in args])
if not input_needs_grad:
return fn(*args)
return checkpoint(fn, *args)
return inner
# for controlling freezing of CLIP
def set_module_requires_grad_(module, requires_grad):
@@ -339,6 +362,75 @@ class OpenAIClipAdapter(BaseClipAdapter):
image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None)
class OpenClipAdapter(BaseClipAdapter):
def __init__(
self,
name = 'ViT-B/32',
pretrained = 'laion400m_e32'
):
import open_clip
clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained)
super().__init__(clip)
self.eos_id = 49407
text_attention_final = self.find_layer('ln_final')
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = preprocess.transforms[-1]
self.cleared = False
def find_layer(self, layer):
modules = dict([*self.clip.named_modules()])
return modules.get(layer, None)
def clear(self):
if self.cleared:
return
self.handle()
def _hook(self, _, inputs, outputs):
self.text_encodings = outputs
@property
def dim_latent(self):
return 512
@property
def image_size(self):
return self.clip.visual.image_size
@property
def image_channels(self):
return 3
@property
def max_text_len(self):
return self.clip.context_length
@torch.no_grad()
def embed_text(self, text):
text = text[..., :self.max_text_len]
is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
assert not self.cleared
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
del self.text_encodings
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())
@torch.no_grad()
def embed_image(self, image):
assert not self.cleared
image = self.validate_and_resize_image(image)
image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None)
# classifier free guidance functions
def prob_mask_like(shape, prob, device):
@@ -547,34 +639,40 @@ class NoiseScheduler(nn.Module):
# diffusion prior
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5, stable = False):
def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
super().__init__()
self.eps = eps
self.fp16_eps = fp16_eps
self.stable = stable
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
eps = self.eps if x.dtype == torch.float32 else self.fp16_eps
if self.stable:
x = x / x.amax(dim = -1, keepdim = True).detach()
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g
return (x - mean) * (var + eps).rsqrt() * self.g
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5, stable = False):
def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
super().__init__()
self.eps = eps
self.fp16_eps = fp16_eps
self.stable = stable
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
eps = self.eps if x.dtype == torch.float32 else self.fp16_eps
if self.stable:
x = x / x.amax(dim = 1, keepdim = True).detach()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g
return (x - mean) * (var + eps).rsqrt() * self.g
class Residual(nn.Module):
def __init__(self, fn):
@@ -695,11 +793,12 @@ class Attention(nn.Module):
dropout = 0.,
causal = False,
rotary_emb = None,
pb_relax_alpha = 128
cosine_sim = True,
cosine_sim_scale = 16
):
super().__init__()
self.pb_relax_alpha = pb_relax_alpha
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
self.cosine_sim = cosine_sim
self.heads = heads
inner_dim = dim_head * heads
@@ -739,6 +838,13 @@ class Attention(nn.Module):
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# whether to use cosine sim
if self.cosine_sim:
q, k = map(l2norm, (q, k))
q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
# calculate query / key similarities
sim = einsum('b h i d, b j d -> b h i j', q, k)
@@ -764,9 +870,6 @@ class Attention(nn.Module):
# attention
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
@@ -1357,7 +1460,8 @@ class ResnetBlock(nn.Module):
*,
cond_dim = None,
time_cond_dim = None,
groups = 8
groups = 8,
cosine_sim_cross_attn = False
):
super().__init__()
@@ -1377,7 +1481,8 @@ class ResnetBlock(nn.Module):
'b (h w) c',
CrossAttention(
dim = dim_out,
context_dim = cond_dim
context_dim = cond_dim,
cosine_sim = cosine_sim_cross_attn
)
)
@@ -1412,11 +1517,12 @@ class CrossAttention(nn.Module):
heads = 8,
dropout = 0.,
norm_context = False,
pb_relax_alpha = 32 ** 2
cosine_sim = False,
cosine_sim_scale = 16
):
super().__init__()
self.pb_relax_alpha = pb_relax_alpha
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
self.cosine_sim = cosine_sim
self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
self.heads = heads
inner_dim = dim_head * heads
@@ -1452,7 +1558,10 @@ class CrossAttention(nn.Module):
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
q = q * self.scale
if self.cosine_sim:
q, k = map(l2norm, (q, k))
q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
sim = einsum('b h i d, b h j d -> b h i j', q, k)
max_neg_value = -torch.finfo(sim.dtype).max
@@ -1462,9 +1571,6 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
@@ -1476,7 +1582,8 @@ class LinearAttention(nn.Module):
self,
dim,
dim_head = 32,
heads = 8
heads = 8,
**kwargs
):
super().__init__()
self.scale = dim_head ** -0.5
@@ -1494,6 +1601,7 @@ class LinearAttention(nn.Module):
def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:]
seq_len = x * y
fmap = self.norm(fmap)
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
@@ -1503,7 +1611,9 @@ class LinearAttention(nn.Module):
k = k.softmax(dim = -2)
q = q * self.scale
v = v / (x * y)
v = l2norm(v)
k, v = map(lambda t: t / math.sqrt(seq_len), (k, v))
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
@@ -1591,6 +1701,8 @@ class Unet(nn.Module):
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
sparse_attn = False,
cosine_sim_cross_attn = False,
cosine_sim_self_attn = False,
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False,
max_text_len = 256,
@@ -1609,6 +1721,7 @@ class Unet(nn.Module):
pixel_shuffle_upsample = True,
final_conv_kernel_size = 1,
combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
checkpoint_during_training = False,
**kwargs
):
super().__init__()
@@ -1711,7 +1824,7 @@ class Unet(nn.Module):
# attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim = cosine_sim_self_attn)
self_attn = cast_tuple(self_attn, num_stages)
@@ -1734,9 +1847,13 @@ class Unet(nn.Module):
upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
# prepare resnet klass
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn)
# give memory efficient unet an initial resnet block
self.init_resnet_block = ResnetBlock(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
self.init_resnet_block = resnet_block(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
# layers
@@ -1763,17 +1880,17 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
ResnetBlock(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
resnet_block(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([resnet_block(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
attention,
downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
]))
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_block1 = resnet_block(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_attn = create_self_attn(mid_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_block2 = resnet_block(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))):
is_last = ind >= (len(in_out) - 1)
@@ -1790,8 +1907,8 @@ class Unet(nn.Module):
upsample_combiner_dims.append(dim_out)
self.ups.append(nn.ModuleList([
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
resnet_block(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([resnet_block(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
attention,
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
]))
@@ -1807,7 +1924,7 @@ class Unet(nn.Module):
# a final resnet block
self.final_resnet_block = ResnetBlock(self.upsample_combiner.dim_out + dim, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
self.final_resnet_block = resnet_block(self.upsample_combiner.dim_out + dim, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
out_dim_in = dim + (channels if lowres_cond else 0)
@@ -1815,6 +1932,10 @@ class Unet(nn.Module):
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
# whether to checkpoint during training
self.checkpoint_during_training = checkpoint_during_training
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def cast_model_parameters(
@@ -1831,7 +1952,7 @@ class Unet(nn.Module):
channels == self.channels and \
cond_on_image_embeds == self.cond_on_image_embeds and \
cond_on_text_encodings == self.cond_on_text_encodings and \
cond_on_lowres_noise == self.cond_on_lowres_noise and \
lowres_noise_cond == self.lowres_noise_cond and \
channels_out == self.channels_out:
return self
@@ -1872,7 +1993,8 @@ class Unet(nn.Module):
image_cond_drop_prob = 0.,
text_cond_drop_prob = 0.,
blur_sigma = None,
blur_kernel_size = None
blur_kernel_size = None,
disable_checkpoint = False
):
batch_size, device = x.shape[0], x.device
@@ -1994,17 +2116,29 @@ class Unet(nn.Module):
c = self.norm_cond(c)
mid_c = self.norm_mid_cond(mid_c)
# gradient checkpointing
can_checkpoint = self.training and self.checkpoint_during_training and not disable_checkpoint
apply_checkpoint_fn = make_checkpointable if can_checkpoint else identity
# make checkpointable modules
init_resnet_block, mid_block1, mid_attn, mid_block2, final_resnet_block = [maybe(apply_checkpoint_fn)(module) for module in (self.init_resnet_block, self.mid_block1, self.mid_attn, self.mid_block2, self.final_resnet_block)]
can_checkpoint_cond = lambda m: isinstance(m, ResnetBlock)
downs, ups = [maybe(apply_checkpoint_fn)(m, condition = can_checkpoint_cond) for m in (self.downs, self.ups)]
# initial resnet block
if exists(self.init_resnet_block):
x = self.init_resnet_block(x, t)
if exists(init_resnet_block):
x = init_resnet_block(x, t)
# go through the layers of the unet, down and up
down_hiddens = []
up_hiddens = []
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in downs:
if exists(pre_downsample):
x = pre_downsample(x)
@@ -2020,16 +2154,16 @@ class Unet(nn.Module):
if exists(post_downsample):
x = post_downsample(x)
x = self.mid_block1(x, t, mid_c)
x = mid_block1(x, t, mid_c)
if exists(self.mid_attn):
x = self.mid_attn(x)
if exists(mid_attn):
x = mid_attn(x)
x = self.mid_block2(x, t, mid_c)
x = mid_block2(x, t, mid_c)
connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)
for init_block, resnet_blocks, attn, upsample in self.ups:
for init_block, resnet_blocks, attn, upsample in ups:
x = connect_skip(x)
x = init_block(x, t, c)
@@ -2046,7 +2180,7 @@ class Unet(nn.Module):
x = torch.cat((x, r), dim = 1)
x = self.final_resnet_block(x, t)
x = final_resnet_block(x, t)
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)

View File

@@ -174,7 +174,7 @@ class DiffusionPriorTrainer(nn.Module):
def __init__(
self,
diffusion_prior,
accelerator,
accelerator = None,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
@@ -186,8 +186,12 @@ class DiffusionPriorTrainer(nn.Module):
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
assert isinstance(accelerator, Accelerator)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)
if not exists(accelerator):
accelerator = Accelerator(**accelerator_kwargs)
# assign some helpful member vars

View File

@@ -1 +1 @@
__version__ = '1.2.1'
__version__ = '1.5.0'

View File

@@ -26,7 +26,7 @@ setup(
install_requires=[
'accelerate',
'click',
'clip-anytorch',
'clip-anytorch>=2.4.0',
'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7',
'einops>=0.4',