mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-15 14:54:22 +01:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
11469dc0c6 | ||
|
|
2d25c89f35 | ||
|
|
3fe96c208a | ||
|
|
0fc6c9cdf3 | ||
|
|
7ee0ecc388 | ||
|
|
1924c7cc3d | ||
|
|
f7df3caaf3 |
@@ -29,6 +29,9 @@ from x_clip import CLIP
|
|||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
|
def identity(t, *args, **kwargs):
|
||||||
|
return t
|
||||||
|
|
||||||
def default(val, d):
|
def default(val, d):
|
||||||
if exists(val):
|
if exists(val):
|
||||||
return val
|
return val
|
||||||
@@ -596,7 +599,7 @@ class CausalTransformer(nn.Module):
|
|||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
heads = 8,
|
heads = 8,
|
||||||
ff_mult = 4,
|
ff_mult = 4,
|
||||||
norm_out = False,
|
norm_out = True,
|
||||||
attn_dropout = 0.,
|
attn_dropout = 0.,
|
||||||
ff_dropout = 0.,
|
ff_dropout = 0.,
|
||||||
final_proj = True
|
final_proj = True
|
||||||
@@ -635,12 +638,14 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
num_timesteps = None,
|
num_timesteps = None,
|
||||||
|
l2norm_output = False, # whether to restrict image embedding output with l2norm at the end (may make it easier to learn?)
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||||
|
self.l2norm_output = l2norm_output
|
||||||
|
|
||||||
def forward_with_cond_scale(
|
def forward_with_cond_scale(
|
||||||
self,
|
self,
|
||||||
@@ -719,7 +724,8 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
|
|
||||||
pred_image_embed = tokens[..., -1, :]
|
pred_image_embed = tokens[..., -1, :]
|
||||||
|
|
||||||
return pred_image_embed
|
output_fn = l2norm if self.l2norm_output else identity
|
||||||
|
return output_fn(pred_image_embed)
|
||||||
|
|
||||||
class DiffusionPrior(BaseGaussianDiffusion):
|
class DiffusionPrior(BaseGaussianDiffusion):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -159,12 +159,13 @@ class DecoderTrainer(nn.Module):
|
|||||||
index = unet_number - 1
|
index = unet_number - 1
|
||||||
unet = self.decoder.unets[index]
|
unet = self.decoder.unets[index]
|
||||||
|
|
||||||
if exists(self.max_grad_norm):
|
|
||||||
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
|
||||||
|
|
||||||
optimizer = getattr(self, f'optim{index}')
|
optimizer = getattr(self, f'optim{index}')
|
||||||
scaler = getattr(self, f'scaler{index}')
|
scaler = getattr(self, f'scaler{index}')
|
||||||
|
|
||||||
|
if exists(self.max_grad_norm):
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.0.89',
|
version = '0.0.92',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
|
|||||||
@@ -17,16 +17,26 @@ os.environ["WANDB_SILENT"] = "true"
|
|||||||
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
|
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
|
total_loss = 0.
|
||||||
|
total_samples = 0.
|
||||||
|
|
||||||
|
for emb_images, emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
|
||||||
text_reader(batch_size=batch_size, start=start, end=end)):
|
text_reader(batch_size=batch_size, start=start, end=end)):
|
||||||
|
|
||||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||||
|
|
||||||
|
batches = emb_images_tensor.shape[0]
|
||||||
|
|
||||||
loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
|
loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
|
||||||
|
|
||||||
# Log to wandb
|
total_loss += loss.item() * batches
|
||||||
wandb.log({f'{phase} {loss_type}': loss})
|
total_samples += batches
|
||||||
|
|
||||||
def save_model(save_path,state_dict):
|
avg_loss = (total_loss / total_samples)
|
||||||
|
wandb.log({f'{phase} {loss_type}': avg_loss})
|
||||||
|
|
||||||
|
def save_model(save_path, state_dict):
|
||||||
# Saving State Dict
|
# Saving State Dict
|
||||||
print("====================================== Saving checkpoint ======================================")
|
print("====================================== Saving checkpoint ======================================")
|
||||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||||
@@ -43,6 +53,7 @@ def train(image_embed_dim,
|
|||||||
clip,
|
clip,
|
||||||
dp_condition_on_text_encodings,
|
dp_condition_on_text_encodings,
|
||||||
dp_timesteps,
|
dp_timesteps,
|
||||||
|
dp_l2norm_output,
|
||||||
dp_cond_drop_prob,
|
dp_cond_drop_prob,
|
||||||
dpn_depth,
|
dpn_depth,
|
||||||
dpn_dim_head,
|
dpn_dim_head,
|
||||||
@@ -52,14 +63,16 @@ def train(image_embed_dim,
|
|||||||
device,
|
device,
|
||||||
learning_rate=0.001,
|
learning_rate=0.001,
|
||||||
max_grad_norm=0.5,
|
max_grad_norm=0.5,
|
||||||
weight_decay=0.01):
|
weight_decay=0.01,
|
||||||
|
amp=False):
|
||||||
|
|
||||||
# DiffusionPriorNetwork
|
# DiffusionPriorNetwork
|
||||||
prior_network = DiffusionPriorNetwork(
|
prior_network = DiffusionPriorNetwork(
|
||||||
dim = image_embed_dim,
|
dim = image_embed_dim,
|
||||||
depth = dpn_depth,
|
depth = dpn_depth,
|
||||||
dim_head = dpn_dim_head,
|
dim_head = dpn_dim_head,
|
||||||
heads = dpn_heads).to(device)
|
heads = dpn_heads,
|
||||||
|
l2norm_output = dp_l2norm_output).to(device)
|
||||||
|
|
||||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||||
diffusion_prior = DiffusionPrior(
|
diffusion_prior = DiffusionPrior(
|
||||||
@@ -82,6 +95,7 @@ def train(image_embed_dim,
|
|||||||
os.makedirs(save_path)
|
os.makedirs(save_path)
|
||||||
|
|
||||||
### Training code ###
|
### Training code ###
|
||||||
|
scaler = GradScaler(enabled=amp)
|
||||||
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
||||||
epochs = num_epochs
|
epochs = num_epochs
|
||||||
|
|
||||||
@@ -98,23 +112,33 @@ def train(image_embed_dim,
|
|||||||
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
|
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
|
||||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||||
optimizer.zero_grad()
|
|
||||||
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
|
with autocast(enabled=amp):
|
||||||
loss.backward()
|
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
# Samples per second
|
# Samples per second
|
||||||
step+=1
|
step+=1
|
||||||
samples_per_sec = batch_size*step/(time.time()-t)
|
samples_per_sec = batch_size*step/(time.time()-t)
|
||||||
# Save checkpoint every save_interval minutes
|
# Save checkpoint every save_interval minutes
|
||||||
if(int(time.time()-t) >= 60*save_interval):
|
if(int(time.time()-t) >= 60*save_interval):
|
||||||
t = time.time()
|
t = time.time()
|
||||||
save_model(save_path,diffusion_prior.state_dict())
|
|
||||||
|
save_model(
|
||||||
|
save_path,
|
||||||
|
dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict()))
|
||||||
|
|
||||||
# Log to wandb
|
# Log to wandb
|
||||||
wandb.log({"Training loss": loss.item(),
|
wandb.log({"Training loss": loss.item(),
|
||||||
"Steps": step,
|
"Steps": step,
|
||||||
"Samples per second": samples_per_sec})
|
"Samples per second": samples_per_sec})
|
||||||
|
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
|
nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
|
||||||
optimizer.step()
|
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
### Evaluate model(validation run) ###
|
### Evaluate model(validation run) ###
|
||||||
start = train_set_size
|
start = train_set_size
|
||||||
@@ -158,15 +182,19 @@ def main():
|
|||||||
# DiffusionPrior(dp) parameters
|
# DiffusionPrior(dp) parameters
|
||||||
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
|
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
|
||||||
parser.add_argument("--dp-timesteps", type=int, default=100)
|
parser.add_argument("--dp-timesteps", type=int, default=100)
|
||||||
|
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
|
||||||
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
|
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
|
||||||
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
||||||
parser.add_argument("--clip", type=str, default=None)
|
parser.add_argument("--clip", type=str, default=None)
|
||||||
|
parser.add_argument("--amp", type=bool, default=False)
|
||||||
# Model checkpointing interval(minutes)
|
# Model checkpointing interval(minutes)
|
||||||
parser.add_argument("--save-interval", type=int, default=30)
|
parser.add_argument("--save-interval", type=int, default=30)
|
||||||
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
|
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
print("Setting up wandb logging... Please wait...")
|
print("Setting up wandb logging... Please wait...")
|
||||||
|
|
||||||
wandb.init(
|
wandb.init(
|
||||||
entity=args.wandb_entity,
|
entity=args.wandb_entity,
|
||||||
project=args.wandb_project,
|
project=args.wandb_project,
|
||||||
@@ -176,6 +204,7 @@ def main():
|
|||||||
"dataset": args.wandb_dataset,
|
"dataset": args.wandb_dataset,
|
||||||
"epochs": args.num_epochs,
|
"epochs": args.num_epochs,
|
||||||
})
|
})
|
||||||
|
|
||||||
print("wandb logging setup done!")
|
print("wandb logging setup done!")
|
||||||
# Obtain the utilized device.
|
# Obtain the utilized device.
|
||||||
|
|
||||||
@@ -197,6 +226,7 @@ def main():
|
|||||||
args.clip,
|
args.clip,
|
||||||
args.dp_condition_on_text_encodings,
|
args.dp_condition_on_text_encodings,
|
||||||
args.dp_timesteps,
|
args.dp_timesteps,
|
||||||
|
args.dp_l2norm_output,
|
||||||
args.dp_cond_drop_prob,
|
args.dp_cond_drop_prob,
|
||||||
args.dpn_depth,
|
args.dpn_depth,
|
||||||
args.dpn_dim_head,
|
args.dpn_dim_head,
|
||||||
@@ -206,7 +236,8 @@ def main():
|
|||||||
device,
|
device,
|
||||||
args.learning_rate,
|
args.learning_rate,
|
||||||
args.max_grad_norm,
|
args.max_grad_norm,
|
||||||
args.weight_decay)
|
args.weight_decay,
|
||||||
|
args.amp)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user