mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8b005510c |
21
README.md
21
README.md
@@ -49,7 +49,6 @@ This library would not have gotten to this working state without the help of
|
||||
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
||||
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
|
||||
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
|
||||
- <a href="https://github.com/arogozhnikov">Alex</a> for <a href="https://github.com/arogozhnikov/einops">einops</a>, indispensable tool for tensor manipulation
|
||||
|
||||
... and many others. Thank you! 🙏
|
||||
|
||||
@@ -1275,24 +1274,4 @@ For detailed information on training the diffusion prior, please refer to the [d
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{rogozhnikov2022einops,
|
||||
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
|
||||
author = {Alex Rogozhnikov},
|
||||
booktitle = {International Conference on Learning Representations},
|
||||
year = {2022},
|
||||
url = {https://openreview.net/forum?id=oapKSVM2bcj}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Sunkara2022NoMS,
|
||||
title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
|
||||
author = {Raja Sunkara and Tie Luo},
|
||||
journal = {ArXiv},
|
||||
year = {2022},
|
||||
volume = {abs/2208.03641}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -250,15 +250,9 @@ class XClipAdapter(BaseClipAdapter):
|
||||
text = text[..., :self.max_text_len]
|
||||
text_mask = text != 0
|
||||
encoder_output = self.clip.text_transformer(text)
|
||||
|
||||
encoder_output_is_cls = encoder_output.ndim == 3
|
||||
|
||||
text_cls, text_encodings = (encoder_output[:, 0], encoder_output[:, 1:]) if encoder_output_is_cls else (encoder_output, None)
|
||||
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||
text_embed = self.clip.to_text_latent(text_cls)
|
||||
|
||||
if exists(text_encodings):
|
||||
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
||||
|
||||
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
||||
return EmbeddedText(l2norm(text_embed), text_encodings)
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -879,8 +873,6 @@ class Attention(nn.Module):
|
||||
# attention
|
||||
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
attn = attn.type(sim.dtype)
|
||||
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
@@ -978,10 +970,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
# dalle1 learned padding strategy
|
||||
|
||||
self.max_text_len = max_text_len
|
||||
|
||||
self.null_text_encodings = nn.Parameter(torch.randn(1, max_text_len, dim))
|
||||
self.null_text_embeds = nn.Parameter(torch.randn(1, num_text_embeds, dim))
|
||||
self.null_image_embed = nn.Parameter(torch.randn(1, dim))
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
|
||||
|
||||
# whether to use self conditioning, Hinton's group's new ddpm technique
|
||||
|
||||
@@ -998,7 +987,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
if cond_scale == 1:
|
||||
return logits
|
||||
|
||||
null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1, **kwargs)
|
||||
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
|
||||
return null_logits + (logits - null_logits) * cond_scale
|
||||
|
||||
def forward(
|
||||
@@ -1009,8 +998,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
text_embed,
|
||||
text_encodings = None,
|
||||
self_cond = None,
|
||||
text_cond_drop_prob = 0.,
|
||||
image_cond_drop_prob = 0.
|
||||
cond_drop_prob = 0.
|
||||
):
|
||||
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
||||
|
||||
@@ -1028,14 +1016,6 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
text_embed = self.to_text_embeds(text_embed)
|
||||
image_embed = self.to_image_embeds(image_embed)
|
||||
|
||||
# classifier free guidance masks
|
||||
|
||||
text_keep_mask = prob_mask_like((batch,), 1 - text_cond_drop_prob, device = device)
|
||||
text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
|
||||
|
||||
image_keep_mask = prob_mask_like((batch,), 1 - image_cond_drop_prob, device = device)
|
||||
image_keep_mask = rearrange(image_keep_mask, 'b -> b 1 1')
|
||||
|
||||
# make text encodings optional
|
||||
# although the paper seems to suggest it is present <--
|
||||
|
||||
@@ -1056,39 +1036,32 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
|
||||
mask = F.pad(mask, (0, remainder), value = False)
|
||||
|
||||
# mask out text encodings with null encodings
|
||||
|
||||
null_text_encodings = self.null_text_encodings.to(text_encodings.dtype)
|
||||
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
|
||||
|
||||
text_encodings = torch.where(
|
||||
rearrange(mask, 'b n -> b n 1').clone() & text_keep_mask,
|
||||
rearrange(mask, 'b n -> b n 1').clone(),
|
||||
text_encodings,
|
||||
null_text_encodings
|
||||
)
|
||||
|
||||
# mask out text embeddings with null text embeddings
|
||||
|
||||
null_text_embeds = self.null_text_embeds.to(text_embed.dtype)
|
||||
|
||||
text_embeds = torch.where(
|
||||
text_keep_mask,
|
||||
text_embed,
|
||||
null_text_embeds
|
||||
)
|
||||
|
||||
# mask out image embeddings with null image embeddings
|
||||
# classifier free guidance
|
||||
|
||||
null_image_embed = self.null_image_embed.to(image_embed.dtype)
|
||||
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
|
||||
keep_mask = rearrange(keep_mask, 'b -> b 1')
|
||||
|
||||
image_embed = torch.where(
|
||||
image_keep_mask,
|
||||
image_embed,
|
||||
null_image_embed
|
||||
)
|
||||
mask &= keep_mask
|
||||
|
||||
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
|
||||
|
||||
keep_mask = repeat(keep_mask, 'b 1 -> b n', n = num_text_embeds)
|
||||
mask = torch.cat((mask, keep_mask), dim = 1)
|
||||
|
||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||
# but let's just do it right
|
||||
|
||||
attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds
|
||||
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
time_embed = self.to_time_embeds(diffusion_timesteps)
|
||||
|
||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||
@@ -1126,8 +1099,6 @@ class DiffusionPrior(nn.Module):
|
||||
timesteps = 1000,
|
||||
sample_timesteps = None,
|
||||
cond_drop_prob = 0.,
|
||||
text_cond_drop_prob = None,
|
||||
image_cond_drop_prob = None,
|
||||
loss_type = "l2",
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
@@ -1166,16 +1137,10 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
self.net = net
|
||||
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
|
||||
|
||||
assert net.dim == self.image_embed_dim, f'your diffusion prior network has a dimension of {net.dim}, but you set your image embedding dimension (keyword image_embed_dim) on DiffusionPrior to {self.image_embed_dim}'
|
||||
assert not exists(clip) or clip.dim_latent == self.image_embed_dim, f'you passed in a CLIP to the diffusion prior with latent dimensions of {clip.dim_latent}, but your image embedding dimension (keyword image_embed_dim) for the DiffusionPrior was set to {self.image_embed_dim}'
|
||||
|
||||
self.channels = default(image_channels, lambda: clip.image_channels)
|
||||
|
||||
self.text_cond_drop_prob = default(text_cond_drop_prob, cond_drop_prob)
|
||||
self.image_cond_drop_prob = default(image_cond_drop_prob, cond_drop_prob)
|
||||
|
||||
self.can_classifier_guidance = self.text_cond_drop_prob > 0. and self.image_cond_drop_prob > 0.
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
self.can_classifier_guidance = cond_drop_prob > 0.
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
@@ -1259,7 +1224,7 @@ class DiffusionPrior(nn.Module):
|
||||
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
|
||||
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
|
||||
|
||||
times = torch.linspace(-1., total_timesteps, steps = timesteps + 1)[:-1]
|
||||
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
||||
|
||||
times = list(reversed(times.int().tolist()))
|
||||
time_pairs = list(zip(times[:-1], times[1:]))
|
||||
@@ -1294,10 +1259,6 @@ class DiffusionPrior(nn.Module):
|
||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||
x_start = self.l2norm_clamp_embed(x_start)
|
||||
|
||||
if time_next < 0:
|
||||
image_embed = x_start
|
||||
continue
|
||||
|
||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
|
||||
@@ -1339,8 +1300,7 @@ class DiffusionPrior(nn.Module):
|
||||
image_embed_noisy,
|
||||
times,
|
||||
self_cond = self_cond,
|
||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||
image_cond_drop_prob = self.image_cond_drop_prob,
|
||||
cond_drop_prob = self.cond_drop_prob,
|
||||
**text_cond
|
||||
)
|
||||
|
||||
@@ -1487,20 +1447,19 @@ class PixelShuffleUpsample(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
def Downsample(dim, dim_out = None):
|
||||
# https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample
|
||||
# named SP-conv in the paper, but basically a pixel unshuffle
|
||||
def Downsample(dim, *, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
return nn.Sequential(
|
||||
Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
|
||||
nn.Conv2d(dim * 4, dim_out, 1)
|
||||
)
|
||||
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
||||
|
||||
class WeightStandardizedConv2d(nn.Conv2d):
|
||||
"""
|
||||
https://arxiv.org/abs/1903.10520
|
||||
weight standardization purportedly works synergistically with group normalization
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
|
||||
|
||||
@@ -1676,7 +1635,6 @@ class CrossAttention(nn.Module):
|
||||
sim = sim.masked_fill(~mask, max_neg_value)
|
||||
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
attn = attn.type(sim.dtype)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
@@ -2849,13 +2807,12 @@ class Decoder(nn.Module):
|
||||
inpaint_mask = None,
|
||||
inpaint_resample_times = 5
|
||||
):
|
||||
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod, self.ddim_sampling_eta
|
||||
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
|
||||
|
||||
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
||||
|
||||
times = list(reversed(times.int().tolist()))
|
||||
time_pairs = list(zip(times[:-1], times[1:]))
|
||||
time_pairs = list(filter(lambda t: t[0] > t[1], time_pairs))
|
||||
|
||||
is_inpaint = exists(inpaint_image)
|
||||
resample_times = inpaint_resample_times if is_inpaint else 1
|
||||
|
||||
@@ -241,7 +241,7 @@ class DecoderConfig(BaseModel):
|
||||
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
||||
channels: int = 3
|
||||
timesteps: int = 1000
|
||||
sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None
|
||||
sample_timesteps: Optional[SingularOrIterable[int]] = None
|
||||
loss_type: str = 'l2'
|
||||
beta_schedule: ListOrTuple[str] = None # None means all cosine
|
||||
learned_variance: SingularOrIterable[bool] = True
|
||||
|
||||
@@ -9,7 +9,7 @@ from collections.abc import Iterable
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
@@ -181,8 +181,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
eps = 1e-6,
|
||||
max_grad_norm = None,
|
||||
group_wd_params = True,
|
||||
warmup_steps = None,
|
||||
cosine_decay_max_steps = None,
|
||||
warmup_steps = 1,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -234,11 +233,8 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
**self.optim_kwargs,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if exists(cosine_decay_max_steps):
|
||||
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
|
||||
else:
|
||||
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
||||
|
||||
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
||||
|
||||
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
|
||||
|
||||
@@ -275,7 +271,6 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
# FIXME: LambdaLR can't be saved due to pickling issues
|
||||
save_obj = dict(
|
||||
optimizer = self.optimizer.state_dict(),
|
||||
scheduler = self.scheduler.state_dict(),
|
||||
warmup_scheduler = self.warmup_scheduler,
|
||||
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
|
||||
version = version.parse(__version__),
|
||||
@@ -322,9 +317,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
# unwrap the model when loading from checkpoint
|
||||
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
|
||||
|
||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
self.scheduler.load_state_dict(loaded_obj['scheduler'])
|
||||
|
||||
# set warmupstep
|
||||
if exists(self.warmup_scheduler):
|
||||
@@ -357,8 +350,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
|
||||
if not self.accelerator.optimizer_step_was_skipped:
|
||||
sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
|
||||
with sched_context():
|
||||
with self.warmup_scheduler.dampening():
|
||||
self.scheduler.step()
|
||||
|
||||
if self.use_ema:
|
||||
@@ -441,7 +433,6 @@ class DecoderTrainer(nn.Module):
|
||||
wd = 1e-2,
|
||||
eps = 1e-8,
|
||||
warmup_steps = None,
|
||||
cosine_decay_max_steps = None,
|
||||
max_grad_norm = 0.5,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
@@ -463,7 +454,7 @@ class DecoderTrainer(nn.Module):
|
||||
# be able to finely customize learning rate, weight decay
|
||||
# per unet
|
||||
|
||||
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
|
||||
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
|
||||
|
||||
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||
|
||||
@@ -471,7 +462,7 @@ class DecoderTrainer(nn.Module):
|
||||
schedulers = []
|
||||
warmup_schedulers = []
|
||||
|
||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
|
||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
||||
if isinstance(unet, nn.Identity):
|
||||
optimizers.append(None)
|
||||
schedulers.append(None)
|
||||
@@ -487,11 +478,7 @@ class DecoderTrainer(nn.Module):
|
||||
)
|
||||
|
||||
optimizers.append(optimizer)
|
||||
|
||||
if exists(unet_cosine_decay_max_steps):
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
|
||||
else:
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
|
||||
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||
warmup_schedulers.append(warmup_scheduler)
|
||||
@@ -571,15 +558,9 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
optimizer_key = f'optim{ind}'
|
||||
scheduler_key = f'sched{ind}'
|
||||
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
scheduler = getattr(self, scheduler_key)
|
||||
|
||||
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
|
||||
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
|
||||
|
||||
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
|
||||
state_dict = optimizer.state_dict() if optimizer is not None else None
|
||||
save_obj = {**save_obj, optimizer_key: state_dict}
|
||||
|
||||
if self.use_ema:
|
||||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||
@@ -600,18 +581,10 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
|
||||
scheduler_key = f'sched{ind}'
|
||||
scheduler = getattr(self, scheduler_key)
|
||||
|
||||
warmup_scheduler = self.warmup_schedulers[ind]
|
||||
|
||||
if exists(optimizer):
|
||||
if optimizer is not None:
|
||||
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if exists(scheduler):
|
||||
scheduler.load_state_dict(loaded_obj[scheduler_key])
|
||||
|
||||
if exists(warmup_scheduler):
|
||||
warmup_scheduler.last_step = last_step
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '1.10.3'
|
||||
__version__ = '1.7.0'
|
||||
|
||||
@@ -134,7 +134,7 @@ def get_example_data(dataloader, device, n=5):
|
||||
break
|
||||
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
|
||||
|
||||
def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
|
||||
def generate_samples(trainer, example_data, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
|
||||
"""
|
||||
Takes example data and generates images from the embeddings
|
||||
Returns three lists: real images, generated images, and captions
|
||||
@@ -144,9 +144,7 @@ def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=No
|
||||
if img_embeddings[0] is None:
|
||||
# Generate image embeddings from clip
|
||||
imgs_tensor = torch.stack(real_images)
|
||||
assert clip is not None, "clip is None, but img_embeddings is None"
|
||||
imgs_tensor.to(device=device)
|
||||
img_embeddings, img_encoding = clip.embed_image(imgs_tensor)
|
||||
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
|
||||
sample_params["image_embed"] = img_embeddings
|
||||
else:
|
||||
# Then we are using precomputed image embeddings
|
||||
@@ -155,10 +153,8 @@ def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=No
|
||||
if condition_on_text_encodings:
|
||||
if text_embeddings[0] is None:
|
||||
# Generate text embeddings from text
|
||||
assert clip is not None, "clip is None, but text_embeddings is None"
|
||||
tokenized_texts = tokenize(txts, truncate=True)
|
||||
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||
sample_params["text_encodings"] = text_encodings
|
||||
sample_params["text"] = tokenized_texts
|
||||
else:
|
||||
# Then we are using precomputed text embeddings
|
||||
text_embeddings = torch.stack(text_embeddings)
|
||||
@@ -170,7 +166,7 @@ def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=No
|
||||
sample_params["image"] = torch.stack(real_images)
|
||||
if device is not None:
|
||||
sample_params["_device"] = device
|
||||
samples = trainer.sample(**sample_params, _cast_deepspeed_precision=False) # At sampling time we don't want to cast to FP16
|
||||
samples = trainer.sample(**sample_params)
|
||||
generated_images = list(samples)
|
||||
captions = [text_prepend + txt for txt in txts]
|
||||
if match_image_size:
|
||||
@@ -178,15 +174,15 @@ def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=No
|
||||
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
|
||||
return real_images, generated_images, captions
|
||||
|
||||
def generate_grid_samples(trainer, examples, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
|
||||
def generate_grid_samples(trainer, examples, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
|
||||
"""
|
||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||
"""
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
|
||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||
return grid_images, captions
|
||||
|
||||
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=None, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
"""
|
||||
Computes evaluation metrics for the decoder
|
||||
"""
|
||||
@@ -196,7 +192,7 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=Non
|
||||
if len(examples) == 0:
|
||||
print("No data to evaluate. Check that your dataloader has shards.")
|
||||
return metrics
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
|
||||
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||
@@ -269,7 +265,6 @@ def train(
|
||||
accelerator: Accelerator,
|
||||
tracker: Tracker,
|
||||
inference_device,
|
||||
clip=None,
|
||||
evaluate_config=None,
|
||||
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
||||
validation_samples = None,
|
||||
@@ -376,19 +371,15 @@ def train(
|
||||
forward_params['image_embed'] = img_emb
|
||||
else:
|
||||
# Forward pass automatically generates embedding
|
||||
assert clip is not None
|
||||
img_embed, img_encoding = clip.embed_image(img)
|
||||
forward_params['image_embed'] = img_embed
|
||||
pass
|
||||
if condition_on_text_encodings:
|
||||
if has_text_embedding:
|
||||
forward_params['text_encodings'] = text_emb
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
assert clip is not None
|
||||
tokenized_texts = tokenize(txt, truncate=True).to(inference_device)
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||
forward_params['text_encodings'] = text_encodings
|
||||
forward_params['text'] = tokenized_texts
|
||||
loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device)
|
||||
trainer.update(unet_number=unet)
|
||||
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
||||
@@ -428,7 +419,7 @@ def train(
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
|
||||
if exists(n_sample_images) and n_sample_images > 0:
|
||||
trainer.eval()
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
if epoch_samples is not None and sample >= epoch_samples:
|
||||
@@ -471,19 +462,15 @@ def train(
|
||||
forward_params['image_embed'] = img_emb.float()
|
||||
else:
|
||||
# Forward pass automatically generates embedding
|
||||
assert clip is not None
|
||||
img_embed, img_encoding = clip.embed_image(img)
|
||||
forward_params['image_embed'] = img_embed
|
||||
pass
|
||||
if condition_on_text_encodings:
|
||||
if has_text_embedding:
|
||||
forward_params['text_encodings'] = text_emb.float()
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
assert clip is not None
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||
forward_params['text_encodings'] = text_encodings
|
||||
forward_params['text'] = tokenized_texts
|
||||
loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device)
|
||||
average_val_loss_tensor[0, unet-1] += loss
|
||||
|
||||
@@ -511,7 +498,7 @@ def train(
|
||||
if next_task == 'eval':
|
||||
if exists(evaluate_config):
|
||||
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, clip=clip, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
|
||||
if is_master:
|
||||
tracker.log(evaluation, step=step())
|
||||
next_task = 'sample'
|
||||
@@ -522,8 +509,8 @@ def train(
|
||||
# Generate examples and save the model if we are the master
|
||||
# Generate sample images
|
||||
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
||||
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
@@ -545,7 +532,6 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_
|
||||
"NumProcesses": accelerator.num_processes,
|
||||
"MixedPrecision": accelerator.mixed_precision
|
||||
}
|
||||
accelerator.wait_for_everyone() # If nodes arrive at this point at different times they might try to autoresume the current run which makes no sense and will cause errors
|
||||
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
|
||||
tracker.save_config(config_path, config_name='decoder_config.json')
|
||||
tracker.add_save_metadata(state_dict_key='config', metadata=config.dict())
|
||||
@@ -569,6 +555,10 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
# If we are in deepspeed fp16 mode, we must ensure learned variance is off
|
||||
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
|
||||
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
|
||||
|
||||
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
|
||||
# This is an invalid configuration until we figure out how to handle this
|
||||
raise ValueError("DeepSpeed does not support multi-node distributed training")
|
||||
|
||||
# Set up data
|
||||
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
||||
@@ -589,11 +579,6 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
seed = config.seed,
|
||||
)
|
||||
|
||||
# If clip is in the model, we need to remove it for compatibility with deepspeed
|
||||
clip = None
|
||||
if config.decoder.clip is not None:
|
||||
clip = config.decoder.clip.create() # Of course we keep it to use it during training, just not in the decoder as that causes issues
|
||||
config.decoder.clip = None
|
||||
# Create the decoder model and print basic info
|
||||
decoder = config.decoder.create()
|
||||
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
|
||||
@@ -605,7 +590,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
has_text_embeddings = config.data.text_embeddings_url is not None
|
||||
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
|
||||
|
||||
has_clip_model = clip is not None
|
||||
has_clip_model = config.decoder.clip is not None
|
||||
data_source_string = ""
|
||||
|
||||
if has_img_embeddings:
|
||||
@@ -630,7 +615,6 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
|
||||
|
||||
train(dataloaders, decoder, accelerator,
|
||||
clip=clip,
|
||||
tracker=tracker,
|
||||
inference_device=accelerator.device,
|
||||
evaluate_config=config.evaluate,
|
||||
|
||||
Reference in New Issue
Block a user