mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-13 21:24:28 +01:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b0edf9e42 | ||
|
|
a922a539de | ||
|
|
8f2466f1cd | ||
|
|
908ab83799 | ||
|
|
46a2558d53 | ||
|
|
86109646e3 | ||
|
|
6a11b9678b | ||
|
|
b90364695d | ||
|
|
868c001199 |
25
README.md
25
README.md
@@ -368,7 +368,8 @@ unet1 = Unet(
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
@@ -385,8 +386,7 @@ decoder = Decoder(
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||
text_cond_drop_prob = 0.5
|
||||
).cuda()
|
||||
|
||||
for unet_number in (1, 2):
|
||||
@@ -1112,15 +1112,6 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Tu2022MaxViTMV,
|
||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||
year = {2022},
|
||||
url = {https://arxiv.org/abs/2204.01697}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Yu2021VectorquantizedIM,
|
||||
title = {Vector-quantized Image Modeling with Improved VQGAN},
|
||||
@@ -1189,4 +1180,14 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Saharia2021PaletteID,
|
||||
title = {Palette: Image-to-Image Diffusion Models},
|
||||
author = {Chitwan Saharia and William Chan and Huiwen Chang and Chris A. Lee and Jonathan Ho and Tim Salimans and David J. Fleet and Mohammad Norouzi},
|
||||
journal = {ArXiv},
|
||||
year = {2021},
|
||||
volume = {abs/2111.05826}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -45,6 +45,11 @@ def exists(val):
|
||||
def identity(t, *args, **kwargs):
|
||||
return t
|
||||
|
||||
def first(arr, d = None):
|
||||
if len(arr) == 0:
|
||||
return d
|
||||
return arr[0]
|
||||
|
||||
def maybe(fn):
|
||||
@wraps(fn)
|
||||
def inner(x):
|
||||
@@ -351,7 +356,7 @@ def cosine_beta_schedule(timesteps, s = 0.008):
|
||||
steps = timesteps + 1
|
||||
x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
|
||||
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
alphas_cumprod = alphas_cumprod / first(alphas_cumprod)
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
return torch.clip(betas, 0, 0.999)
|
||||
|
||||
@@ -1088,8 +1093,16 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
# decoder
|
||||
|
||||
def Upsample(dim):
|
||||
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
||||
def ConvTransposeUpsample(dim, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
|
||||
|
||||
def NearestUpsample(dim, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
return nn.Sequential(
|
||||
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
||||
nn.Conv2d(dim, dim_out, 3, padding = 1)
|
||||
)
|
||||
|
||||
def Downsample(dim, *, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
@@ -1166,7 +1179,7 @@ class ResnetBlock(nn.Module):
|
||||
self.block2 = Block(dim_out, dim_out, groups = groups)
|
||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||
|
||||
def forward(self, x, cond = None, time_emb = None):
|
||||
def forward(self, x, time_emb = None, cond = None):
|
||||
|
||||
scale_shift = None
|
||||
if exists(self.time_mlp) and exists(time_emb):
|
||||
@@ -1247,20 +1260,6 @@ class CrossAttention(nn.Module):
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class GridAttention(nn.Module):
|
||||
def __init__(self, *args, window_size = 8, **kwargs):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.attn = Attention(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
h, w = x.shape[-2:]
|
||||
wsz = self.window_size
|
||||
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
|
||||
out = self.attn(x)
|
||||
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
|
||||
return out
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1359,6 +1358,9 @@ class Unet(nn.Module):
|
||||
cross_embed_downsample = False,
|
||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||
memory_efficient = False,
|
||||
scale_skip_connection = False,
|
||||
nearest_upsample = False,
|
||||
final_conv_kernel_size = 1,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1440,6 +1442,10 @@ class Unet(nn.Module):
|
||||
self.max_text_len = max_text_len
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||
|
||||
# whether to scale skip connection, adopted in Imagen
|
||||
|
||||
self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
|
||||
|
||||
# attention related params
|
||||
|
||||
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
||||
@@ -1447,6 +1453,8 @@ class Unet(nn.Module):
|
||||
# resnet block klass
|
||||
|
||||
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
||||
top_level_resnet_group = first(resnet_groups)
|
||||
|
||||
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
|
||||
|
||||
assert len(resnet_groups) == len(in_out)
|
||||
@@ -1457,23 +1465,36 @@ class Unet(nn.Module):
|
||||
if cross_embed_downsample:
|
||||
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
|
||||
|
||||
# upsample klass
|
||||
|
||||
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
|
||||
|
||||
# give memory efficient unet an initial resnet block
|
||||
|
||||
self.init_resnet_block = ResnetBlock(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
|
||||
|
||||
# layers
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
skip_connect_dims = [] # keeping track of skip connection dimensions
|
||||
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
|
||||
is_first = ind == 0
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
layer_cond_dim = cond_dim if not is_first else None
|
||||
|
||||
dim_layer = dim_out if memory_efficient else dim_in
|
||||
skip_connect_dims.append(dim_layer)
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
|
||||
ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
downsample_klass(dim_out) if not is_last and not memory_efficient else None
|
||||
ResnetBlock(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_layer, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
nn.ModuleList([ResnetBlock(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
@@ -1486,17 +1507,17 @@ class Unet(nn.Module):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
layer_cond_dim = cond_dim if not is_last else None
|
||||
|
||||
skip_connect_dim = skip_connect_dims.pop()
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
Upsample(dim_in) if not is_last or memory_efficient else nn.Identity()
|
||||
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
|
||||
]))
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
|
||||
nn.Conv2d(dim, self.channels_out, 1)
|
||||
)
|
||||
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
||||
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
# for cascading DDPM, then reinit the unet with the right settings
|
||||
@@ -1660,6 +1681,11 @@ class Unet(nn.Module):
|
||||
c = self.norm_cond(c)
|
||||
mid_c = self.norm_mid_cond(mid_c)
|
||||
|
||||
# initial resnet block
|
||||
|
||||
if exists(self.init_resnet_block):
|
||||
x = self.init_resnet_block(x, t)
|
||||
|
||||
# go through the layers of the unet, down and up
|
||||
|
||||
hiddens = []
|
||||
@@ -1668,36 +1694,41 @@ class Unet(nn.Module):
|
||||
if exists(pre_downsample):
|
||||
x = pre_downsample(x)
|
||||
|
||||
x = init_block(x, c, t)
|
||||
x = init_block(x, t, c)
|
||||
x = sparse_attn(x)
|
||||
hiddens.append(x)
|
||||
|
||||
for resnet_block in resnet_blocks:
|
||||
x = resnet_block(x, c, t)
|
||||
|
||||
hiddens.append(x)
|
||||
x = resnet_block(x, t, c)
|
||||
hiddens.append(x)
|
||||
|
||||
if exists(post_downsample):
|
||||
x = post_downsample(x)
|
||||
|
||||
x = self.mid_block1(x, mid_c, t)
|
||||
x = self.mid_block1(x, t, mid_c)
|
||||
|
||||
if exists(self.mid_attn):
|
||||
x = self.mid_attn(x)
|
||||
|
||||
x = self.mid_block2(x, mid_c, t)
|
||||
x = self.mid_block2(x, t, mid_c)
|
||||
|
||||
connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1)
|
||||
|
||||
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim = 1)
|
||||
x = init_block(x, c, t)
|
||||
x = connect_skip(x)
|
||||
x = init_block(x, t, c)
|
||||
x = sparse_attn(x)
|
||||
|
||||
for resnet_block in resnet_blocks:
|
||||
x = resnet_block(x, c, t)
|
||||
x = connect_skip(x)
|
||||
x = resnet_block(x, t, c)
|
||||
|
||||
x = upsample(x)
|
||||
|
||||
x = torch.cat((x, r), dim = 1)
|
||||
return self.final_conv(x)
|
||||
|
||||
x = self.final_resnet_block(x, t)
|
||||
return self.to_out(x)
|
||||
|
||||
class LowresConditioner(nn.Module):
|
||||
def __init__(
|
||||
@@ -1764,7 +1795,7 @@ class Decoder(nn.Module):
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
|
||||
blur_sigma = 0.6, # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
clip_denoised = True,
|
||||
clip_x_start = True,
|
||||
@@ -1781,13 +1812,6 @@ class Decoder(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.unconditional = unconditional
|
||||
|
||||
# text conditioning
|
||||
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# clip
|
||||
|
||||
self.clip = None
|
||||
@@ -1819,12 +1843,16 @@ class Decoder(nn.Module):
|
||||
|
||||
self.channels = channels
|
||||
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
# verify conditioning method
|
||||
|
||||
unets = cast_tuple(unet)
|
||||
num_unets = len(unets)
|
||||
|
||||
self.unconditional = unconditional
|
||||
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
|
||||
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
|
||||
|
||||
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
|
||||
@@ -1860,6 +1888,10 @@ class Decoder(nn.Module):
|
||||
self.unets.append(one_unet)
|
||||
self.vaes.append(one_vae.copy_for_eval())
|
||||
|
||||
# determine from unets whether conditioning on text encoding is needed
|
||||
|
||||
self.condition_on_text_encodings = any([unet.cond_on_text_encodings for unet in self.unets])
|
||||
|
||||
# create noise schedulers per unet
|
||||
|
||||
if not exists(beta_schedule):
|
||||
@@ -2193,7 +2225,8 @@ class Decoder(nn.Module):
|
||||
image_embed = None,
|
||||
text_encodings = None,
|
||||
text_mask = None,
|
||||
unet_number = None
|
||||
unet_number = None,
|
||||
return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes
|
||||
):
|
||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||
unet_number = default(unet_number, 1)
|
||||
@@ -2243,7 +2276,12 @@ class Decoder(nn.Module):
|
||||
image = vae.encode(image)
|
||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
||||
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
||||
|
||||
if not return_lowres_cond_image:
|
||||
return losses
|
||||
|
||||
return losses, lowres_cond_img
|
||||
|
||||
# main class
|
||||
|
||||
@@ -2291,6 +2329,6 @@ class DALLE2(nn.Module):
|
||||
images = list(map(self.to_pil, images.unbind(dim = 0)))
|
||||
|
||||
if one_text:
|
||||
return images[0]
|
||||
return first(images)
|
||||
|
||||
return images
|
||||
|
||||
@@ -158,6 +158,8 @@ class UnetConfig(BaseModel):
|
||||
dim: int
|
||||
dim_mults: ListOrTuple(int)
|
||||
image_embed_dim: int = None
|
||||
text_embed_dim: int = None
|
||||
cond_on_text_encodings: bool = None
|
||||
cond_dim: int = None
|
||||
channels: int = 3
|
||||
attn_dim_head: int = 32
|
||||
@@ -170,7 +172,6 @@ class DecoderConfig(BaseModel):
|
||||
unets: ListOrTuple(UnetConfig)
|
||||
image_size: int = None
|
||||
image_sizes: ListOrTuple(int) = None
|
||||
condition_on_text_encodings: bool = False
|
||||
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
||||
channels: int = 3
|
||||
timesteps: int = 1000
|
||||
@@ -283,21 +284,27 @@ class TrainDecoderConfig(BaseModel):
|
||||
def check_has_embeddings(cls, values):
|
||||
# Makes sure that enough information is provided to get the embeddings specified for training
|
||||
data_config, decoder_config = values.get('data'), values.get('decoder')
|
||||
if data_config is None or decoder_config is None:
|
||||
|
||||
if not exists(data_config) or not exists(decoder_config):
|
||||
# Then something else errored and we should just pass through
|
||||
return values
|
||||
using_text_embeddings = decoder_config.condition_on_text_encodings
|
||||
|
||||
using_text_embeddings = any([unet.cond_on_text_encodings for unet in decoder_config.unets])
|
||||
using_clip = exists(decoder_config.clip)
|
||||
img_emb_url = data_config.img_embeddings_url
|
||||
text_emb_url = data_config.text_embeddings_url
|
||||
|
||||
if using_text_embeddings:
|
||||
# Then we need some way to get the embeddings
|
||||
assert using_clip or text_emb_url is not None, 'If condition_on_text_encodings is true, either clip or text_embeddings_url must be provided'
|
||||
assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'
|
||||
|
||||
if using_clip:
|
||||
if using_text_embeddings:
|
||||
assert text_emb_url is None or img_emb_url is None, 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
|
||||
assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'
|
||||
else:
|
||||
assert img_emb_url is None, 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
|
||||
assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
|
||||
|
||||
if text_emb_url:
|
||||
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
|
||||
|
||||
return values
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.12.1'
|
||||
__version__ = '0.15.1'
|
||||
|
||||
@@ -596,9 +596,11 @@ def initialize_training(config, config_path):
|
||||
|
||||
has_img_embeddings = config.data.img_embeddings_url is not None
|
||||
has_text_embeddings = config.data.text_embeddings_url is not None
|
||||
conditioning_on_text = config.decoder.condition_on_text_encodings
|
||||
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
|
||||
|
||||
has_clip_model = config.decoder.clip is not None
|
||||
data_source_string = ""
|
||||
|
||||
if has_img_embeddings:
|
||||
data_source_string += "precomputed image embeddings"
|
||||
elif has_clip_model:
|
||||
@@ -622,7 +624,7 @@ def initialize_training(config, config_path):
|
||||
inference_device=accelerator.device,
|
||||
load_config=config.load,
|
||||
evaluate_config=config.evaluate,
|
||||
condition_on_text_encodings=config.decoder.condition_on_text_encodings,
|
||||
condition_on_text_encodings=conditioning_on_text,
|
||||
**config.train.dict(),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user