mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e76e89f9eb | ||
|
|
bb3ff0ac67 | ||
|
|
1ec4dbe64f | ||
|
|
e0835acca9 | ||
|
|
e055793e5d | ||
|
|
1d9ef99288 | ||
|
|
bdd62c24b3 | ||
|
|
1f1557c614 | ||
|
|
1a217e99e3 | ||
|
|
7ea314e2f0 | ||
|
|
4173e88121 | ||
|
|
3dae43fa0e | ||
|
|
a598820012 | ||
|
|
4878762627 | ||
|
|
47ae17b36e | ||
|
|
b7e22f7da0 | ||
|
|
68de937aac | ||
|
|
097afda606 | ||
|
|
5c520db825 | ||
|
|
3070610231 | ||
|
|
870aeeca62 | ||
|
|
f28dc6dc01 | ||
|
|
081d8d3484 | ||
|
|
a71f693a26 | ||
|
|
d7bc5fbedd | ||
|
|
8c823affff | ||
|
|
ec7cab01d9 | ||
|
|
46be8c32d3 | ||
|
|
900f086a6d | ||
|
|
b3e646fd3b |
33
.github/workflows/ci.yml
vendored
Normal file
33
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
name: Continuous integration
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install
|
||||
run: |
|
||||
python3 -m venv .env
|
||||
source .env/bin/activate
|
||||
make install
|
||||
- name: Tests
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
make test
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -136,3 +136,5 @@ dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
.tracker_data
|
||||
*.pth
|
||||
|
||||
6
Makefile
Normal file
6
Makefile
Normal file
@@ -0,0 +1,6 @@
|
||||
install:
|
||||
pip install -U pip
|
||||
pip install -e .
|
||||
|
||||
test:
|
||||
CUDA_VISIBLE_DEVICES= python train_decoder.py --config_file configs/train_decoder_config.test.json
|
||||
48
README.md
48
README.md
@@ -44,6 +44,8 @@ This library would not have gotten to this working state without the help of
|
||||
- <a href="https://github.com/krish240574">Kumar</a> for working on the initial diffusion training script
|
||||
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
|
||||
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
|
||||
- <a href="https://github.com/marunine">Marunine</a> for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes
|
||||
- <a href="https://github.com/malumadev">MalumaDev</a> for proposing the use of pixel shuffle upsampler for fixing checkboard artifacts
|
||||
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
||||
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
|
||||
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
|
||||
@@ -354,7 +356,8 @@ prior_network = DiffusionPriorNetwork(
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
timesteps = 1000,
|
||||
sample_timesteps = 64,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
@@ -418,7 +421,7 @@ For the layperson, no worries, training will all be automated into a CLI tool, a
|
||||
|
||||
## Training on Preprocessed CLIP Embeddings
|
||||
|
||||
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
|
||||
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings`
|
||||
|
||||
Working example below
|
||||
|
||||
@@ -581,7 +584,9 @@ unet1 = Unet(
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
text_embed_dim = 512,
|
||||
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
@@ -596,14 +601,14 @@ decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
timesteps = 1000,
|
||||
sample_timesteps = (250, 27),
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||
text_cond_drop_prob = 0.5
|
||||
).cuda()
|
||||
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps
|
||||
@@ -990,34 +995,7 @@ dataset = ImageEmbeddingDataset(
|
||||
|
||||
#### `train_diffusion_prior.py`
|
||||
|
||||
This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process.
|
||||
Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below.
|
||||
|
||||
#### Usage
|
||||
|
||||
```bash
|
||||
$ python train_diffusion_prior.py
|
||||
```
|
||||
|
||||
The most significant parameters for the script are as follows:
|
||||
|
||||
- `image-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/"`
|
||||
|
||||
- `text-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/"`
|
||||
|
||||
- `image-embed-dim`, default = `768` - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
|
||||
|
||||
- `learning-rate`, default = `1.1e-4`
|
||||
|
||||
- `weight-decay`, default = `6.02e-2`
|
||||
|
||||
- `max-grad-norm`, default = `0.5`
|
||||
|
||||
- `batch-size`, default = `10 ** 4`
|
||||
|
||||
- `num-epochs`, default = `5`
|
||||
|
||||
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
|
||||
For detailed information on training the diffusion prior, please refer to the [dedicated readme](prior.md)
|
||||
|
||||
## CLI (wip)
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ Defines the configuration options for the decoder model. The unets defined above
|
||||
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
|
||||
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
|
||||
| `learned_variance` | No | `True` | Whether to learn the variance. |
|
||||
| `clip` | No | `None` | The clip model to use if embeddings are being generated on the fly. Takes keys `make` and `model` with defaults `openai` and `ViT-L/14`. |
|
||||
|
||||
Any parameter from the `Decoder` constructor can also be given here.
|
||||
|
||||
@@ -39,7 +40,8 @@ Settings for creation of the dataloaders.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
|
||||
| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
|
||||
| `img_embeddings_url` | No | `None` | The url of the folder containing image embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
|
||||
| `text_embeddings_url` | No | `None` | The url of the folder containing text embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
|
||||
| `num_workers` | No | `4` | The number of workers used in the dataloader. |
|
||||
| `batch_size` | No | `64` | The batch size. |
|
||||
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
|
||||
@@ -106,6 +108,13 @@ Tracking is split up into three sections:
|
||||
|
||||
**Logging:**
|
||||
|
||||
All loggers have the following keys:
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `log_type` | Yes | N/A | The type of logger class to use. |
|
||||
| `resume` | No | `False` | For loggers that have the option to resume an old run, resume it using maually input parameters. |
|
||||
| `auto_resume` | No | `False` | If true, the logger will attempt to resume an old run using parameters from that previous run. |
|
||||
|
||||
If using `console` there is no further configuration than setting `log_type` to `console`.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
@@ -119,10 +128,15 @@ If using `wandb`
|
||||
| `wandb_project` | Yes | N/A | The wandb project save the run to. |
|
||||
| `wandb_run_name` | No | `None` | The wandb run name. |
|
||||
| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
|
||||
| `wandb_resume` | No | `False` | Whether to resume an old run. |
|
||||
|
||||
**Loading:**
|
||||
|
||||
All loaders have the following keys:
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `load_from` | Yes | N/A | The type of loader class to use. |
|
||||
| `only_auto_resume` | No | `False` | If true, the loader will only load the model if the run is being auto resumed. |
|
||||
|
||||
If using `local`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
},
|
||||
"data": {
|
||||
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
|
||||
"embeddings_url": "s3://bucket/embeddings/path/",
|
||||
"img_embeddings_url": "s3://bucket/img_embeddings/path/",
|
||||
"num_workers": 4,
|
||||
"batch_size": 64,
|
||||
"start_shard": 0,
|
||||
|
||||
102
configs/train_decoder_config.test.json
Normal file
102
configs/train_decoder_config.test.json
Normal file
@@ -0,0 +1,102 @@
|
||||
{
|
||||
"decoder": {
|
||||
"unets": [
|
||||
{
|
||||
"dim": 16,
|
||||
"image_embed_dim": 768,
|
||||
"cond_dim": 16,
|
||||
"channels": 3,
|
||||
"dim_mults": [1, 2, 4, 8],
|
||||
"attn_dim_head": 16,
|
||||
"attn_heads": 4,
|
||||
"self_attn": [false, true, true, true]
|
||||
}
|
||||
],
|
||||
"clip": {
|
||||
"make": "openai",
|
||||
"model": "ViT-L/14"
|
||||
},
|
||||
|
||||
"timesteps": 10,
|
||||
"image_sizes": [64],
|
||||
"channels": 3,
|
||||
"loss_type": "l2",
|
||||
"beta_schedule": ["cosine"],
|
||||
"learned_variance": true
|
||||
},
|
||||
"data": {
|
||||
"webdataset_base_url": "test_data/{}.tar",
|
||||
"num_workers": 4,
|
||||
"batch_size": 4,
|
||||
"start_shard": 0,
|
||||
"end_shard": 9,
|
||||
"shard_width": 1,
|
||||
"index_width": 1,
|
||||
"splits": {
|
||||
"train": 0.75,
|
||||
"val": 0.15,
|
||||
"test": 0.1
|
||||
},
|
||||
"shuffle_train": false,
|
||||
"resample_train": true,
|
||||
"preprocessing": {
|
||||
"RandomResizedCrop": {
|
||||
"size": [224, 224],
|
||||
"scale": [0.75, 1.0],
|
||||
"ratio": [1.0, 1.0]
|
||||
},
|
||||
"ToTensor": true
|
||||
}
|
||||
},
|
||||
"train": {
|
||||
"epochs": 1,
|
||||
"lr": 1e-16,
|
||||
"wd": 0.01,
|
||||
"max_grad_norm": 0.5,
|
||||
"save_every_n_samples": 100,
|
||||
"n_sample_images": 1,
|
||||
"device": "cpu",
|
||||
"epoch_samples": 50,
|
||||
"validation_samples": 5,
|
||||
"use_ema": true,
|
||||
"ema_beta": 0.99,
|
||||
"amp": false,
|
||||
"save_all": false,
|
||||
"save_latest": true,
|
||||
"save_best": true,
|
||||
"unet_training_mask": [true]
|
||||
},
|
||||
"evaluate": {
|
||||
"n_evaluation_samples": 2,
|
||||
"FID": {
|
||||
"feature": 64
|
||||
},
|
||||
"IS": {
|
||||
"feature": 64,
|
||||
"splits": 10
|
||||
},
|
||||
"KID": {
|
||||
"feature": 64,
|
||||
"subset_size": 2
|
||||
},
|
||||
"LPIPS": {
|
||||
"net_type": "vgg",
|
||||
"reduction": "mean"
|
||||
}
|
||||
},
|
||||
"tracker": {
|
||||
"overwrite_data_path": true,
|
||||
|
||||
"log": {
|
||||
"log_type": "console"
|
||||
},
|
||||
|
||||
"load": {
|
||||
"load_from": null
|
||||
},
|
||||
|
||||
"save": [{
|
||||
"save_to": "local"
|
||||
}]
|
||||
}
|
||||
}
|
||||
@@ -77,6 +77,11 @@ def cast_tuple(val, length = None):
|
||||
def module_device(module):
|
||||
return next(module.parameters()).device
|
||||
|
||||
def zero_init_(m):
|
||||
nn.init.zeros_(m.weight)
|
||||
if exists(m.bias):
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
@contextmanager
|
||||
def null_context(*args, **kwargs):
|
||||
yield
|
||||
@@ -125,14 +130,28 @@ def log(t, eps = 1e-12):
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1)
|
||||
|
||||
def resize_image_to(image, target_image_size):
|
||||
def resize_image_to(
|
||||
image,
|
||||
target_image_size,
|
||||
clamp_range = None,
|
||||
nearest = False,
|
||||
**kwargs
|
||||
):
|
||||
orig_image_size = image.shape[-1]
|
||||
|
||||
if orig_image_size == target_image_size:
|
||||
return image
|
||||
|
||||
scale_factors = target_image_size / orig_image_size
|
||||
return resize(image, scale_factors = scale_factors)
|
||||
if not nearest:
|
||||
scale_factors = target_image_size / orig_image_size
|
||||
out = resize(image, scale_factors = scale_factors, **kwargs)
|
||||
else:
|
||||
out = F.interpolate(image, target_image_size, mode = 'nearest', align_corners = False)
|
||||
|
||||
if exists(clamp_range):
|
||||
out = out.clamp(*clamp_range)
|
||||
|
||||
return out
|
||||
|
||||
# image normalization functions
|
||||
# ddpms expect images to be in the range of -1 to 1
|
||||
@@ -146,7 +165,7 @@ def unnormalize_zero_to_one(normed_img):
|
||||
|
||||
# clip related adapters
|
||||
|
||||
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask'])
|
||||
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings'])
|
||||
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
|
||||
|
||||
class BaseClipAdapter(nn.Module):
|
||||
@@ -155,6 +174,11 @@ class BaseClipAdapter(nn.Module):
|
||||
self.clip = clip
|
||||
self.overrides = kwargs
|
||||
|
||||
def validate_and_resize_image(self, image):
|
||||
image_size = image.shape[-1]
|
||||
assert image_size >= self.image_size, f'you are passing in an image of size {image_size} but CLIP requires the image size to be at least {self.image_size}'
|
||||
return resize_image_to(image, self.image_size)
|
||||
|
||||
@property
|
||||
def dim_latent(self):
|
||||
raise NotImplementedError
|
||||
@@ -201,11 +225,12 @@ class XClipAdapter(BaseClipAdapter):
|
||||
encoder_output = self.clip.text_transformer(text)
|
||||
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||
text_embed = self.clip.to_text_latent(text_cls)
|
||||
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
|
||||
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
||||
return EmbeddedText(l2norm(text_embed), text_encodings)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.validate_and_resize_image(image)
|
||||
encoder_output = self.clip.visual_transformer(image)
|
||||
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
@@ -236,11 +261,12 @@ class CoCaAdapter(BaseClipAdapter):
|
||||
text = text[..., :self.max_text_len]
|
||||
text_mask = text != 0
|
||||
text_embed, text_encodings = self.clip.embed_text(text)
|
||||
return EmbeddedText(text_embed, text_encodings, text_mask)
|
||||
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
||||
return EmbeddedText(text_embed, text_encodings)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.validate_and_resize_image(image)
|
||||
image_embed, image_encodings = self.clip.embed_image(image)
|
||||
return EmbeddedImage(image_embed, image_encodings)
|
||||
|
||||
@@ -295,13 +321,14 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
|
||||
text_embed = self.clip.encode_text(text)
|
||||
text_encodings = self.text_encodings
|
||||
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
||||
del self.text_encodings
|
||||
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
|
||||
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
assert not self.cleared
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.validate_and_resize_image(image)
|
||||
image = self.clip_normalize(image)
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||
@@ -486,6 +513,12 @@ class NoiseScheduler(nn.Module):
|
||||
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||
)
|
||||
|
||||
def predict_noise_from_start(self, x_t, t, x0):
|
||||
return (
|
||||
(x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \
|
||||
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
||||
)
|
||||
|
||||
def p2_reweigh_loss(self, loss, times):
|
||||
if not self.has_p2_loss_reweighting:
|
||||
return loss
|
||||
@@ -838,8 +871,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
if not exists(text_encodings):
|
||||
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
|
||||
|
||||
if not exists(mask):
|
||||
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
|
||||
mask = torch.any(text_encodings != 0., dim = -1)
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
@@ -856,9 +888,8 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||
# but let's just do it right
|
||||
|
||||
if exists(mask):
|
||||
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
|
||||
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
|
||||
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
time_embed = self.to_time_embeds(diffusion_timesteps)
|
||||
|
||||
@@ -892,19 +923,23 @@ class DiffusionPrior(nn.Module):
|
||||
image_size = None,
|
||||
image_channels = 3,
|
||||
timesteps = 1000,
|
||||
sample_timesteps = None,
|
||||
cond_drop_prob = 0.,
|
||||
loss_type = "l2",
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
sampling_clamp_l2norm = False,
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
|
||||
sampling_final_clamp_l2norm = False, # whether to l2norm the final image embedding output (this is also done for images in ddpm)
|
||||
training_clamp_l2norm = False,
|
||||
init_image_embed_l2norm = False,
|
||||
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
clip_adapter_overrides = dict()
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_timesteps = sample_timesteps
|
||||
|
||||
self.noise_scheduler = NoiseScheduler(
|
||||
beta_schedule = beta_schedule,
|
||||
timesteps = timesteps,
|
||||
@@ -935,23 +970,32 @@ class DiffusionPrior(nn.Module):
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
|
||||
self.predict_x_start = predict_x_start
|
||||
|
||||
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
|
||||
self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)
|
||||
|
||||
# whether to force an l2norm, similar to clipping denoised, when sampling
|
||||
|
||||
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
||||
self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm
|
||||
|
||||
self.training_clamp_l2norm = training_clamp_l2norm
|
||||
self.init_image_embed_l2norm = init_image_embed_l2norm
|
||||
|
||||
# device tracker
|
||||
|
||||
self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._dummy.device
|
||||
|
||||
def l2norm_clamp_embed(self, image_embed):
|
||||
return l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
@@ -959,8 +1003,6 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
if self.predict_x_start:
|
||||
x_recon = pred
|
||||
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
|
||||
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
|
||||
else:
|
||||
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
@@ -983,21 +1025,81 @@ class DiffusionPrior(nn.Module):
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
|
||||
device = self.device
|
||||
|
||||
b = shape[0]
|
||||
image_embed = torch.randn(shape, device=device)
|
||||
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
|
||||
batch, device = shape[0], self.device
|
||||
image_embed = torch.randn(shape, device = device)
|
||||
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
|
||||
times = torch.full((b,), i, device = device, dtype = torch.long)
|
||||
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
||||
|
||||
if self.sampling_final_clamp_l2norm and self.predict_x_start:
|
||||
image_embed = self.l2norm_clamp_embed(image_embed)
|
||||
|
||||
return image_embed
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
|
||||
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
|
||||
|
||||
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
||||
|
||||
times = list(reversed(times.int().tolist()))
|
||||
time_pairs = list(zip(times[:-1], times[1:]))
|
||||
|
||||
image_embed = torch.randn(shape, device = device)
|
||||
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
||||
alpha = alphas[time]
|
||||
alpha_next = alphas[time_next]
|
||||
|
||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||
|
||||
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
x_start = pred
|
||||
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = pred)
|
||||
else:
|
||||
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
|
||||
pred_noise = pred
|
||||
|
||||
if not self.predict_x_start:
|
||||
x_start.clamp_(-1., 1.)
|
||||
|
||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||
x_start = self.l2norm_clamp_embed(x_start)
|
||||
|
||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
|
||||
|
||||
image_embed = x_start * alpha_next.sqrt() + \
|
||||
c1 * noise + \
|
||||
c2 * pred_noise
|
||||
|
||||
if self.predict_x_start and self.sampling_final_clamp_l2norm:
|
||||
image_embed = self.l2norm_clamp_embed(image_embed)
|
||||
|
||||
return image_embed
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, *args, timesteps = None, **kwargs):
|
||||
timesteps = default(timesteps, self.noise_scheduler.num_timesteps)
|
||||
assert timesteps <= self.noise_scheduler.num_timesteps
|
||||
is_ddim = timesteps < self.noise_scheduler.num_timesteps
|
||||
|
||||
if not is_ddim:
|
||||
return self.p_sample_loop_ddpm(*args, **kwargs)
|
||||
|
||||
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
|
||||
|
||||
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||
|
||||
@@ -1011,7 +1113,7 @@ class DiffusionPrior(nn.Module):
|
||||
)
|
||||
|
||||
if self.predict_x_start and self.training_clamp_l2norm:
|
||||
pred = l2norm(pred) * self.image_embed_scale
|
||||
pred = self.l2norm_clamp_embed(pred)
|
||||
|
||||
target = noise if not self.predict_x_start else image_embed
|
||||
|
||||
@@ -1032,7 +1134,15 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.):
|
||||
def sample(
|
||||
self,
|
||||
text,
|
||||
num_samples_per_batch = 2,
|
||||
cond_scale = 1.,
|
||||
timesteps = None
|
||||
):
|
||||
timesteps = default(timesteps, self.sample_timesteps)
|
||||
|
||||
# in the paper, what they did was
|
||||
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
|
||||
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
|
||||
@@ -1040,14 +1150,14 @@ class DiffusionPrior(nn.Module):
|
||||
batch_size = text.shape[0]
|
||||
image_embed_dim = self.image_embed_dim
|
||||
|
||||
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
text_embed, text_encodings = self.clip.embed_text(text)
|
||||
|
||||
text_cond = dict(text_embed = text_embed)
|
||||
|
||||
if self.condition_on_text_encodings:
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings}
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale)
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)
|
||||
|
||||
# retrieve original unscaled image embed
|
||||
|
||||
@@ -1073,7 +1183,6 @@ class DiffusionPrior(nn.Module):
|
||||
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
|
||||
image_embed = None,
|
||||
text_encodings = None, # as well as CLIP text encodings
|
||||
text_mask = None, # text mask <- may eventually opt for the learned padding tokens technique from DALL-E1 to reduce complexity
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
@@ -1087,13 +1196,13 @@ class DiffusionPrior(nn.Module):
|
||||
# calculate text conditionings, based on what is passed in
|
||||
|
||||
if exists(text):
|
||||
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
text_embed, text_encodings = self.clip.embed_text(text)
|
||||
|
||||
text_cond = dict(text_embed = text_embed)
|
||||
|
||||
if self.condition_on_text_encodings:
|
||||
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings}
|
||||
|
||||
# timestep conditioning from ddpm
|
||||
|
||||
@@ -1110,16 +1219,35 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
# decoder
|
||||
|
||||
def ConvTransposeUpsample(dim, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
|
||||
class PixelShuffleUpsample(nn.Module):
|
||||
"""
|
||||
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
|
||||
https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
|
||||
"""
|
||||
def __init__(self, dim, dim_out = None):
|
||||
super().__init__()
|
||||
dim_out = default(dim_out, dim)
|
||||
conv = nn.Conv2d(dim, dim_out * 4, 1)
|
||||
|
||||
def NearestUpsample(dim, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
return nn.Sequential(
|
||||
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
||||
nn.Conv2d(dim, dim_out, 3, padding = 1)
|
||||
)
|
||||
self.net = nn.Sequential(
|
||||
conv,
|
||||
nn.SiLU(),
|
||||
nn.PixelShuffle(2)
|
||||
)
|
||||
|
||||
self.init_conv_(conv)
|
||||
|
||||
def init_conv_(self, conv):
|
||||
o, i, h, w = conv.weight.shape
|
||||
conv_weight = torch.empty(o // 4, i, h, w)
|
||||
nn.init.kaiming_uniform_(conv_weight)
|
||||
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
|
||||
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
def Downsample(dim, *, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
@@ -1383,7 +1511,7 @@ class Unet(nn.Module):
|
||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||
memory_efficient = False,
|
||||
scale_skip_connection = False,
|
||||
nearest_upsample = False,
|
||||
pixel_shuffle_upsample = True,
|
||||
final_conv_kernel_size = 1,
|
||||
**kwargs
|
||||
):
|
||||
@@ -1449,10 +1577,12 @@ class Unet(nn.Module):
|
||||
# text encoding conditioning (optional)
|
||||
|
||||
self.text_to_cond = None
|
||||
self.text_embed_dim = None
|
||||
|
||||
if cond_on_text_encodings:
|
||||
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
|
||||
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
|
||||
self.text_embed_dim = text_embed_dim
|
||||
|
||||
# finer control over whether to condition on image embeddings and text encodings
|
||||
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
||||
@@ -1495,7 +1625,7 @@ class Unet(nn.Module):
|
||||
|
||||
# upsample klass
|
||||
|
||||
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
|
||||
upsample_klass = ConvTransposeUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
|
||||
|
||||
# give memory efficient unet an initial resnet block
|
||||
|
||||
@@ -1559,6 +1689,8 @@ class Unet(nn.Module):
|
||||
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
||||
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
||||
|
||||
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
# for cascading DDPM, then reinit the unet with the right settings
|
||||
def cast_model_parameters(
|
||||
@@ -1609,7 +1741,6 @@ class Unet(nn.Module):
|
||||
image_embed,
|
||||
lowres_cond_img = None,
|
||||
text_encodings = None,
|
||||
text_mask = None,
|
||||
image_cond_drop_prob = 0.,
|
||||
text_cond_drop_prob = 0.,
|
||||
blur_sigma = None,
|
||||
@@ -1681,21 +1812,26 @@ class Unet(nn.Module):
|
||||
text_tokens = None
|
||||
|
||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
|
||||
|
||||
text_mask = torch.any(text_encodings != 0., dim = -1)
|
||||
|
||||
text_tokens = self.text_to_cond(text_encodings)
|
||||
|
||||
text_tokens = text_tokens[:, :self.max_text_len]
|
||||
text_mask = text_mask[:, :self.max_text_len]
|
||||
|
||||
text_tokens_len = text_tokens.shape[1]
|
||||
remainder = self.max_text_len - text_tokens_len
|
||||
|
||||
if remainder > 0:
|
||||
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
|
||||
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
||||
|
||||
if exists(text_mask):
|
||||
if remainder > 0:
|
||||
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
||||
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
||||
|
||||
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
||||
text_keep_mask = text_mask & text_keep_mask
|
||||
assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}. text encoding is of shape {text_encodings.shape}'
|
||||
text_keep_mask = text_mask & text_keep_mask
|
||||
|
||||
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
|
||||
|
||||
@@ -1776,11 +1912,17 @@ class LowresConditioner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
downsample_first = True,
|
||||
downsample_mode_nearest = False,
|
||||
blur_sigma = 0.6,
|
||||
blur_kernel_size = 3,
|
||||
input_image_range = None
|
||||
):
|
||||
super().__init__()
|
||||
self.downsample_first = downsample_first
|
||||
self.downsample_mode_nearest = downsample_mode_nearest
|
||||
|
||||
self.input_image_range = input_image_range
|
||||
|
||||
self.blur_sigma = blur_sigma
|
||||
self.blur_kernel_size = blur_kernel_size
|
||||
|
||||
@@ -1794,7 +1936,7 @@ class LowresConditioner(nn.Module):
|
||||
blur_kernel_size = None
|
||||
):
|
||||
if self.training and self.downsample_first and exists(downsample_image_size):
|
||||
cond_fmap = resize_image_to(cond_fmap, downsample_image_size)
|
||||
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = self.downsample_mode_nearest)
|
||||
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
@@ -1814,7 +1956,7 @@ class LowresConditioner(nn.Module):
|
||||
|
||||
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size)
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range)
|
||||
|
||||
return cond_fmap
|
||||
|
||||
@@ -1828,6 +1970,7 @@ class Decoder(nn.Module):
|
||||
channels = 3,
|
||||
vae = tuple(),
|
||||
timesteps = 1000,
|
||||
sample_timesteps = None,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
loss_type = 'l2',
|
||||
@@ -1837,6 +1980,7 @@ class Decoder(nn.Module):
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
lowres_downsample_mode_nearest = False, # cascading ddpm - whether to use nearest mode downsampling for lower resolution
|
||||
blur_sigma = 0.6, # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
clip_denoised = True,
|
||||
@@ -1850,7 +1994,8 @@ class Decoder(nn.Module):
|
||||
use_dynamic_thres = False, # from the Imagen paper
|
||||
dynamic_thres_percentile = 0.9,
|
||||
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
|
||||
p2_loss_weight_k = 1
|
||||
p2_loss_weight_k = 1,
|
||||
ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1930,9 +2075,10 @@ class Decoder(nn.Module):
|
||||
self.unets.append(one_unet)
|
||||
self.vaes.append(one_vae.copy_for_eval())
|
||||
|
||||
# determine from unets whether conditioning on text encoding is needed
|
||||
# sampling timesteps, defaults to non-ddim with full timesteps sampling
|
||||
|
||||
self.condition_on_text_encodings = any([unet.cond_on_text_encodings for unet in self.unets])
|
||||
self.sample_timesteps = cast_tuple(sample_timesteps, num_unets)
|
||||
self.ddim_sampling_eta = ddim_sampling_eta
|
||||
|
||||
# create noise schedulers per unet
|
||||
|
||||
@@ -1944,7 +2090,9 @@ class Decoder(nn.Module):
|
||||
|
||||
self.noise_schedulers = nn.ModuleList([])
|
||||
|
||||
for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(beta_schedule, p2_loss_weight_gamma):
|
||||
for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
|
||||
assert not exists(sample_timesteps) or sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
|
||||
|
||||
noise_scheduler = NoiseScheduler(
|
||||
beta_schedule = unet_beta_schedule,
|
||||
timesteps = timesteps,
|
||||
@@ -1972,6 +2120,10 @@ class Decoder(nn.Module):
|
||||
|
||||
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
||||
|
||||
# input image range
|
||||
|
||||
self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)
|
||||
|
||||
# cascading ddpm related stuff
|
||||
|
||||
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
|
||||
@@ -1979,8 +2131,10 @@ class Decoder(nn.Module):
|
||||
|
||||
self.to_lowres_cond = LowresConditioner(
|
||||
downsample_first = lowres_downsample_first,
|
||||
downsample_mode_nearest = lowres_downsample_mode_nearest,
|
||||
blur_sigma = blur_sigma,
|
||||
blur_kernel_size = blur_kernel_size,
|
||||
input_image_range = self.input_image_range
|
||||
)
|
||||
|
||||
# classifier free guidance
|
||||
@@ -2012,6 +2166,10 @@ class Decoder(nn.Module):
|
||||
def device(self):
|
||||
return self._dummy.device
|
||||
|
||||
@property
|
||||
def condition_on_text_encodings(self):
|
||||
return any([unet.cond_on_text_encodings for unet in self.unets])
|
||||
|
||||
def get_unet(self, unet_number):
|
||||
assert 0 < unet_number <= len(self.unets)
|
||||
index = unet_number - 1
|
||||
@@ -2035,10 +2193,30 @@ class Decoder(nn.Module):
|
||||
for unet, device in zip(self.unets, devices):
|
||||
unet.to(device)
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
||||
def dynamic_threshold(self, x):
|
||||
""" proposed in https://arxiv.org/abs/2205.11487 as an improved clamping in the setting of classifier free guidance """
|
||||
|
||||
# s is the threshold amount
|
||||
# static thresholding would just be s = 1
|
||||
s = 1.
|
||||
if self.use_dynamic_thres:
|
||||
s = torch.quantile(
|
||||
rearrange(x, 'b ... -> b (...)').abs(),
|
||||
self.dynamic_thres_percentile,
|
||||
dim = -1
|
||||
)
|
||||
|
||||
s.clamp_(min = 1.)
|
||||
s = s.view(-1, *((1,) * (x.ndim - 1)))
|
||||
|
||||
# clip by threshold, depending on whether static or dynamic
|
||||
x = x.clamp(-s, s) / s
|
||||
return x
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
|
||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
|
||||
|
||||
if learned_variance:
|
||||
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
|
||||
@@ -2049,21 +2227,7 @@ class Decoder(nn.Module):
|
||||
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised:
|
||||
# s is the threshold amount
|
||||
# static thresholding would just be s = 1
|
||||
s = 1.
|
||||
if self.use_dynamic_thres:
|
||||
s = torch.quantile(
|
||||
rearrange(x_recon, 'b ... -> b (...)').abs(),
|
||||
self.dynamic_thres_percentile,
|
||||
dim = -1
|
||||
)
|
||||
|
||||
s.clamp_(min = 1.)
|
||||
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
|
||||
|
||||
# clip by threshold, depending on whether static or dynamic
|
||||
x_recon = x_recon.clamp(-s, s) / s
|
||||
x_recon = self.dynamic_threshold(x_recon)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
|
||||
@@ -2084,16 +2248,16 @@ class Decoder(nn.Module):
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
|
||||
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance)
|
||||
noise = torch.randn_like(x)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
device = self.device
|
||||
|
||||
b = shape[0]
|
||||
@@ -2109,7 +2273,6 @@ class Decoder(nn.Module):
|
||||
torch.full((b,), i, device = device, dtype = torch.long),
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
text_mask = text_mask,
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
predict_x_start = predict_x_start,
|
||||
@@ -2121,7 +2284,63 @@ class Decoder(nn.Module):
|
||||
unnormalize_img = self.unnormalize_img(img)
|
||||
return unnormalize_img
|
||||
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
|
||||
|
||||
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
||||
|
||||
times = list(reversed(times.int().tolist()))
|
||||
time_pairs = list(zip(times[:-1], times[1:]))
|
||||
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
||||
alpha = alphas[time]
|
||||
alpha_next = alphas[time_next]
|
||||
|
||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||
|
||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
if learned_variance:
|
||||
pred, _ = pred.chunk(2, dim = 1)
|
||||
|
||||
if predict_x_start:
|
||||
x_start = pred
|
||||
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
|
||||
else:
|
||||
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
|
||||
pred_noise = pred
|
||||
|
||||
if clip_denoised:
|
||||
x_start = self.dynamic_threshold(x_start)
|
||||
|
||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||
noise = torch.randn_like(img) if time_next > 0 else 0.
|
||||
|
||||
img = x_start * alpha_next.sqrt() + \
|
||||
c1 * noise + \
|
||||
c2 * pred_noise
|
||||
|
||||
img = self.unnormalize_img(img)
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, *args, noise_scheduler, timesteps = None, **kwargs):
|
||||
num_timesteps = noise_scheduler.num_timesteps
|
||||
|
||||
timesteps = default(timesteps, num_timesteps)
|
||||
assert timesteps <= num_timesteps
|
||||
is_ddim = timesteps < num_timesteps
|
||||
|
||||
if not is_ddim:
|
||||
return self.p_sample_loop_ddpm(*args, noise_scheduler = noise_scheduler, **kwargs)
|
||||
|
||||
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
|
||||
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
# normalize to [-1, 1]
|
||||
@@ -2139,7 +2358,6 @@ class Decoder(nn.Module):
|
||||
times,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
text_mask = text_mask,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
image_cond_drop_prob = self.image_cond_drop_prob,
|
||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||
@@ -2199,7 +2417,6 @@ class Decoder(nn.Module):
|
||||
self,
|
||||
image_embed = None,
|
||||
text = None,
|
||||
text_mask = None,
|
||||
text_encodings = None,
|
||||
batch_size = 1,
|
||||
cond_scale = 1.,
|
||||
@@ -2213,7 +2430,7 @@ class Decoder(nn.Module):
|
||||
|
||||
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
||||
assert exists(self.clip)
|
||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
_, text_encodings = self.clip.embed_text(text)
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
@@ -2221,7 +2438,7 @@ class Decoder(nn.Module):
|
||||
img = None
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers)):
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps)):
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
|
||||
|
||||
@@ -2243,14 +2460,14 @@ class Decoder(nn.Module):
|
||||
shape,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
text_mask = text_mask,
|
||||
cond_scale = cond_scale,
|
||||
predict_x_start = predict_x_start,
|
||||
learned_variance = learned_variance,
|
||||
clip_denoised = not is_latent_diffusion,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
is_latent_diffusion = is_latent_diffusion,
|
||||
noise_scheduler = noise_scheduler
|
||||
noise_scheduler = noise_scheduler,
|
||||
timesteps = sample_timesteps
|
||||
)
|
||||
|
||||
img = vae.decode(img)
|
||||
@@ -2266,7 +2483,6 @@ class Decoder(nn.Module):
|
||||
text = None,
|
||||
image_embed = None,
|
||||
text_encodings = None,
|
||||
text_mask = None,
|
||||
unet_number = None,
|
||||
return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes
|
||||
):
|
||||
@@ -2295,7 +2511,7 @@ class Decoder(nn.Module):
|
||||
|
||||
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
||||
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
|
||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
_, text_encodings = self.clip.embed_text(text)
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
@@ -2318,7 +2534,7 @@ class Decoder(nn.Module):
|
||||
image = vae.encode(image)
|
||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||
|
||||
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
||||
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
||||
|
||||
if not return_lowres_cond_image:
|
||||
return losses
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import webdataset as wds
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
import fsspec
|
||||
import shutil
|
||||
@@ -255,7 +256,7 @@ def create_image_embedding_dataloader(
|
||||
)
|
||||
if shuffle_num is not None and shuffle_num > 0:
|
||||
ds.shuffle(1000)
|
||||
return wds.WebLoader(
|
||||
return DataLoader(
|
||||
ds,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import urllib.request
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from itertools import zip_longest
|
||||
@@ -37,14 +38,17 @@ class BaseLogger:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
verbose (bool): Whether of not to always print logs to the console.
|
||||
"""
|
||||
def __init__(self, data_path: str, verbose: bool = False, **kwargs):
|
||||
def __init__(self, data_path: str, resume: bool = False, auto_resume: bool = False, verbose: bool = False, **kwargs):
|
||||
self.data_path = Path(data_path)
|
||||
self.resume = resume
|
||||
self.auto_resume = auto_resume
|
||||
self.verbose = verbose
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
"""
|
||||
Initializes the logger.
|
||||
Errors if the logger is invalid.
|
||||
full_config is the config file dict while extra_config is anything else from the script that is not defined the config file.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -60,6 +64,14 @@ class BaseLogger:
|
||||
def log_error(self, error_string, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_resume_data(self, **kwargs) -> dict:
|
||||
"""
|
||||
Sets tracker attributes that along with { "resume": True } will be used to resume training.
|
||||
It is assumed that after init is called this data will be complete.
|
||||
If the logger does not have any resume functionality, it should return an empty dict.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
class ConsoleLogger(BaseLogger):
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
print("Logging to console")
|
||||
@@ -76,6 +88,9 @@ class ConsoleLogger(BaseLogger):
|
||||
def log_error(self, error_string, **kwargs) -> None:
|
||||
print(error_string)
|
||||
|
||||
def get_resume_data(self, **kwargs) -> dict:
|
||||
return {}
|
||||
|
||||
class WandbLogger(BaseLogger):
|
||||
"""
|
||||
Logs to a wandb run.
|
||||
@@ -85,7 +100,6 @@ class WandbLogger(BaseLogger):
|
||||
wandb_project (str): The wandb project to log to.
|
||||
wandb_run_id (str): The wandb run id to resume.
|
||||
wandb_run_name (str): The wandb run name to use.
|
||||
wandb_resume (bool): Whether to resume a wandb run.
|
||||
"""
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
@@ -93,7 +107,6 @@ class WandbLogger(BaseLogger):
|
||||
wandb_project: str,
|
||||
wandb_run_id: Optional[str] = None,
|
||||
wandb_run_name: Optional[str] = None,
|
||||
wandb_resume: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(data_path, **kwargs)
|
||||
@@ -101,7 +114,6 @@ class WandbLogger(BaseLogger):
|
||||
self.project = wandb_project
|
||||
self.run_id = wandb_run_id
|
||||
self.run_name = wandb_run_name
|
||||
self.resume = wandb_resume
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
|
||||
@@ -149,6 +161,14 @@ class WandbLogger(BaseLogger):
|
||||
print(error_string)
|
||||
self.wandb.log({"error": error_string, **kwargs}, step=step)
|
||||
|
||||
def get_resume_data(self, **kwargs) -> dict:
|
||||
# In order to resume, we need wandb_entity, wandb_project, and wandb_run_id
|
||||
return {
|
||||
"entity": self.entity,
|
||||
"project": self.project,
|
||||
"run_id": self.wandb.run.id
|
||||
}
|
||||
|
||||
logger_type_map = {
|
||||
'console': ConsoleLogger,
|
||||
'wandb': WandbLogger,
|
||||
@@ -168,8 +188,9 @@ class BaseLoader:
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
"""
|
||||
def __init__(self, data_path: str, **kwargs):
|
||||
def __init__(self, data_path: str, only_auto_resume: bool = False, **kwargs):
|
||||
self.data_path = Path(data_path)
|
||||
self.only_auto_resume = only_auto_resume
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
@@ -304,6 +325,10 @@ class LocalSaver(BaseSaver):
|
||||
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
|
||||
# Copy the file to save_path
|
||||
save_path_file_name = Path(save_path).name
|
||||
# Make sure parent directory exists
|
||||
save_path_parent = Path(save_path).parent
|
||||
if not save_path_parent.exists():
|
||||
save_path_parent.mkdir(parents=True)
|
||||
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
|
||||
shutil.copy(local_path, save_path)
|
||||
|
||||
@@ -385,11 +410,7 @@ class Tracker:
|
||||
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
|
||||
self.data_path = Path(data_path)
|
||||
if not dummy_mode:
|
||||
if overwrite_data_path:
|
||||
if self.data_path.exists():
|
||||
shutil.rmtree(self.data_path)
|
||||
self.data_path.mkdir(parents=True)
|
||||
else:
|
||||
if not overwrite_data_path:
|
||||
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
|
||||
if not self.data_path.exists():
|
||||
self.data_path.mkdir(parents=True)
|
||||
@@ -398,7 +419,46 @@ class Tracker:
|
||||
self.savers: List[BaseSaver]= []
|
||||
self.dummy_mode = dummy_mode
|
||||
|
||||
def _load_auto_resume(self) -> bool:
|
||||
# If the file does not exist, we return False. If autoresume is enabled we print a warning so that the user can know that this is the first run.
|
||||
if not self.auto_resume_path.exists():
|
||||
if self.logger.auto_resume:
|
||||
print("Auto_resume is enabled but no auto_resume.json file exists. Assuming this is the first run.")
|
||||
return False
|
||||
|
||||
# Now we know that the autoresume file exists, but if we are not auto resuming we should remove it so that we don't accidentally load it next time
|
||||
if not self.logger.auto_resume:
|
||||
print(f'Removing auto_resume.json because auto_resume is not enabled in the config')
|
||||
self.auto_resume_path.unlink()
|
||||
return False
|
||||
|
||||
# Otherwise we read the json into a dictionary will will override parts of logger.__dict__
|
||||
with open(self.auto_resume_path, 'r') as f:
|
||||
auto_resume_dict = json.load(f)
|
||||
# Check if the logger is of the same type as the autoresume save
|
||||
if auto_resume_dict["logger_type"] != self.logger.__class__.__name__:
|
||||
raise Exception(f'The logger type in the auto_resume file is {auto_resume_dict["logger_type"]} but the current logger is {self.logger.__class__.__name__}. Either use the original logger type, set `auto_resume` to `False`, or delete your existing tracker-data folder.')
|
||||
# Then we are ready to override the logger with the autoresume save
|
||||
self.logger.__dict__["resume"] = True
|
||||
print(f"Updating {self.logger.__dict__} with {auto_resume_dict}")
|
||||
self.logger.__dict__.update(auto_resume_dict)
|
||||
return True
|
||||
|
||||
def _save_auto_resume(self):
|
||||
# Gets the autoresume dict from the logger and adds "logger_type" to it then saves it to the auto_resume file
|
||||
auto_resume_dict = self.logger.get_resume_data()
|
||||
auto_resume_dict['logger_type'] = self.logger.__class__.__name__
|
||||
with open(self.auto_resume_path, 'w') as f:
|
||||
json.dump(auto_resume_dict, f)
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict):
|
||||
self.auto_resume_path = self.data_path / 'auto_resume.json'
|
||||
# Check for resuming the run
|
||||
self.did_auto_resume = self._load_auto_resume()
|
||||
if self.did_auto_resume:
|
||||
print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
|
||||
print(f"New logger config: {self.logger.__dict__}")
|
||||
|
||||
assert self.logger is not None, '`logger` must be set before `init` is called'
|
||||
if self.dummy_mode:
|
||||
# The only thing we need is a loader
|
||||
@@ -406,12 +466,17 @@ class Tracker:
|
||||
self.loader.init(self.logger)
|
||||
return
|
||||
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
|
||||
|
||||
self.logger.init(full_config, extra_config)
|
||||
if self.loader is not None:
|
||||
self.loader.init(self.logger)
|
||||
for saver in self.savers:
|
||||
saver.init(self.logger)
|
||||
|
||||
if self.logger.auto_resume:
|
||||
# Then we need to save the autoresume file. It is assumed after logger.init is called that the logger is ready to be saved.
|
||||
self._save_auto_resume()
|
||||
|
||||
def add_logger(self, logger: BaseLogger):
|
||||
self.logger = logger
|
||||
|
||||
@@ -503,11 +568,16 @@ class Tracker:
|
||||
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
|
||||
print(f'Error saving checkpoint: {e}')
|
||||
|
||||
@property
|
||||
def can_recall(self):
|
||||
# Defines whether a recall can be performed.
|
||||
return self.loader is not None and (not self.loader.only_auto_resume or self.did_auto_resume)
|
||||
|
||||
def recall(self):
|
||||
if self.loader is not None:
|
||||
if self.can_recall:
|
||||
return self.loader.recall()
|
||||
else:
|
||||
raise ValueError('No loader specified')
|
||||
raise ValueError('Tried to recall, but no loader was set or auto-resume was not performed.')
|
||||
|
||||
|
||||
|
||||
@@ -47,6 +47,8 @@ class TrainSplitConfig(BaseModel):
|
||||
|
||||
class TrackerLogConfig(BaseModel):
|
||||
log_type: str = 'console'
|
||||
resume: bool = False # For logs that are saved to unique locations, resume a previous run
|
||||
auto_resume: bool = False # If the process crashes and restarts, resume from the run that crashed
|
||||
verbose: bool = False
|
||||
|
||||
class Config:
|
||||
@@ -59,6 +61,7 @@ class TrackerLogConfig(BaseModel):
|
||||
|
||||
class TrackerLoadConfig(BaseModel):
|
||||
load_from: Optional[str] = None
|
||||
only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
@@ -151,6 +154,7 @@ class DiffusionPriorConfig(BaseModel):
|
||||
image_size: int
|
||||
image_channels: int = 3
|
||||
timesteps: int = 1000
|
||||
sample_timesteps: Optional[int] = None
|
||||
cond_drop_prob: float = 0.
|
||||
loss_type: str = 'l2'
|
||||
predict_x_start: bool = True
|
||||
@@ -230,6 +234,7 @@ class DecoderConfig(BaseModel):
|
||||
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
||||
channels: int = 3
|
||||
timesteps: int = 1000
|
||||
sample_timesteps: Optional[SingularOrIterable(int)] = None
|
||||
loss_type: str = 'l2'
|
||||
beta_schedule: ListOrTuple(str) = 'cosine'
|
||||
learned_variance: bool = True
|
||||
|
||||
@@ -21,7 +21,7 @@ import pytorch_warmup as warmup
|
||||
|
||||
from ema_pytorch import EMA
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DistributedType
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -76,6 +76,7 @@ def cast_torch_tensor(fn):
|
||||
def inner(model, *args, **kwargs):
|
||||
device = kwargs.pop('_device', next(model.parameters()).device)
|
||||
cast_device = kwargs.pop('_cast_device', True)
|
||||
cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)
|
||||
|
||||
kwargs_keys = kwargs.keys()
|
||||
all_args = (*args, *kwargs.values())
|
||||
@@ -85,6 +86,21 @@ def cast_torch_tensor(fn):
|
||||
if cast_device:
|
||||
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||
|
||||
if cast_deepspeed_precision:
|
||||
try:
|
||||
accelerator = model.accelerator
|
||||
if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
cast_type_map = {
|
||||
"fp16": torch.half,
|
||||
"bf16": torch.bfloat16,
|
||||
"no": torch.float
|
||||
}
|
||||
precision_type = cast_type_map[accelerator.mixed_precision]
|
||||
all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||
except AttributeError:
|
||||
# Then this model doesn't have an accelerator
|
||||
pass
|
||||
|
||||
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
|
||||
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
|
||||
|
||||
@@ -192,6 +208,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
self.device = diffusion_prior_device
|
||||
else:
|
||||
self.device = accelerator.device if exists(accelerator) else device
|
||||
diffusion_prior.to(self.device)
|
||||
|
||||
# save model
|
||||
|
||||
@@ -445,6 +462,7 @@ class DecoderTrainer(nn.Module):
|
||||
self,
|
||||
decoder,
|
||||
accelerator = None,
|
||||
dataloaders = None,
|
||||
use_ema = True,
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
@@ -507,11 +525,31 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
|
||||
|
||||
if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:
|
||||
# Then we need to make sure clip is using the correct precision or else deepspeed will error
|
||||
cast_type_map = {
|
||||
"fp16": torch.half,
|
||||
"bf16": torch.bfloat16,
|
||||
"no": torch.float
|
||||
}
|
||||
precision_type = cast_type_map[accelerator.mixed_precision]
|
||||
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
|
||||
clip = decoder.clip
|
||||
clip.to(precision_type)
|
||||
|
||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||
schedulers = list(self.accelerator.prepare(*schedulers))
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
# prepare dataloaders
|
||||
|
||||
train_loader = val_loader = None
|
||||
if exists(dataloaders):
|
||||
train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"])
|
||||
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
|
||||
# store optimizers
|
||||
|
||||
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
|
||||
@@ -526,6 +564,17 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.warmup_schedulers = warmup_schedulers
|
||||
|
||||
def validate_and_return_unet_number(self, unet_number = None):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
|
||||
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
|
||||
return unet_number
|
||||
|
||||
def num_steps_taken(self, unet_number = None):
|
||||
unet_number = self.validate_and_return_unet_number(unet_number)
|
||||
return self.steps[unet_number - 1].item()
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
@@ -594,10 +643,7 @@ class DecoderTrainer(nn.Module):
|
||||
self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
|
||||
|
||||
def update(self, unet_number = None):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
|
||||
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
|
||||
unet_number = self.validate_and_return_unet_number(unet_number)
|
||||
index = unet_number - 1
|
||||
|
||||
optimizer = getattr(self, f'optim{index}')
|
||||
@@ -663,11 +709,13 @@ class DecoderTrainer(nn.Module):
|
||||
max_batch_size = None,
|
||||
**kwargs
|
||||
):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
unet_number = self.validate_and_return_unet_number(unet_number)
|
||||
|
||||
total_loss = 0.
|
||||
|
||||
|
||||
using_amp = self.accelerator.mixed_precision != 'no'
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
with self.accelerator.autocast():
|
||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.16.14'
|
||||
__version__ = '0.22.0'
|
||||
|
||||
183
prior.md
Normal file
183
prior.md
Normal file
@@ -0,0 +1,183 @@
|
||||
# Diffusion Prior
|
||||
This readme serves as an introduction to the diffusion prior.
|
||||
|
||||
## Intro
|
||||
|
||||
A properly trained prior will allow you to translate between two embedding spaces. If you know *a priori* that two embeddings are connected some way—then ability the translate between them could extremely helpful.
|
||||
|
||||
### Motivation
|
||||
|
||||
Before we dive into the model, let’s look at a quick example of where the model may be helpful.
|
||||
|
||||
For demonstration purposes we will imagine that we wish to generate images from text using CLIP and a Decoder.
|
||||
|
||||
> [CLIP](https://openai.com/blog/clip/) is a contrastive model that learns to maximize the cosine similarity between a given image and caption, however, there is no guarantee that these embeddings are in the same space. While the embeddings generated are ***close*** the image and text embeddings occupy two disjoint sets.
|
||||
|
||||
```python
|
||||
# Load Models
|
||||
clip_model = clip.load("ViT-L/14")
|
||||
decoder = Decoder(checkpoint="best.pth") # A decoder trained on CLIP Image embeddings
|
||||
|
||||
# Retrieve prompt from user and encode with CLIP
|
||||
prompt = "A corgi wearing sunglasses"
|
||||
tokenized_text = tokenize(prompt)
|
||||
text_embedding = clip_model.encode_text(tokenized_text)
|
||||
|
||||
# Now, pass the text embedding to the decoder
|
||||
predicted_image = decoder.sample(text_embedding)
|
||||
```
|
||||
|
||||
> **Question**: *Can you spot the issue here?*
|
||||
>
|
||||
> **Answer**: *We’re trying to generate an image from a text embedding!*
|
||||
|
||||
Unfortunately, we run into the issue previously mentioned--the image embeddings and the text embeddings are not interchangeable! Now let's look at a better solution
|
||||
|
||||
```python
|
||||
# Load Models
|
||||
prior= Prior(checkpoint="prior.pth") # A decoder trained to go from: text-> clip text emb -> clip img emb
|
||||
decoder = Decoder(checkpoint="decoder.pth") # A decoder trained on CLIP Image embeddings
|
||||
|
||||
# Retrieve prompt from user and encode with a prior
|
||||
prompt = "A corgi wearing sunglasses"
|
||||
tokenized_text = tokenize(prompt)
|
||||
text_embedding = prior.sample(tokenized_text) # <-- now we get an embedding in the same space as images!
|
||||
|
||||
# Now, pass the predicted image embedding to the decoder
|
||||
predicted_image = decoder.sample(text_embedding)
|
||||
```
|
||||
|
||||
With the prior we are able to successfully generate embeddings *within* CLIP's image space! For this reason, the decoder will perform much better as it receives input that is much closer to its training data.
|
||||
|
||||
> **You may be asking yourself the following question:**
|
||||
>
|
||||
> *"Why don't you just train the decoder on clip text embeddings instead of image embeddings?"*
|
||||
>
|
||||
> OpenAI covers this topic in their [DALLE-2 paper](https://arxiv.org/abs/2204.06125). The TL;DR is *"it doesn't work as well as decoders trained on image embeddings"*...also...its just an example :smile:
|
||||
|
||||
## Usage
|
||||
|
||||
To utilize a pre-trained prior, it’s quite simple.
|
||||
|
||||
### Loading Checkpoints
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer
|
||||
|
||||
def load_diffusion_model(dprior_path):
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim=768,
|
||||
depth=24,
|
||||
dim_head=64,
|
||||
heads=32,
|
||||
normformer=True,
|
||||
attn_dropout=5e-2,
|
||||
ff_dropout=5e-2,
|
||||
num_time_embeds=1,
|
||||
num_image_embeds=1,
|
||||
num_text_embeds=1,
|
||||
num_timesteps=1000,
|
||||
ff_mult=4
|
||||
)
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net=prior_network,
|
||||
clip=OpenAIClipAdapter("ViT-L/14"),
|
||||
image_embed_dim=768,
|
||||
timesteps=1000,
|
||||
cond_drop_prob=0.1,
|
||||
loss_type="l2",
|
||||
condition_on_text_encodings=True,
|
||||
|
||||
)
|
||||
|
||||
trainer = DiffusionPriorTrainer(
|
||||
diffusion_prior=diffusion_prior,
|
||||
lr=1.1e-4,
|
||||
wd=6.02e-2,
|
||||
max_grad_norm=0.5,
|
||||
amp=False,
|
||||
group_wd_params=True,
|
||||
use_ema=True,
|
||||
device=device,
|
||||
accelerator=None,
|
||||
)
|
||||
|
||||
trainer.load(dprior_path)
|
||||
|
||||
return trainer
|
||||
```
|
||||
|
||||
Here we instantiate a model matches the configuration it was trained with, and then load the weights (*just like any other PyTorch model!*)
|
||||
|
||||
### Sampling
|
||||
Once we have a pre-trained model, generating embeddings is quite simple!
|
||||
```python
|
||||
# tokenize the text
|
||||
tokenized_text = clip.tokenize("<your amazing prompt>")
|
||||
# predict an embedding
|
||||
predicted_embedding = prior.sample(tokenized_text, n_samples_per_batch=2, cond_scale=1.0)
|
||||
```
|
||||
|
||||
The resulting tensor returned from `.sample()` is of the same shape as your training data along the non-batch dimension(s). For example, a prior trained on `ViT-L/14` embeddings will predict an embedding of shape (1, 768).
|
||||
|
||||
> For CLIP priors, this is quite handy as it means that you can use prior.sample(tokenizer_text) as a drop in replacement for clip.encode_text().
|
||||
|
||||
**Some things to note:**
|
||||
* It is possible to specify the number of embeddings to sample from (the default suggested by OpenAI is `n=2`). Put simply, the idea here is that you avoid getting unlucky with a bad embedding generation by creating two; and selecting the one with the higher cosine similarity with the prompt.
|
||||
* You may specify a higher conditioning scale than the default (`1.0`). It is unclear whether OpenAI uses a higher value for the prior specifically, or only on the decoder. Local testing has shown poor results with anything higher than `1.0` but *ymmv*.
|
||||
|
||||
---
|
||||
|
||||
## Training
|
||||
|
||||
### Overview
|
||||
|
||||
Training the prior is a relatively straightforward process thanks to the Trainer base class. The major step that is required of you is preparing a dataset in the format that EmbeddingReader expects. Having pre-computed embeddings massively increases training efficiency and is generally recommended as you will likely benefit from having them on hand for other tasks as well. Once you have a dataset, you are ready to move onto configuration
|
||||
|
||||
## Dataset
|
||||
|
||||
To train the prior, it is highly recommended to use precomputed embeddings for the images. To obtain these for a custom dataset, you can leverage [img2datset](https://github.com/rom1504/img2dataset) to pull images from a list of URLs and [clip_retrieval](https://github.com/rom1504/clip-retrieval#clip-inference) for generating the actual embeddings that can be used in the prior's dataloader.
|
||||
|
||||
## Configuration
|
||||
|
||||
The configuration file allows for you to easily track and reproduce experiments. It is a simple JSON file that will specify the architecture, dataset, and training parameters. For more information and specifics please see the configuration README.
|
||||
|
||||
## Distributed Training
|
||||
|
||||
If you would like to train in a distributed manner we have opted to leverage huggingface’ new Accelerate library. HFA makes it extremely simple to distribute work across multiple GPU’s and nodes. All that is required of you is to follow the simple CLI configuration tool [more information here](https://huggingface.co/docs/accelerate/accelerator).
|
||||
|
||||
## Evaluation
|
||||
|
||||
There are a variety of metrics available to you when training the prior. You can read a brief description of each in the table below:
|
||||
| Metric | Description | Comments |
|
||||
| ----------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Online Model Validation | The validation loss associated with your online model. | Ideally validation loss will be as low as possible. Using L2 loss, values as low as `0.1` and lower are possible after around 1 Billion samples seen. |
|
||||
| EMA Validation | This metric measures the validation loss associated with your EMA model. | This will likely lag behind your "online" model's validation loss, but should outperform in the long-term. |
|
||||
| Baseline Similarity | Baseline similarity refers to the similarity between your dataset's prompts and associated image embeddings. This will serve as a guide for your prior's performance in cosine similarity. | Generally `0.3` is considered a good cosine similarity for caption similarity. |
|
||||
| Similarity With Original Image | This metric will measure the cosine similarity between your prior's predicted image embedding and the actual image that the caption was associated with. This is useful for determining wether your prior is generating images with the right contents. | Values around `0.75`+ are obtainable. This metric should improve rapidly in the early stages of training and plateau with diminishing increases over time. If it takes hundreds of millions of samples to reach above `0.5`/`0.6` similarity--then you likely are suffering from some kind of training error or inefficiency (i.e. not using EMA) |
|
||||
| Difference From Baseline Similarity | Sometimes its useful to visualize a metric in another light. This metric will show you how your prior's predicted image embeddings match up with the baseline similarity measured in your dataset. | This value should float around `0.0` with some room for variation. After a billion samples seen, values are within `0.01`+/- of `0.0`. If this climbs to high, (~>`0.02`) then this may be a sign that your model is overfitting somehow. |
|
||||
| Similarity With Text | This metric is your bread and butter cosine similarity between the predicted image embedding and the original caption given to the prior. Monitoring this metric will be on of your main focuses and is probably the second most important behind your loss. | As mentioned, this value should be close to baseline similarity. We have observed early rapid increase with diminishing returns as the prior learns to generate valid image embeddings. If this value increases too far beyond the baseline similarity--it could be an indication that your model is overfitting. |
|
||||
| Similarity With Unrelated Caption | This metric will attempt to exposed an overfit prior by feeding it arbitrary prompts (from your dataset) and then measure the similarity of this predicted embedding with some other image. | Early on we found that a poorly trained/modeled prior could effectively fool CLIP into believing that the cosine similarity between two images were high (when in fact the caption and image were completely unrelated). With this in mind--a low value is ideal, anything below `0.1` is probably safe. |
|
||||
|
||||
## Launching the script
|
||||
|
||||
Now that you’ve done all the prep it’s time for the easy part! 🚀
|
||||
|
||||
To actually launch the script, you will either use `accelerate launch train_diffusion_prior.py --config_path <path to your config>` to launch with distributed training & huggingface accelerate or `python train_diffusion_prior.py` if you would like to train on your gpu/cpu without huggingface accelerate.
|
||||
|
||||
## Checkpointing
|
||||
|
||||
Checkpoints will be saved to the directory specified in your configuration file.
|
||||
|
||||
Additionally, a final checkpoint is saved before running the test split. This file will be saved to the same directory and titled “latest.pth”. This is to avoid problems where your `save_every` configuration does not overlap with the number of steps required to do a complete pass through the data.
|
||||
|
||||
## Things To Keep In Mind
|
||||
|
||||
The prior has not been trained for tasks other than the traditional CLIP embedding translation…at least yet.
|
||||
|
||||
As we finalize the replication of unCLIP, there will almost assuredly be experiments attempting to apply the prior network to other tasks.
|
||||
|
||||
With that in mind, you are more or less a pioneer in embedding-translation if you are reading this and attempting something you don’t see documentation for!
|
||||
BIN
test_data/0.tar
Normal file
BIN
test_data/0.tar
Normal file
Binary file not shown.
BIN
test_data/1.tar
Normal file
BIN
test_data/1.tar
Normal file
Binary file not shown.
BIN
test_data/2.tar
Normal file
BIN
test_data/2.tar
Normal file
Binary file not shown.
BIN
test_data/3.tar
Normal file
BIN
test_data/3.tar
Normal file
Binary file not shown.
BIN
test_data/4.tar
Normal file
BIN
test_data/4.tar
Normal file
Binary file not shown.
BIN
test_data/5.tar
Normal file
BIN
test_data/5.tar
Normal file
Binary file not shown.
BIN
test_data/6.tar
Normal file
BIN
test_data/6.tar
Normal file
Binary file not shown.
BIN
test_data/7.tar
Normal file
BIN
test_data/7.tar
Normal file
Binary file not shown.
BIN
test_data/8.tar
Normal file
BIN
test_data/8.tar
Normal file
Binary file not shown.
BIN
test_data/9.tar
Normal file
BIN
test_data/9.tar
Normal file
Binary file not shown.
@@ -132,7 +132,7 @@ def get_example_data(dataloader, device, n=5):
|
||||
break
|
||||
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
|
||||
|
||||
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""):
|
||||
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend="", match_image_size=True):
|
||||
"""
|
||||
Takes example data and generates images from the embeddings
|
||||
Returns three lists: real images, generated images, and captions
|
||||
@@ -160,6 +160,9 @@ def generate_samples(trainer, example_data, condition_on_text_encodings=False, t
|
||||
samples = trainer.sample(**sample_params)
|
||||
generated_images = list(samples)
|
||||
captions = [text_prepend + txt for txt in txts]
|
||||
if match_image_size:
|
||||
generated_image_size = generated_images[0].shape[-1]
|
||||
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
|
||||
return real_images, generated_images, captions
|
||||
|
||||
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
|
||||
@@ -167,14 +170,6 @@ def generate_grid_samples(trainer, examples, condition_on_text_encodings=False,
|
||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||
"""
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
|
||||
|
||||
real_image_size = real_images[0].shape[-1]
|
||||
generated_image_size = generated_images[0].shape[-1]
|
||||
|
||||
# training images may be larger than the generated one
|
||||
if real_image_size > generated_image_size:
|
||||
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
|
||||
|
||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||
return grid_images, captions
|
||||
|
||||
@@ -279,6 +274,7 @@ def train(
|
||||
trainer = DecoderTrainer(
|
||||
decoder=decoder,
|
||||
accelerator=accelerator,
|
||||
dataloaders=dataloaders,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -289,9 +285,8 @@ def train(
|
||||
sample = 0
|
||||
samples_seen = 0
|
||||
val_sample = 0
|
||||
step = lambda: int(trainer.step.item())
|
||||
|
||||
if tracker.loader is not None:
|
||||
if tracker.can_recall:
|
||||
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
|
||||
if next_task == 'train':
|
||||
sample = recalled_sample
|
||||
@@ -304,6 +299,8 @@ def train(
|
||||
if not exists(unet_training_mask):
|
||||
# Then the unet mask should be true for all unets in the decoder
|
||||
unet_training_mask = [True] * trainer.num_unets
|
||||
first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask)
|
||||
step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1))
|
||||
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
||||
|
||||
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
|
||||
@@ -326,7 +323,7 @@ def train(
|
||||
last_snapshot = sample
|
||||
|
||||
if next_task == 'train':
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
|
||||
for i, (img, emb, txt) in enumerate(trainer.train_loader):
|
||||
# We want to count the total number of samples across all processes
|
||||
sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
||||
@@ -419,7 +416,7 @@ def train(
|
||||
timer = Timer()
|
||||
accelerator.wait_for_everyone()
|
||||
i = 0
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
||||
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
|
||||
val_sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(val_sample_length_tensor)
|
||||
total_samples = all_samples.sum().item()
|
||||
@@ -524,6 +521,20 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
# Set up accelerator for configurable distributed training
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||
|
||||
if accelerator.num_processes > 1:
|
||||
# We are using distributed training and want to immediately ensure all can connect
|
||||
accelerator.print("Waiting for all processes to connect...")
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.print("All processes online and connected")
|
||||
|
||||
# If we are in deepspeed fp16 mode, we must ensure learned variance is off
|
||||
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
|
||||
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
|
||||
|
||||
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
|
||||
# This is an invalid configuration until we figure out how to handle this
|
||||
raise ValueError("DeepSpeed does not support multi-node distributed training")
|
||||
|
||||
# Set up data
|
||||
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
||||
@@ -546,7 +557,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
|
||||
# Create the decoder model and print basic info
|
||||
decoder = config.decoder.create()
|
||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
|
||||
|
||||
# Create and initialize the tracker if we are the master
|
||||
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
|
||||
@@ -575,7 +586,10 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
accelerator.print(print_ribbon("Loaded Config", repeat=40))
|
||||
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
|
||||
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
|
||||
accelerator.print(f"Number of parameters: {num_parameters}")
|
||||
accelerator.print(f"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training")
|
||||
for i, unet in enumerate(decoder.unets):
|
||||
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
|
||||
|
||||
train(dataloaders, decoder, accelerator,
|
||||
tracker=tracker,
|
||||
inference_device=accelerator.device,
|
||||
|
||||
@@ -126,9 +126,9 @@ def report_cosine_sims(
|
||||
|
||||
# we are text conditioned, we produce an embedding from the tokenized text
|
||||
if text_conditioned:
|
||||
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
|
||||
text_embedding, text_encodings = trainer.embed_text(text_data)
|
||||
text_cond = dict(
|
||||
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
|
||||
text_embed=text_embedding, text_encodings=text_encodings
|
||||
)
|
||||
else:
|
||||
text_embedding = text_data
|
||||
@@ -146,15 +146,12 @@ def report_cosine_sims(
|
||||
|
||||
if text_conditioned:
|
||||
text_encodings_shuffled = text_encodings[rolled_idx]
|
||||
text_mask_shuffled = text_mask[rolled_idx]
|
||||
else:
|
||||
text_encodings_shuffled = None
|
||||
text_mask_shuffled = None
|
||||
|
||||
text_cond_shuffled = dict(
|
||||
text_embed=text_embed_shuffled,
|
||||
text_encodings=text_encodings_shuffled,
|
||||
mask=text_mask_shuffled,
|
||||
text_encodings=text_encodings_shuffled
|
||||
)
|
||||
|
||||
# prepare the text embedding
|
||||
|
||||
Reference in New Issue
Block a user