mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
70282de23b | ||
|
|
83f761847e | ||
|
|
11469dc0c6 | ||
|
|
2d25c89f35 | ||
|
|
3fe96c208a | ||
|
|
0fc6c9cdf3 | ||
|
|
7ee0ecc388 | ||
|
|
1924c7cc3d | ||
|
|
f7df3caaf3 |
11
README.md
11
README.md
@@ -831,6 +831,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -896,4 +897,14 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Shleifer2021NormFormerIT,
|
||||
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
|
||||
author = {Sam Shleifer and Jason Weston and Myle Ott},
|
||||
journal = {ArXiv},
|
||||
year = {2021},
|
||||
volume = {abs/2110.09456}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
||||
|
||||
@@ -29,6 +29,9 @@ from x_clip import CLIP
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def identity(t, *args, **kwargs):
|
||||
return t
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
@@ -496,7 +499,12 @@ class SwiGLU(nn.Module):
|
||||
x, gate = x.chunk(2, dim = -1)
|
||||
return x * F.silu(gate)
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
|
||||
def FeedForward(
|
||||
dim,
|
||||
mult = 4,
|
||||
dropout = 0.,
|
||||
post_activation_norm = False
|
||||
):
|
||||
""" post-activation norm https://arxiv.org/abs/2110.09456 """
|
||||
|
||||
inner_dim = int(mult * dim)
|
||||
@@ -519,7 +527,8 @@ class Attention(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
causal = False
|
||||
causal = False,
|
||||
post_norm = False
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
@@ -534,7 +543,11 @@ class Attention(nn.Module):
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim, bias = False),
|
||||
LayerNorm(dim) if post_norm else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, mask = None, attn_bias = None):
|
||||
b, n, device = *x.shape[:2], x.device
|
||||
@@ -596,10 +609,11 @@ class CausalTransformer(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
ff_mult = 4,
|
||||
norm_out = False,
|
||||
norm_out = True,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
final_proj = True
|
||||
final_proj = True,
|
||||
normformer = False
|
||||
):
|
||||
super().__init__()
|
||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||
@@ -607,8 +621,8 @@ class CausalTransformer(nn.Module):
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout),
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||
]))
|
||||
|
||||
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||
@@ -635,12 +649,14 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
self,
|
||||
dim,
|
||||
num_timesteps = None,
|
||||
l2norm_output = False, # whether to restrict image embedding output with l2norm at the end (may make it easier to learn?)
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||
self.l2norm_output = l2norm_output
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
@@ -719,7 +735,8 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
pred_image_embed = tokens[..., -1, :]
|
||||
|
||||
return pred_image_embed
|
||||
output_fn = l2norm if self.l2norm_output else identity
|
||||
return output_fn(pred_image_embed)
|
||||
|
||||
class DiffusionPrior(BaseGaussianDiffusion):
|
||||
def __init__(
|
||||
|
||||
@@ -159,12 +159,13 @@ class DecoderTrainer(nn.Module):
|
||||
index = unet_number - 1
|
||||
unet = self.decoder.unets[index]
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
||||
|
||||
optimizer = getattr(self, f'optim{index}')
|
||||
scaler = getattr(self, f'scaler{index}')
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
scaler.unscale_(optimizer)
|
||||
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.89',
|
||||
version = '0.0.93',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -17,16 +17,26 @@ os.environ["WANDB_SILENT"] = "true"
|
||||
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
|
||||
total_loss = 0.
|
||||
total_samples = 0.
|
||||
|
||||
for emb_images, emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
|
||||
text_reader(batch_size=batch_size, start=start, end=end)):
|
||||
|
||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||
|
||||
batches = emb_images_tensor.shape[0]
|
||||
|
||||
loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
|
||||
|
||||
# Log to wandb
|
||||
wandb.log({f'{phase} {loss_type}': loss})
|
||||
total_loss += loss.item() * batches
|
||||
total_samples += batches
|
||||
|
||||
def save_model(save_path,state_dict):
|
||||
avg_loss = (total_loss / total_samples)
|
||||
wandb.log({f'{phase} {loss_type}': avg_loss})
|
||||
|
||||
def save_model(save_path, state_dict):
|
||||
# Saving State Dict
|
||||
print("====================================== Saving checkpoint ======================================")
|
||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||
@@ -43,6 +53,7 @@ def train(image_embed_dim,
|
||||
clip,
|
||||
dp_condition_on_text_encodings,
|
||||
dp_timesteps,
|
||||
dp_l2norm_output,
|
||||
dp_cond_drop_prob,
|
||||
dpn_depth,
|
||||
dpn_dim_head,
|
||||
@@ -52,14 +63,16 @@ def train(image_embed_dim,
|
||||
device,
|
||||
learning_rate=0.001,
|
||||
max_grad_norm=0.5,
|
||||
weight_decay=0.01):
|
||||
weight_decay=0.01,
|
||||
amp=False):
|
||||
|
||||
# DiffusionPriorNetwork
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = image_embed_dim,
|
||||
depth = dpn_depth,
|
||||
dim_head = dpn_dim_head,
|
||||
heads = dpn_heads).to(device)
|
||||
heads = dpn_heads,
|
||||
l2norm_output = dp_l2norm_output).to(device)
|
||||
|
||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||
diffusion_prior = DiffusionPrior(
|
||||
@@ -82,6 +95,7 @@ def train(image_embed_dim,
|
||||
os.makedirs(save_path)
|
||||
|
||||
### Training code ###
|
||||
scaler = GradScaler(enabled=amp)
|
||||
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
||||
epochs = num_epochs
|
||||
|
||||
@@ -98,23 +112,33 @@ def train(image_embed_dim,
|
||||
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
|
||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||
optimizer.zero_grad()
|
||||
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
|
||||
loss.backward()
|
||||
|
||||
with autocast(enabled=amp):
|
||||
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# Samples per second
|
||||
step+=1
|
||||
samples_per_sec = batch_size*step/(time.time()-t)
|
||||
# Save checkpoint every save_interval minutes
|
||||
if(int(time.time()-t) >= 60*save_interval):
|
||||
t = time.time()
|
||||
save_model(save_path,diffusion_prior.state_dict())
|
||||
|
||||
save_model(
|
||||
save_path,
|
||||
dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict()))
|
||||
|
||||
# Log to wandb
|
||||
wandb.log({"Training loss": loss.item(),
|
||||
"Steps": step,
|
||||
"Samples per second": samples_per_sec})
|
||||
|
||||
scaler.unscale_(optimizer)
|
||||
nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
|
||||
optimizer.step()
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
### Evaluate model(validation run) ###
|
||||
start = train_set_size
|
||||
@@ -158,15 +182,19 @@ def main():
|
||||
# DiffusionPrior(dp) parameters
|
||||
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
|
||||
parser.add_argument("--dp-timesteps", type=int, default=100)
|
||||
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
|
||||
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
|
||||
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
||||
parser.add_argument("--clip", type=str, default=None)
|
||||
parser.add_argument("--amp", type=bool, default=False)
|
||||
# Model checkpointing interval(minutes)
|
||||
parser.add_argument("--save-interval", type=int, default=30)
|
||||
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Setting up wandb logging... Please wait...")
|
||||
|
||||
wandb.init(
|
||||
entity=args.wandb_entity,
|
||||
project=args.wandb_project,
|
||||
@@ -176,6 +204,7 @@ def main():
|
||||
"dataset": args.wandb_dataset,
|
||||
"epochs": args.num_epochs,
|
||||
})
|
||||
|
||||
print("wandb logging setup done!")
|
||||
# Obtain the utilized device.
|
||||
|
||||
@@ -197,6 +226,7 @@ def main():
|
||||
args.clip,
|
||||
args.dp_condition_on_text_encodings,
|
||||
args.dp_timesteps,
|
||||
args.dp_l2norm_output,
|
||||
args.dp_cond_drop_prob,
|
||||
args.dpn_depth,
|
||||
args.dpn_dim_head,
|
||||
@@ -206,7 +236,8 @@ def main():
|
||||
device,
|
||||
args.learning_rate,
|
||||
args.max_grad_norm,
|
||||
args.weight_decay)
|
||||
args.weight_decay,
|
||||
args.amp)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user