Compare commits

...

44 Commits

Author SHA1 Message Date
Phil Wang
4b912a38c6 0.26.2 2022-07-19 17:50:36 -07:00
Aidan Dempster
f97e55ec6b Quality of life improvements for tracker savers (#210)
The default save location is now none so if keys are not specified the
corresponding checkpoint type is not saved.

Models and checkpoints are now both saved with version number and the
config used to create them in order to simplify loading.

Documentation was fixed to be in line with current usage.
2022-07-19 17:50:18 -07:00
Phil Wang
291377bb9c @jacobwjs reports dynamic thresholding works very well and 0.95 is a better value 2022-07-19 11:31:56 -07:00
Phil Wang
7f120a8b56 cleanup, CLI no longer necessary since Zion + Aidan have https://github.com/LAION-AI/dalle2-laion and colab notebook going 2022-07-19 09:47:44 -07:00
Phil Wang
8c003ab1e1 readme and citation 2022-07-19 09:36:45 -07:00
Phil Wang
723bf0abba complete inpainting ability using inpaint_image and inpaint_mask passed into sample function for decoder 2022-07-19 09:26:55 -07:00
Phil Wang
d88c7ba56c fix a bug with ddim and predict x0 objective 2022-07-18 19:04:26 -07:00
Phil Wang
3676a8ce78 comments 2022-07-18 15:02:04 -07:00
Phil Wang
da8e99ada0 fix sample bug 2022-07-18 13:50:22 -07:00
Phil Wang
6afb886cf4 complete imagen-like noise level conditioning 2022-07-18 13:43:57 -07:00
Phil Wang
c7fe4f2f44 project management 2022-07-17 17:27:44 -07:00
Phil Wang
a2ee3fa3cc offer way to turn off initial cross embed convolutional module, for debugging upsampler artifacts 2022-07-15 17:29:10 -07:00
Phil Wang
a58a370d75 takes care of a grad strides error at https://github.com/lucidrains/DALLE2-pytorch/issues/196 thanks to @YUHANG-Ma 2022-07-14 15:28:34 -07:00
Phil Wang
1662bbf226 protect against random cropping for base unet 2022-07-14 12:49:43 -07:00
Phil Wang
5be1f57448 update 2022-07-14 12:03:42 -07:00
Phil Wang
c52ce58e10 update 2022-07-14 10:54:51 -07:00
Phil Wang
a34f60962a let the neural network peek at the low resolution conditioning one last time before making prediction, for upsamplers 2022-07-14 10:27:04 -07:00
Phil Wang
0b40cbaa54 just always use nearest neighbor interpolation when resizing for low resolution conditioning, for https://github.com/lucidrains/DALLE2-pytorch/pull/181 2022-07-13 20:59:43 -07:00
Phil Wang
f141144a6d allow for using classifier free guidance for some unets but not others, by passing in a tuple of cond_scale during sampling for decoder, just in case it is causing issues for upsamplers 2022-07-13 13:12:30 -07:00
Phil Wang
f988207718 hack around some inplace error, also make sure for openai clip text encoding, only tokens after eos_id is masked out 2022-07-13 12:56:02 -07:00
Phil Wang
b2073219f0 foolproof sampling for decoder to always use eval mode (and restore training state afterwards) 2022-07-13 10:21:00 -07:00
Phil Wang
cc0f7a935c fix non pixel shuffle upsample 2022-07-13 10:16:02 -07:00
Phil Wang
95a512cb65 fix a potential bug with conditioning with blurred low resolution image, blur should be applied only 50% of the time 2022-07-13 10:11:49 -07:00
Phil Wang
972ee973bc fix issue with ddim and normalization of lowres conditioning image 2022-07-13 09:48:40 -07:00
Phil Wang
79e2a3bc77 only use the stable layernorm for final output norm in transformer 2022-07-13 07:56:30 -07:00
Aidan Dempster
544cdd0b29 Reverted to using basic dataloaders (#205)
Accelerate removes the ability to collate strings. Likely since it
cannot gather strings.
2022-07-12 18:22:27 -07:00
Phil Wang
349aaca56f add yet another transformer stability measure 2022-07-12 17:49:16 -07:00
Phil Wang
3ee3c56d2a add learned padding tokens, same strategy as dalle1, for diffusion prior, and get rid of masking in causal transformer 2022-07-12 17:33:14 -07:00
Phil Wang
cd26c6b17d 0.22.3 2022-07-12 17:08:31 -07:00
Phil Wang
775abc4df6 add setting to attend to all text encodings regardless of padding, for diffusion prior 2022-07-12 17:08:12 -07:00
Phil Wang
11b1d533a0 make sure text encodings being passed in has the correct batch dimension 2022-07-12 16:00:19 -07:00
Phil Wang
e76e89f9eb remove text masking altogether in favor of deriving from text encodings (padded text encodings must be pad value of 0.) 2022-07-12 15:40:31 -07:00
Phil Wang
bb3ff0ac67 protect against bad text mask being passed into decoder 2022-07-12 15:33:13 -07:00
Phil Wang
1ec4dbe64f one more fix for text mask, if the length of the text encoding exceeds max_text_len, add an assert for better error msg 2022-07-12 15:01:46 -07:00
Phil Wang
e0835acca9 generate text mask within the unet and diffusion prior itself from the text encodings, if not given 2022-07-12 12:54:59 -07:00
Phil Wang
e055793e5d shoutout for @MalumaDev 2022-07-11 16:12:35 -07:00
Phil Wang
1d9ef99288 add PixelShuffleUpsample thanks to @MalumaDev and @marunine for running the experiment and verifyng absence of checkboard artifacts 2022-07-11 16:07:23 -07:00
Phil Wang
bdd62c24b3 zero init final projection in unet, since openai and @crowsonkb are both doing it 2022-07-11 13:22:06 -07:00
Phil Wang
1f1557c614 make it so even if text mask is omitted, it will be derived based on whether text encodings are all 0s or not, simplify dataloading 2022-07-11 10:56:19 -07:00
Aidan Dempster
1a217e99e3 Unet parameter count is now shown (#202) 2022-07-10 16:45:59 -07:00
Phil Wang
7ea314e2f0 allow for final l2norm clamping of the sampled image embed 2022-07-10 09:44:38 -07:00
Phil Wang
4173e88121 more accurate readme 2022-07-09 20:57:26 -07:00
Phil Wang
3dae43fa0e fix misnamed variable, thanks to @nousr 2022-07-09 19:01:37 -07:00
Phil Wang
a598820012 do not noise for the last step in ddim 2022-07-09 18:38:40 -07:00
12 changed files with 580 additions and 196 deletions

2
.github/FUNDING.yml vendored
View File

@@ -1 +1 @@
github: [lucidrains]
github: [nousr, Veldrovive, lucidrains]

115
README.md
View File

@@ -45,6 +45,7 @@ This library would not have gotten to this working state without the help of
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/marunine">Marunine</a> for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes
- <a href="https://github.com/malumadev">MalumaDev</a> for proposing the use of pixel shuffle upsampler for fixing checkboard artifacts
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
@@ -355,7 +356,8 @@ prior_network = DiffusionPriorNetwork(
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
timesteps = 1000,
sample_timesteps = 64,
cond_drop_prob = 0.2
).cuda()
@@ -419,7 +421,7 @@ For the layperson, no worries, training will all be automated into a CLI tool, a
## Training on Preprocessed CLIP Embeddings
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings`
Working example below
@@ -626,6 +628,82 @@ images = dalle2(
Now you'll just have to worry about training the Prior and the Decoder!
## Inpainting
Inpainting is also built into the `Decoder`. You simply have to pass in the `inpaint_image` and `inpaint_mask` (boolean tensor where `True` indicates which regions of the inpaint image to keep)
This repository uses the formulation put forth by <a href="https://arxiv.org/abs/2201.09865">Lugmayr et al. in Repaint</a>
```python
import torch
from dalle2_pytorch import Unet, Decoder, CLIP
# trained clip from step 1
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# 2 unets for the decoder (a la cascading DDPM)
unet = Unet(
dim = 16,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 1, 1, 1)
).cuda()
# decoder, which contains the unet(s) and clip
decoder = Decoder(
clip = clip,
unet = (unet,), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256,), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 1000,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()
# mock images (get a lot of this)
images = torch.randn(4, 3, 256, 256).cuda()
# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
loss = decoder(images, unet_number = 1)
loss.backward()
# do the above for many steps for both unets
mock_image_embed = torch.randn(1, 512).cuda()
# then to do inpainting
inpaint_image = torch.randn(1, 3, 256, 256).cuda() # (batch, channels, height, width)
inpaint_mask = torch.ones(1, 256, 256).bool().cuda() # (batch, height, width)
inpainted_images = decoder.sample(
image_embed = mock_image_embed,
inpaint_image = inpaint_image, # just pass in the inpaint image
inpaint_mask = inpaint_mask # and the mask
)
inpainted_images.shape # (1, 3, 256, 256)
```
## Experimental
### DALL-E2 with Latent Diffusion
@@ -989,26 +1067,12 @@ dataset = ImageEmbeddingDataset(
)
```
### Scripts (wip)
### Scripts
#### `train_diffusion_prior.py`
For detailed information on training the diffusion prior, please refer to the [dedicated readme](prior.md)
## CLI (wip)
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
```
Once built, images will be saved to the same directory the command is invoked
<a href="https://github.com/lucidrains/big-sleep">template</a>
## Training CLI (wip)
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
## Todo
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
@@ -1046,11 +1110,10 @@ Once built, images will be saved to the same directory the command is invoked
- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)
- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
- [x] allow for unet to be able to condition non-cross attention style as well
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [x] speed up inference, read up on papers (ddim)
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
- [ ] try out the nested unet from https://arxiv.org/abs/2005.09007 after hearing several positive testimonies from researchers, for segmentation anyhow
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
## Citations
@@ -1168,4 +1231,14 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@article{Lugmayr2022RePaintIU,
title = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models},
author = {Andreas Lugmayr and Martin Danelljan and Andr{\'e}s Romero and Fisher Yu and Radu Timofte and Luc Van Gool},
journal = {ArXiv},
year = {2022},
volume = {abs/2201.09865}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -74,9 +74,6 @@ Settings for controlling the training hyperparameters.
| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |
| `ema_beta` | No | `0.99` | The ema coefficient. |
| `save_all` | No | `False` | If True, preserves a checkpoint for every epoch. |
| `save_latest` | No | `True` | If True, overwrites the `latest.pth` every time the model is saved. |
| `save_best` | No | `True` | If True, overwrites the `best.pth` every time the model has a lower validation loss than all previous models. |
| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |
**<ins>Evaluate</ins>:**
@@ -163,9 +160,10 @@ All save locations have these configuration options
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |
| `save_latest_to` | No | `latest.pth` | Sets the relative path to save the latest model to. |
| `save_best_to` | No | `best.pth` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
| `save_type` | No | `'checkpoint'` | The type of save. `'checkpoint'` saves a checkpoint, `'model'` saves a model without any fluff (Saves with ema if ema is enabled). |
| `save_latest_to` | No | `None` | Sets the relative path to save the latest model to. |
| `save_best_to` | No | `None` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
| `save_meta_to` | No | `None` | The path to save metadata files in. This includes the config files used to start the training. |
| `save_type` | No | `checkpoint` | The type of save. `checkpoint` saves a checkpoint, `model` saves a model without any fluff (Saves with ema if ema is enabled). |
If using `local`
| Option | Required | Default | Description |
@@ -177,7 +175,6 @@ If using `huggingface`
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `huggingface`. |
| `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |
| `huggingface_base_path` | Yes | N/A | The base path that checkpoints will be saved under. |
| `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |
If using `wandb`

View File

@@ -56,9 +56,6 @@
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
@@ -96,14 +93,15 @@
},
"save": [{
"save_to": "wandb"
"save_to": "wandb",
"save_latest_to": "latest.pth"
}, {
"save_to": "huggingface",
"huggingface_repo": "Veldrovive/test_model",
"save_all": true,
"save_latest": true,
"save_best": true,
"save_latest_to": "path/to/model_dir/latest.pth",
"save_best_to": "path/to/model_dir/best.pth",
"save_meta_to": "path/to/directory/for/assorted/files",
"save_type": "model"
}]

View File

@@ -61,9 +61,6 @@
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
@@ -96,7 +93,8 @@
},
"save": [{
"save_to": "local"
"save_to": "local",
"save_latest_to": "latest.pth"
}]
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,13 +4,15 @@ import json
from pathlib import Path
import shutil
from itertools import zip_longest
from typing import Optional, List, Union
from typing import Any, Optional, List, Union
from pydantic import BaseModel
import torch
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.utils import import_or_print_error
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.version import __version__
from packaging import version
# constants
@@ -21,16 +23,6 @@ DEFAULT_DATA_PATH = './.tracker-data'
def exists(val):
return val is not None
# load file functions
def load_wandb_file(run_path, file_path, **kwargs):
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
file_reference = wandb.restore(file_path, run_path=run_path)
return file_reference.name
def load_local_file(file_path, **kwargs):
return file_path
class BaseLogger:
"""
An abstract class representing an object that can log data.
@@ -234,7 +226,7 @@ class LocalLoader(BaseLoader):
def init(self, logger: BaseLogger, **kwargs) -> None:
# Makes sure the file exists to be loaded
if not self.file_path.exists():
if not self.file_path.exists() and not self.only_auto_resume:
raise FileNotFoundError(f'Model not found at {self.file_path}')
def recall(self) -> dict:
@@ -283,9 +275,9 @@ def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
class BaseSaver:
def __init__(self,
data_path: str,
save_latest_to: Optional[Union[str, bool]] = 'latest.pth',
save_best_to: Optional[Union[str, bool]] = 'best.pth',
save_meta_to: str = './',
save_latest_to: Optional[Union[str, bool]] = None,
save_best_to: Optional[Union[str, bool]] = None,
save_meta_to: Optional[str] = None,
save_type: str = 'checkpoint',
**kwargs
):
@@ -295,10 +287,10 @@ class BaseSaver:
self.save_best_to = save_best_to
self.saving_best = save_best_to is not None and save_best_to is not False
self.save_meta_to = save_meta_to
self.saving_meta = save_meta_to is not None
self.save_type = save_type
assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
assert self.save_meta_to is not None, '`save_meta_to` must be provided'
assert self.saving_latest or self.saving_best, '`save_latest_to` or `save_best_to` must be provided'
assert self.saving_latest or self.saving_best or self.saving_meta, 'At least one saving option must be specified'
def init(self, logger: BaseLogger, **kwargs) -> None:
raise NotImplementedError
@@ -459,6 +451,11 @@ class Tracker:
print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
print(f"New logger config: {self.logger.__dict__}")
self.save_metadata = dict(
version = version.parse(__version__)
) # Data that will be saved alongside the checkpoint or model
self.blacklisted_checkpoint_metadata_keys = ['scaler', 'optimizer', 'model', 'version', 'step', 'steps'] # These keys would cause us to error if we try to save them as metadata
assert self.logger is not None, '`logger` must be set before `init` is called'
if self.dummy_mode:
# The only thing we need is a loader
@@ -507,8 +504,15 @@ class Tracker:
# Save the config under config_name in the root folder of data_path
shutil.copy(current_config_path, self.data_path / config_name)
for saver in self.savers:
remote_path = Path(saver.save_meta_to) / config_name
saver.save_file(current_config_path, str(remote_path))
if saver.saving_meta:
remote_path = Path(saver.save_meta_to) / config_name
saver.save_file(current_config_path, str(remote_path))
def add_save_metadata(self, state_dict_key: str, metadata: Any):
"""
Adds a new piece of metadata that will be saved along with the model or decoder.
"""
self.save_metadata[state_dict_key] = metadata
def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
"""
@@ -518,24 +522,34 @@ class Tracker:
"""
assert save_type in ['checkpoint', 'model']
if save_type == 'checkpoint':
trainer.save(file_path, overwrite=True, **kwargs)
# Create a metadata dict without the blacklisted keys so we do not error when we create the state dict
metadata = {k: v for k, v in self.save_metadata.items() if k not in self.blacklisted_checkpoint_metadata_keys}
trainer.save(file_path, overwrite=True, **kwargs, **metadata)
elif save_type == 'model':
if isinstance(trainer, DiffusionPriorTrainer):
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
state_dict = trainer.unwrap_model(prior).state_dict()
torch.save(state_dict, file_path)
prior: DiffusionPrior = trainer.unwrap_model(prior)
# Remove CLIP if it is part of the model
prior.clip = None
model_state_dict = prior.state_dict()
elif isinstance(trainer, DecoderTrainer):
decoder = trainer.accelerator.unwrap_model(trainer.decoder)
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
# Remove CLIP if it is part of the model
decoder.clip = None
if trainer.use_ema:
trainable_unets = decoder.unets
decoder.unets = trainer.unets # Swap EMA unets in
state_dict = decoder.state_dict()
model_state_dict = decoder.state_dict()
decoder.unets = trainable_unets # Swap back
else:
state_dict = decoder.state_dict()
torch.save(state_dict, file_path)
model_state_dict = decoder.state_dict()
else:
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
state_dict = {
**self.save_metadata,
'model': model_state_dict
}
torch.save(state_dict, file_path)
return Path(file_path)
def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):

View File

@@ -129,6 +129,7 @@ class AdapterConfig(BaseModel):
class DiffusionPriorNetworkConfig(BaseModel):
dim: int
depth: int
max_text_len: int = None
num_timesteps: int = None
num_time_embeds: int = 1
num_image_embeds: int = 1
@@ -136,6 +137,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
dim_head: int = 64
heads: int = 8
ff_mult: int = 4
norm_in: bool = False
norm_out: bool = True
attn_dropout: float = 0.
ff_dropout: float = 0.
@@ -223,6 +225,7 @@ class UnetConfig(BaseModel):
self_attn: ListOrTuple(int)
attn_dim_head: int = 32
attn_heads: int = 16
init_cross_embed: bool = True
class Config:
extra = "allow"

View File

@@ -673,8 +673,14 @@ class DecoderTrainer(nn.Module):
def sample(self, *args, **kwargs):
distributed = self.accelerator.num_processes > 1
base_decoder = self.accelerator.unwrap_model(self.decoder)
was_training = base_decoder.training
base_decoder.eval()
if kwargs.pop('use_non_ema', False) or not self.use_ema:
return base_decoder.sample(*args, **kwargs, distributed = distributed)
out = base_decoder.sample(*args, **kwargs, distributed = distributed)
base_decoder.train(was_training)
return out
trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
@@ -687,6 +693,7 @@ class DecoderTrainer(nn.Module):
for ema in self.ema_unets:
ema.restore_ema_model_device()
base_decoder.train(was_training)
return output
@torch.no_grad()

View File

@@ -1 +1 @@
__version__ = '0.19.3'
__version__ = '0.26.2'

View File

@@ -323,7 +323,7 @@ def train(
last_snapshot = sample
if next_task == 'train':
for i, (img, emb, txt) in enumerate(trainer.train_loader):
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
# We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
@@ -358,6 +358,7 @@ def train(
else:
# Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts
loss = trainer.forward(img, **forward_params, unet_number=unet)
trainer.update(unet_number=unet)
@@ -416,7 +417,7 @@ def train(
timer = Timer()
accelerator.wait_for_everyone()
i = 0
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader
val_sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item()
@@ -512,6 +513,7 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_
}
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
tracker.save_config(config_path, config_name='decoder_config.json')
tracker.add_save_metadata(state_dict_key='config', metadata=config.dict())
return tracker
def initialize_training(config: TrainDecoderConfig, config_path):
@@ -557,7 +559,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
# Create the decoder model and print basic info
decoder = config.decoder.create()
num_parameters = sum(p.numel() for p in decoder.parameters())
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
# Create and initialize the tracker if we are the master
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
@@ -586,7 +588,10 @@ def initialize_training(config: TrainDecoderConfig, config_path):
accelerator.print(print_ribbon("Loaded Config", repeat=40))
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
accelerator.print(f"Number of parameters: {num_parameters}")
accelerator.print(f"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training")
for i, unet in enumerate(decoder.unets):
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
train(dataloaders, decoder, accelerator,
tracker=tracker,
inference_device=accelerator.device,

View File

@@ -126,9 +126,9 @@ def report_cosine_sims(
# we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned:
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
text_embedding, text_encodings = trainer.embed_text(text_data)
text_cond = dict(
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
text_embed=text_embedding, text_encodings=text_encodings
)
else:
text_embedding = text_data
@@ -146,15 +146,12 @@ def report_cosine_sims(
if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx]
text_mask_shuffled = text_mask[rolled_idx]
else:
text_encodings_shuffled = None
text_mask_shuffled = None
text_cond_shuffled = dict(
text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled,
mask=text_mask_shuffled,
text_encodings=text_encodings_shuffled
)
# prepare the text embedding