Compare commits

...

11 Commits

6 changed files with 98 additions and 25 deletions

View File

@@ -830,6 +830,8 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
## Citations
@@ -895,4 +897,14 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@article{Shleifer2021NormFormerIT,
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
author = {Sam Shleifer and Jason Weston and Myle Ott},
journal = {ArXiv},
year = {2021},
volume = {abs/2110.09456}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>

View File

@@ -29,6 +29,9 @@ from x_clip import CLIP
def exists(val):
return val is not None
def identity(t, *args, **kwargs):
return t
def default(val, d):
if exists(val):
return val
@@ -496,7 +499,12 @@ class SwiGLU(nn.Module):
x, gate = x.chunk(2, dim = -1)
return x * F.silu(gate)
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
def FeedForward(
dim,
mult = 4,
dropout = 0.,
post_activation_norm = False
):
""" post-activation norm https://arxiv.org/abs/2110.09456 """
inner_dim = int(mult * dim)
@@ -519,7 +527,8 @@ class Attention(nn.Module):
dim_head = 64,
heads = 8,
dropout = 0.,
causal = False
causal = False,
post_norm = False
):
super().__init__()
self.scale = dim_head ** -0.5
@@ -534,7 +543,11 @@ class Attention(nn.Module):
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity()
)
def forward(self, x, mask = None, attn_bias = None):
b, n, device = *x.shape[:2], x.device
@@ -596,10 +609,11 @@ class CausalTransformer(nn.Module):
dim_head = 64,
heads = 8,
ff_mult = 4,
norm_out = False,
norm_out = True,
attn_dropout = 0.,
ff_dropout = 0.,
final_proj = True
final_proj = True,
normformer = False
):
super().__init__()
self.rel_pos_bias = RelPosBias(heads = heads)
@@ -607,8 +621,8 @@ class CausalTransformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
@@ -635,12 +649,14 @@ class DiffusionPriorNetwork(nn.Module):
self,
dim,
num_timesteps = None,
l2norm_output = False, # whether to restrict image embedding output with l2norm at the end (may make it easier to learn?)
**kwargs
):
super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
self.l2norm_output = l2norm_output
def forward_with_cond_scale(
self,
@@ -719,7 +735,8 @@ class DiffusionPriorNetwork(nn.Module):
pred_image_embed = tokens[..., -1, :]
return pred_image_embed
output_fn = l2norm if self.l2norm_output else identity
return output_fn(pred_image_embed)
class DiffusionPrior(BaseGaussianDiffusion):
def __init__(

View File

@@ -159,12 +159,13 @@ class DecoderTrainer(nn.Module):
index = unet_number - 1
unet = self.decoder.unets[index]
if exists(self.max_grad_norm):
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
optimizer = getattr(self, f'optim{index}')
scaler = getattr(self, f'scaler{index}')
if exists(self.max_grad_norm):
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

View File

@@ -285,6 +285,10 @@ class ResnetEncDec(nn.Module):
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
@property
def last_dec_layer(self):
return self.decoders[-1].weight
def encode(self, x):
for enc in self.encoders:
x = enc(x)
@@ -419,6 +423,10 @@ class ConvNextEncDec(nn.Module):
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
@property
def last_dec_layer(self):
return self.decoders[-1].weight
def encode(self, x):
for enc in self.encoders:
x = enc(x)
@@ -606,6 +614,10 @@ class ViTEncDec(nn.Module):
def get_encoded_fmap_size(self, image_size):
return image_size // self.patch_size
@property
def last_dec_layer(self):
return self.decoder[-3][-1].weight
def encode(self, x):
return self.encoder(x)
@@ -843,7 +855,7 @@ class VQGanVAE(nn.Module):
# calculate adaptive weight
last_dec_layer = self.decoders[-1].weight
last_dec_layer = self.enc_dec.last_dec_layer
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.88',
version = '0.0.93',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',

View File

@@ -17,16 +17,26 @@ os.environ["WANDB_SILENT"] = "true"
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
model.eval()
with torch.no_grad():
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
total_loss = 0.
total_samples = 0.
for emb_images, emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
text_reader(batch_size=batch_size, start=start, end=end)):
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
batches = emb_images_tensor.shape[0]
loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
# Log to wandb
wandb.log({f'{phase} {loss_type}': loss})
total_loss += loss.item() * batches
total_samples += batches
def save_model(save_path,state_dict):
avg_loss = (total_loss / total_samples)
wandb.log({f'{phase} {loss_type}': avg_loss})
def save_model(save_path, state_dict):
# Saving State Dict
print("====================================== Saving checkpoint ======================================")
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
@@ -43,6 +53,7 @@ def train(image_embed_dim,
clip,
dp_condition_on_text_encodings,
dp_timesteps,
dp_l2norm_output,
dp_cond_drop_prob,
dpn_depth,
dpn_dim_head,
@@ -52,14 +63,16 @@ def train(image_embed_dim,
device,
learning_rate=0.001,
max_grad_norm=0.5,
weight_decay=0.01):
weight_decay=0.01,
amp=False):
# DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork(
dim = image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads).to(device)
heads = dpn_heads,
l2norm_output = dp_l2norm_output).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(
@@ -82,6 +95,7 @@ def train(image_embed_dim,
os.makedirs(save_path)
### Training code ###
scaler = GradScaler(enabled=amp)
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
epochs = num_epochs
@@ -98,23 +112,33 @@ def train(image_embed_dim,
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
optimizer.zero_grad()
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
loss.backward()
with autocast(enabled=amp):
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
scaler.scale(loss).backward()
# Samples per second
step+=1
samples_per_sec = batch_size*step/(time.time()-t)
# Save checkpoint every save_interval minutes
if(int(time.time()-t) >= 60*save_interval):
t = time.time()
save_model(save_path,diffusion_prior.state_dict())
save_model(
save_path,
dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict()))
# Log to wandb
wandb.log({"Training loss": loss.item(),
"Steps": step,
"Samples per second": samples_per_sec})
scaler.unscale_(optimizer)
nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
optimizer.step()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
### Evaluate model(validation run) ###
start = train_set_size
@@ -158,15 +182,19 @@ def main():
# DiffusionPrior(dp) parameters
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
parser.add_argument("--dp-timesteps", type=int, default=100)
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
parser.add_argument("--dp-loss-type", type=str, default="l2")
parser.add_argument("--clip", type=str, default=None)
parser.add_argument("--amp", type=bool, default=False)
# Model checkpointing interval(minutes)
parser.add_argument("--save-interval", type=int, default=30)
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
args = parser.parse_args()
print("Setting up wandb logging... Please wait...")
wandb.init(
entity=args.wandb_entity,
project=args.wandb_project,
@@ -176,6 +204,7 @@ def main():
"dataset": args.wandb_dataset,
"epochs": args.num_epochs,
})
print("wandb logging setup done!")
# Obtain the utilized device.
@@ -197,6 +226,7 @@ def main():
args.clip,
args.dp_condition_on_text_encodings,
args.dp_timesteps,
args.dp_l2norm_output,
args.dp_cond_drop_prob,
args.dpn_depth,
args.dpn_dim_head,
@@ -206,7 +236,8 @@ def main():
device,
args.learning_rate,
args.max_grad_norm,
args.weight_decay)
args.weight_decay,
args.amp)
if __name__ == "__main__":
main()