Compare commits

...

5 Commits

3 changed files with 39 additions and 12 deletions

View File

@@ -29,6 +29,9 @@ from x_clip import CLIP
def exists(val):
return val is not None
def identity(t, *args, **kwargs):
return t
def default(val, d):
if exists(val):
return val
@@ -596,7 +599,7 @@ class CausalTransformer(nn.Module):
dim_head = 64,
heads = 8,
ff_mult = 4,
norm_out = False,
norm_out = True,
attn_dropout = 0.,
ff_dropout = 0.,
final_proj = True
@@ -635,12 +638,14 @@ class DiffusionPriorNetwork(nn.Module):
self,
dim,
num_timesteps = None,
l2norm_output = False, # whether to restrict image embedding output with l2norm at the end (may make it easier to learn?)
**kwargs
):
super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
self.l2norm_output = l2norm_output
def forward_with_cond_scale(
self,
@@ -719,7 +724,8 @@ class DiffusionPriorNetwork(nn.Module):
pred_image_embed = tokens[..., -1, :]
return pred_image_embed
output_fn = l2norm if self.l2norm_output else identity
return output_fn(pred_image_embed)
class DiffusionPrior(BaseGaussianDiffusion):
def __init__(

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.90',
version = '0.0.92',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',

View File

@@ -36,7 +36,7 @@ def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_t
avg_loss = (total_loss / total_samples)
wandb.log({f'{phase} {loss_type}': avg_loss})
def save_model(save_path,state_dict):
def save_model(save_path, state_dict):
# Saving State Dict
print("====================================== Saving checkpoint ======================================")
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
@@ -53,6 +53,7 @@ def train(image_embed_dim,
clip,
dp_condition_on_text_encodings,
dp_timesteps,
dp_l2norm_output,
dp_cond_drop_prob,
dpn_depth,
dpn_dim_head,
@@ -62,14 +63,16 @@ def train(image_embed_dim,
device,
learning_rate=0.001,
max_grad_norm=0.5,
weight_decay=0.01):
weight_decay=0.01,
amp=False):
# DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork(
dim = image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads).to(device)
heads = dpn_heads,
l2norm_output = dp_l2norm_output).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(
@@ -92,6 +95,7 @@ def train(image_embed_dim,
os.makedirs(save_path)
### Training code ###
scaler = GradScaler(enabled=amp)
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
epochs = num_epochs
@@ -108,23 +112,33 @@ def train(image_embed_dim,
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
optimizer.zero_grad()
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
loss.backward()
with autocast(enabled=amp):
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
scaler.scale(loss).backward()
# Samples per second
step+=1
samples_per_sec = batch_size*step/(time.time()-t)
# Save checkpoint every save_interval minutes
if(int(time.time()-t) >= 60*save_interval):
t = time.time()
save_model(save_path,diffusion_prior.state_dict())
save_model(
save_path,
dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict()))
# Log to wandb
wandb.log({"Training loss": loss.item(),
"Steps": step,
"Samples per second": samples_per_sec})
scaler.unscale_(optimizer)
nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
optimizer.step()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
### Evaluate model(validation run) ###
start = train_set_size
@@ -168,15 +182,19 @@ def main():
# DiffusionPrior(dp) parameters
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
parser.add_argument("--dp-timesteps", type=int, default=100)
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
parser.add_argument("--dp-loss-type", type=str, default="l2")
parser.add_argument("--clip", type=str, default=None)
parser.add_argument("--amp", type=bool, default=False)
# Model checkpointing interval(minutes)
parser.add_argument("--save-interval", type=int, default=30)
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
args = parser.parse_args()
print("Setting up wandb logging... Please wait...")
wandb.init(
entity=args.wandb_entity,
project=args.wandb_project,
@@ -186,6 +204,7 @@ def main():
"dataset": args.wandb_dataset,
"epochs": args.num_epochs,
})
print("wandb logging setup done!")
# Obtain the utilized device.
@@ -207,6 +226,7 @@ def main():
args.clip,
args.dp_condition_on_text_encodings,
args.dp_timesteps,
args.dp_l2norm_output,
args.dp_cond_drop_prob,
args.dpn_depth,
args.dpn_dim_head,
@@ -216,7 +236,8 @@ def main():
device,
args.learning_rate,
args.max_grad_norm,
args.weight_decay)
args.weight_decay,
args.amp)
if __name__ == "__main__":
main()