Compare commits

..

1 Commits

8 changed files with 101 additions and 221 deletions

2
.github/FUNDING.yml vendored
View File

@@ -1 +1 @@
github: [nousr, Veldrovive, lucidrains]
github: [lucidrains]

View File

@@ -45,7 +45,6 @@ This library would not have gotten to this working state without the help of
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/marunine">Marunine</a> for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes
- <a href="https://github.com/malumadev">MalumaDev</a> for proposing the use of pixel shuffle upsampler for fixing checkboard artifacts
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
@@ -356,8 +355,7 @@ prior_network = DiffusionPriorNetwork(
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 1000,
sample_timesteps = 64,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
@@ -421,7 +419,7 @@ For the layperson, no worries, training will all be automated into a CLI tool, a
## Training on Preprocessed CLIP Embeddings
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings`
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
Working example below

View File

@@ -77,11 +77,6 @@ def cast_tuple(val, length = None):
def module_device(module):
return next(module.parameters()).device
def zero_init_(m):
nn.init.zeros_(m.weight)
if exists(m.bias):
nn.init.zeros_(m.bias)
@contextmanager
def null_context(*args, **kwargs):
yield
@@ -146,7 +141,7 @@ def resize_image_to(
scale_factors = target_image_size / orig_image_size
out = resize(image, scale_factors = scale_factors, **kwargs)
else:
out = F.interpolate(image, target_image_size, mode = 'nearest')
out = F.interpolate(image, target_image_size, mode = 'nearest', align_corners = False)
if exists(clamp_range):
out = out.clamp(*clamp_range)
@@ -165,7 +160,7 @@ def unnormalize_zero_to_one(normed_img):
# clip related adapters
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings'])
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask'])
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
class BaseClipAdapter(nn.Module):
@@ -225,8 +220,7 @@ class XClipAdapter(BaseClipAdapter):
encoder_output = self.clip.text_transformer(text)
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
return EmbeddedText(l2norm(text_embed), text_encodings)
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
@torch.no_grad()
def embed_image(self, image):
@@ -261,8 +255,7 @@ class CoCaAdapter(BaseClipAdapter):
text = text[..., :self.max_text_len]
text_mask = text != 0
text_embed, text_encodings = self.clip.embed_text(text)
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
return EmbeddedText(text_embed, text_encodings)
return EmbeddedText(text_embed, text_encodings, text_mask)
@torch.no_grad()
def embed_image(self, image):
@@ -278,7 +271,6 @@ class OpenAIClipAdapter(BaseClipAdapter):
import clip
openai_clip, preprocess = clip.load(name)
super().__init__(openai_clip)
self.eos_id = 49407 # for handling 0 being also '!'
text_attention_final = self.find_layer('ln_final')
self.handle = text_attention_final.register_forward_hook(self._hook)
@@ -317,17 +309,13 @@ class OpenAIClipAdapter(BaseClipAdapter):
@torch.no_grad()
def embed_text(self, text):
text = text[..., :self.max_text_len]
is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
text_mask = text != 0
assert not self.cleared
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
del self.text_encodings
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
@torch.no_grad()
def embed_image(self, image):
@@ -531,31 +519,25 @@ class NoiseScheduler(nn.Module):
# diffusion prior
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5, stable = False):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.stable = stable
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
if self.stable:
x = x / x.amax(dim = -1, keepdim = True).detach()
x = x / x.amax(dim = -1, keepdim = True).detach()
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5, stable = False):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.stable = stable
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
if self.stable:
x = x / x.amax(dim = 1, keepdim = True).detach()
x = x / x.amax(dim = 1, keepdim = True).detach()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g
@@ -679,7 +661,7 @@ class Attention(nn.Module):
dropout = 0.,
causal = False,
rotary_emb = None,
pb_relax_alpha = 128
pb_relax_alpha = 32 ** 2
):
super().__init__()
self.pb_relax_alpha = pb_relax_alpha
@@ -770,7 +752,6 @@ class CausalTransformer(nn.Module):
dim_head = 64,
heads = 8,
ff_mult = 4,
norm_in = False,
norm_out = True,
attn_dropout = 0.,
ff_dropout = 0.,
@@ -779,8 +760,6 @@ class CausalTransformer(nn.Module):
rotary_emb = True
):
super().__init__()
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
self.rel_pos_bias = RelPosBias(heads = heads)
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
@@ -792,18 +771,20 @@ class CausalTransformer(nn.Module):
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
def forward(self, x):
def forward(
self,
x,
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
):
n, device = x.shape[1], x.device
x = self.init_norm(x)
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
for attn, ff in self.layers:
x = attn(x, attn_bias = attn_bias) + x
x = attn(x, mask = mask, attn_bias = attn_bias) + x
x = ff(x) + x
out = self.norm(x)
@@ -817,7 +798,6 @@ class DiffusionPriorNetwork(nn.Module):
num_time_embeds = 1,
num_image_embeds = 1,
num_text_embeds = 1,
max_text_len = 256,
**kwargs
):
super().__init__()
@@ -843,11 +823,6 @@ class DiffusionPriorNetwork(nn.Module):
self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
# dalle1 learned padding strategy
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
def forward_with_cond_scale(
self,
*args,
@@ -869,6 +844,7 @@ class DiffusionPriorNetwork(nn.Module):
*,
text_embed,
text_encodings = None,
mask = None,
cond_drop_prob = 0.
):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
@@ -886,28 +862,9 @@ class DiffusionPriorNetwork(nn.Module):
if not exists(text_encodings):
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
mask = torch.any(text_encodings != 0., dim = -1)
# replace any padding in the text encodings with learned padding tokens unique across position
text_encodings = text_encodings[:, :self.max_text_len]
mask = mask[:, :self.max_text_len]
text_len = text_encodings.shape[-2]
remainder = self.max_text_len - text_len
if remainder > 0:
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
mask = F.pad(mask, (0, remainder), value = False)
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
text_encodings = torch.where(
rearrange(mask, 'b n -> b n 1').clone(),
text_encodings,
null_text_embeds
)
if not exists(mask):
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
# classifier free guidance
@@ -924,8 +881,9 @@ class DiffusionPriorNetwork(nn.Module):
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
if exists(mask):
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.to_time_embeds(diffusion_timesteps)
@@ -941,7 +899,7 @@ class DiffusionPriorNetwork(nn.Module):
# attend
tokens = self.causal_transformer(tokens)
tokens = self.causal_transformer(tokens, mask = mask)
# get learned query, which should predict the image embedding (per DDPM timestep)
@@ -964,12 +922,11 @@ class DiffusionPrior(nn.Module):
loss_type = "l2",
predict_x_start = True,
beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
sampling_final_clamp_l2norm = False, # whether to l2norm the final image embedding output (this is also done for images in ddpm)
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False,
training_clamp_l2norm = False,
init_image_embed_l2norm = False,
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
clip_adapter_overrides = dict()
):
super().__init__()
@@ -1006,32 +963,23 @@ class DiffusionPrior(nn.Module):
self.condition_on_text_encodings = condition_on_text_encodings
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
self.predict_x_start = predict_x_start
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)
# whether to force an l2norm, similar to clipping denoised, when sampling
self.sampling_clamp_l2norm = sampling_clamp_l2norm
self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm
self.training_clamp_l2norm = training_clamp_l2norm
self.init_image_embed_l2norm = init_image_embed_l2norm
# device tracker
self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
@property
def device(self):
return self._dummy.device
def l2norm_clamp_embed(self, image_embed):
return l2norm(image_embed) * self.image_embed_scale
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
@@ -1072,9 +1020,6 @@ class DiffusionPrior(nn.Module):
times = torch.full((batch,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
if self.sampling_final_clamp_l2norm and self.predict_x_start:
image_embed = self.l2norm_clamp_embed(image_embed)
return image_embed
@torch.no_grad()
@@ -1110,18 +1055,15 @@ class DiffusionPrior(nn.Module):
x_start.clamp_(-1., 1.)
if self.predict_x_start and self.sampling_clamp_l2norm:
x_start = self.l2norm_clamp_embed(x_start)
x_start = l2norm(x_start) * self.image_embed_scale
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
new_noise = torch.randn_like(image_embed)
image_embed = x_start * alpha_next.sqrt() + \
c1 * noise + \
c2 * pred_noise
if self.predict_x_start and self.sampling_final_clamp_l2norm:
image_embed = self.l2norm_clamp_embed(image_embed)
img = x_start * alpha_next.sqrt() + \
c1 * new_noise + \
c2 * pred_noise
return image_embed
@@ -1149,7 +1091,7 @@ class DiffusionPrior(nn.Module):
)
if self.predict_x_start and self.training_clamp_l2norm:
pred = self.l2norm_clamp_embed(pred)
pred = l2norm(pred) * self.image_embed_scale
target = noise if not self.predict_x_start else image_embed
@@ -1186,12 +1128,12 @@ class DiffusionPrior(nn.Module):
batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim
text_embed, text_encodings = self.clip.embed_text(text)
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
text_cond = dict(text_embed = text_embed)
if self.condition_on_text_encodings:
text_cond = {**text_cond, 'text_encodings': text_encodings}
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)
@@ -1219,6 +1161,7 @@ class DiffusionPrior(nn.Module):
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
image_embed = None,
text_encodings = None, # as well as CLIP text encodings
text_mask = None, # text mask <- may eventually opt for the learned padding tokens technique from DALL-E1 to reduce complexity
*args,
**kwargs
):
@@ -1232,13 +1175,13 @@ class DiffusionPrior(nn.Module):
# calculate text conditionings, based on what is passed in
if exists(text):
text_embed, text_encodings = self.clip.embed_text(text)
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
text_cond = dict(text_embed = text_embed)
if self.condition_on_text_encodings:
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
text_cond = {**text_cond, 'text_encodings': text_encodings}
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
# timestep conditioning from ddpm
@@ -1255,44 +1198,17 @@ class DiffusionPrior(nn.Module):
# decoder
def ConvTransposeUpsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
def NearestUpsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv2d(dim, dim_out, 3, padding = 1)
)
class PixelShuffleUpsample(nn.Module):
"""
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
"""
def __init__(self, dim, dim_out = None):
super().__init__()
dim_out = default(dim_out, dim)
conv = nn.Conv2d(dim, dim_out * 4, 1)
self.net = nn.Sequential(
conv,
nn.SiLU(),
nn.PixelShuffle(2)
)
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
return self.net(x)
def Downsample(dim, *, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Conv2d(dim, dim_out, 4, 2, 1)
@@ -1555,7 +1471,7 @@ class Unet(nn.Module):
cross_embed_downsample_kernel_sizes = (2, 4),
memory_efficient = False,
scale_skip_connection = False,
pixel_shuffle_upsample = True,
nearest_upsample = False,
final_conv_kernel_size = 1,
**kwargs
):
@@ -1621,12 +1537,10 @@ class Unet(nn.Module):
# text encoding conditioning (optional)
self.text_to_cond = None
self.text_embed_dim = None
if cond_on_text_encodings:
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
self.text_embed_dim = text_embed_dim
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
@@ -1669,7 +1583,7 @@ class Unet(nn.Module):
# upsample klass
upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
# give memory efficient unet an initial resnet block
@@ -1731,12 +1645,7 @@ class Unet(nn.Module):
]))
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
out_dim_in = dim + (channels if lowres_cond else 0)
self.to_out = nn.Conv2d(out_dim_in, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
@@ -1788,6 +1697,7 @@ class Unet(nn.Module):
image_embed,
lowres_cond_img = None,
text_encodings = None,
text_mask = None,
image_cond_drop_prob = 0.,
text_cond_drop_prob = 0.,
blur_sigma = None,
@@ -1859,27 +1769,21 @@ class Unet(nn.Module):
text_tokens = None
if exists(text_encodings) and self.cond_on_text_encodings:
assert text_encodings.shape[0] == batch_size, f'the text encodings being passed into the unet does not have the proper batch size - text encoding shape {text_encodings.shape} - required batch size is {batch_size}'
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
text_mask = torch.any(text_encodings != 0., dim = -1)
text_tokens = self.text_to_cond(text_encodings)
text_tokens = text_tokens[:, :self.max_text_len]
text_mask = text_mask[:, :self.max_text_len]
text_tokens_len = text_tokens.shape[1]
remainder = self.max_text_len - text_tokens_len
if remainder > 0:
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
text_mask = F.pad(text_mask, (0, remainder), value = False)
text_mask = rearrange(text_mask, 'b n -> b n 1')
if exists(text_mask):
if remainder > 0:
text_mask = F.pad(text_mask, (0, remainder), value = False)
assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}. text encoding is of shape {text_encodings.shape}'
text_keep_mask = text_mask & text_keep_mask
text_mask = rearrange(text_mask, 'b n -> b n 1')
text_keep_mask = text_mask & text_keep_mask
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
@@ -1926,7 +1830,7 @@ class Unet(nn.Module):
hiddens.append(x)
x = attn(x)
hiddens.append(x.contiguous())
hiddens.append(x)
if exists(post_downsample):
x = post_downsample(x)
@@ -1954,26 +1858,23 @@ class Unet(nn.Module):
x = torch.cat((x, r), dim = 1)
x = self.final_resnet_block(x, t)
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
return self.to_out(x)
class LowresConditioner(nn.Module):
def __init__(
self,
downsample_first = True,
blur_prob = 0.5,
downsample_mode_nearest = False,
blur_sigma = 0.6,
blur_kernel_size = 3,
input_image_range = None
):
super().__init__()
self.downsample_first = downsample_first
self.downsample_mode_nearest = downsample_mode_nearest
self.input_image_range = input_image_range
self.blur_prob = blur_prob
self.blur_sigma = blur_sigma
self.blur_kernel_size = blur_kernel_size
@@ -1986,27 +1887,20 @@ class LowresConditioner(nn.Module):
blur_sigma = None,
blur_kernel_size = None
):
if self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = True)
# blur is only applied 50% of the time
# section 3.1 in https://arxiv.org/abs/2106.15282
if random.random() < self.blur_prob:
if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = self.downsample_mode_nearest)
if self.training:
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
# allow for drawing a random sigma between lo and hi float values
if isinstance(blur_sigma, tuple):
blur_sigma = tuple(map(float, blur_sigma))
blur_sigma = random.uniform(*blur_sigma)
# allow for drawing a random kernel size between lo and hi int values
if isinstance(blur_kernel_size, tuple):
blur_kernel_size = tuple(map(int, blur_kernel_size))
kernel_size_lo, kernel_size_hi = blur_kernel_size
@@ -2014,7 +1908,8 @@ class LowresConditioner(nn.Module):
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range, nearest = True)
cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range)
return cond_fmap
class Decoder(nn.Module):
@@ -2037,7 +1932,7 @@ class Decoder(nn.Module):
image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_prob = 0.5, # cascading ddpm - when training, the gaussian blur is only applied 50% of the time
lowres_downsample_mode_nearest = False, # cascading ddpm - whether to use nearest mode downsampling for lower resolution
blur_sigma = 0.6, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
clip_denoised = True,
@@ -2148,7 +2043,7 @@ class Decoder(nn.Module):
self.noise_schedulers = nn.ModuleList([])
for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
assert not exists(sample_timesteps) or sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
assert sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
noise_scheduler = NoiseScheduler(
beta_schedule = unet_beta_schedule,
@@ -2172,7 +2067,6 @@ class Decoder(nn.Module):
# random crop sizes (for super-resoluting unets at the end of cascade?)
self.random_crop_sizes = cast_tuple(random_crop_sizes, len(image_sizes))
assert not exists(self.random_crop_sizes[0]), 'you would not need to randomly crop the image for the base unet'
# predict x0 config
@@ -2189,7 +2083,7 @@ class Decoder(nn.Module):
self.to_lowres_cond = LowresConditioner(
downsample_first = lowres_downsample_first,
blur_prob = blur_prob,
downsample_mode_nearest = lowres_downsample_mode_nearest,
blur_sigma = blur_sigma,
blur_kernel_size = blur_kernel_size,
input_image_range = self.input_image_range
@@ -2271,10 +2165,10 @@ class Decoder(nn.Module):
x = x.clamp(-s, s) / s
return x
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
if learned_variance:
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
@@ -2306,16 +2200,16 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False):
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
device = self.device
b = shape[0]
@@ -2331,6 +2225,7 @@ class Decoder(nn.Module):
torch.full((b,), i, device = device, dtype = torch.long),
image_embed = image_embed,
text_encodings = text_encodings,
text_mask = text_mask,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start,
@@ -2343,7 +2238,7 @@ class Decoder(nn.Module):
return unnormalize_img
@torch.no_grad()
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False):
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
@@ -2353,16 +2248,13 @@ class Decoder(nn.Module):
img = torch.randn(shape, device = device)
if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
alpha = alphas[time]
alpha_next = alphas[time_next]
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
if learned_variance:
pred, _ = pred.chunk(2, dim = 1)
@@ -2379,10 +2271,9 @@ class Decoder(nn.Module):
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(img) if time_next > 0 else 0.
img = x_start * alpha_next.sqrt() + \
c1 * noise + \
c1 * torch.randn_like(img) + \
c2 * pred_noise
img = self.unnormalize_img(img)
@@ -2401,7 +2292,7 @@ class Decoder(nn.Module):
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
@@ -2419,6 +2310,7 @@ class Decoder(nn.Module):
times,
image_embed = image_embed,
text_encodings = text_encodings,
text_mask = text_mask,
lowres_cond_img = lowres_cond_img,
image_cond_drop_prob = self.image_cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob,
@@ -2478,6 +2370,7 @@ class Decoder(nn.Module):
self,
image_embed = None,
text = None,
text_mask = None,
text_encodings = None,
batch_size = 1,
cond_scale = 1.,
@@ -2491,7 +2384,7 @@ class Decoder(nn.Module):
if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exists(self.clip)
_, text_encodings = self.clip.embed_text(text)
_, text_encodings, text_mask = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
@@ -2499,10 +2392,7 @@ class Decoder(nn.Module):
img = None
is_cuda = next(self.parameters()).is_cuda
num_unets = len(self.unets)
cond_scale = cast_tuple(cond_scale, num_unets)
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps, cond_scale)):
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps)):
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
@@ -2511,7 +2401,7 @@ class Decoder(nn.Module):
shape = (batch_size, channel, image_size, image_size)
if unet.lowres_cond:
lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = True)
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
is_latent_diffusion = isinstance(vae, VQGanVAE)
image_size = vae.get_encoded_fmap_size(image_size)
@@ -2524,7 +2414,8 @@ class Decoder(nn.Module):
shape,
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = unet_cond_scale,
text_mask = text_mask,
cond_scale = cond_scale,
predict_x_start = predict_x_start,
learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion,
@@ -2547,6 +2438,7 @@ class Decoder(nn.Module):
text = None,
image_embed = None,
text_encodings = None,
text_mask = None,
unet_number = None,
return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes
):
@@ -2575,13 +2467,13 @@ class Decoder(nn.Module):
if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
_, text_encodings = self.clip.embed_text(text)
_, text_encodings, text_mask = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size, nearest = True)
image = resize_image_to(image, target_image_size)
if exists(random_crop_size):
aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)
@@ -2598,7 +2490,7 @@ class Decoder(nn.Module):
image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
if not return_lowres_cond_image:
return losses

View File

@@ -129,7 +129,6 @@ class AdapterConfig(BaseModel):
class DiffusionPriorNetworkConfig(BaseModel):
dim: int
depth: int
max_text_len: int = None
num_timesteps: int = None
num_time_embeds: int = 1
num_image_embeds: int = 1
@@ -137,7 +136,6 @@ class DiffusionPriorNetworkConfig(BaseModel):
dim_head: int = 64
heads: int = 8
ff_mult: int = 4
norm_in: bool = False
norm_out: bool = True
attn_dropout: float = 0.
ff_dropout: float = 0.
@@ -236,7 +234,7 @@ class DecoderConfig(BaseModel):
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable(int)] = None
sampling_timesteps: Optional[SingularOrIterable(int)] = None
loss_type: str = 'l2'
beta_schedule: ListOrTuple(str) = 'cosine'
learned_variance: bool = True

View File

@@ -673,14 +673,8 @@ class DecoderTrainer(nn.Module):
def sample(self, *args, **kwargs):
distributed = self.accelerator.num_processes > 1
base_decoder = self.accelerator.unwrap_model(self.decoder)
was_training = base_decoder.training
base_decoder.eval()
if kwargs.pop('use_non_ema', False) or not self.use_ema:
out = base_decoder.sample(*args, **kwargs, distributed = distributed)
base_decoder.train(was_training)
return out
return base_decoder.sample(*args, **kwargs, distributed = distributed)
trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
@@ -693,7 +687,6 @@ class DecoderTrainer(nn.Module):
for ema in self.ema_unets:
ema.restore_ema_model_device()
base_decoder.train(was_training)
return output
@torch.no_grad()

View File

@@ -1 +1 @@
__version__ = '0.24.2'
__version__ = '0.19.0'

View File

@@ -323,7 +323,7 @@ def train(
last_snapshot = sample
if next_task == 'train':
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
for i, (img, emb, txt) in enumerate(trainer.train_loader):
# We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
@@ -358,7 +358,6 @@ def train(
else:
# Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts
loss = trainer.forward(img, **forward_params, unet_number=unet)
trainer.update(unet_number=unet)
@@ -417,7 +416,7 @@ def train(
timer = Timer()
accelerator.wait_for_everyone()
i = 0
for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
val_sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item()
@@ -558,7 +557,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
# Create the decoder model and print basic info
decoder = config.decoder.create()
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
num_parameters = sum(p.numel() for p in decoder.parameters())
# Create and initialize the tracker if we are the master
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
@@ -587,10 +586,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
accelerator.print(print_ribbon("Loaded Config", repeat=40))
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
accelerator.print(f"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training")
for i, unet in enumerate(decoder.unets):
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
accelerator.print(f"Number of parameters: {num_parameters}")
train(dataloaders, decoder, accelerator,
tracker=tracker,
inference_device=accelerator.device,

View File

@@ -126,9 +126,9 @@ def report_cosine_sims(
# we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned:
text_embedding, text_encodings = trainer.embed_text(text_data)
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
text_cond = dict(
text_embed=text_embedding, text_encodings=text_encodings
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
)
else:
text_embedding = text_data
@@ -146,12 +146,15 @@ def report_cosine_sims(
if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx]
text_mask_shuffled = text_mask[rolled_idx]
else:
text_encodings_shuffled = None
text_mask_shuffled = None
text_cond_shuffled = dict(
text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled
text_encodings=text_encodings_shuffled,
mask=text_mask_shuffled,
)
# prepare the text embedding