mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1d9ef99288 | ||
|
|
bdd62c24b3 | ||
|
|
1f1557c614 | ||
|
|
1a217e99e3 | ||
|
|
7ea314e2f0 | ||
|
|
4173e88121 | ||
|
|
3dae43fa0e | ||
|
|
a598820012 | ||
|
|
4878762627 | ||
|
|
47ae17b36e | ||
|
|
b7e22f7da0 | ||
|
|
68de937aac | ||
|
|
097afda606 | ||
|
|
5c520db825 | ||
|
|
3070610231 | ||
|
|
870aeeca62 | ||
|
|
f28dc6dc01 | ||
|
|
081d8d3484 | ||
|
|
a71f693a26 | ||
|
|
d7bc5fbedd | ||
|
|
8c823affff | ||
|
|
ec7cab01d9 |
33
.github/workflows/ci.yml
vendored
Normal file
33
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
name: Continuous integration
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install
|
||||
run: |
|
||||
python3 -m venv .env
|
||||
source .env/bin/activate
|
||||
make install
|
||||
- name: Tests
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
make test
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -136,3 +136,5 @@ dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
.tracker_data
|
||||
*.pth
|
||||
|
||||
6
Makefile
Normal file
6
Makefile
Normal file
@@ -0,0 +1,6 @@
|
||||
install:
|
||||
pip install -U pip
|
||||
pip install -e .
|
||||
|
||||
test:
|
||||
CUDA_VISIBLE_DEVICES= python train_decoder.py --config_file configs/train_decoder_config.test.json
|
||||
@@ -355,7 +355,8 @@ prior_network = DiffusionPriorNetwork(
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
timesteps = 1000,
|
||||
sample_timesteps = 64,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
@@ -583,6 +584,7 @@ unet1 = Unet(
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
text_embed_dim = 512,
|
||||
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
|
||||
).cuda()
|
||||
|
||||
@@ -598,7 +600,8 @@ decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
timesteps = 1000,
|
||||
sample_timesteps = (250, 27),
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5
|
||||
).cuda()
|
||||
|
||||
@@ -30,6 +30,7 @@ Defines the configuration options for the decoder model. The unets defined above
|
||||
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
|
||||
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
|
||||
| `learned_variance` | No | `True` | Whether to learn the variance. |
|
||||
| `clip` | No | `None` | The clip model to use if embeddings are being generated on the fly. Takes keys `make` and `model` with defaults `openai` and `ViT-L/14`. |
|
||||
|
||||
Any parameter from the `Decoder` constructor can also be given here.
|
||||
|
||||
@@ -39,7 +40,8 @@ Settings for creation of the dataloaders.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
|
||||
| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
|
||||
| `img_embeddings_url` | No | `None` | The url of the folder containing image embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
|
||||
| `text_embeddings_url` | No | `None` | The url of the folder containing text embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
|
||||
| `num_workers` | No | `4` | The number of workers used in the dataloader. |
|
||||
| `batch_size` | No | `64` | The batch size. |
|
||||
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
|
||||
@@ -106,6 +108,13 @@ Tracking is split up into three sections:
|
||||
|
||||
**Logging:**
|
||||
|
||||
All loggers have the following keys:
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `log_type` | Yes | N/A | The type of logger class to use. |
|
||||
| `resume` | No | `False` | For loggers that have the option to resume an old run, resume it using maually input parameters. |
|
||||
| `auto_resume` | No | `False` | If true, the logger will attempt to resume an old run using parameters from that previous run. |
|
||||
|
||||
If using `console` there is no further configuration than setting `log_type` to `console`.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
@@ -119,10 +128,15 @@ If using `wandb`
|
||||
| `wandb_project` | Yes | N/A | The wandb project save the run to. |
|
||||
| `wandb_run_name` | No | `None` | The wandb run name. |
|
||||
| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
|
||||
| `wandb_resume` | No | `False` | Whether to resume an old run. |
|
||||
|
||||
**Loading:**
|
||||
|
||||
All loaders have the following keys:
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `load_from` | Yes | N/A | The type of loader class to use. |
|
||||
| `only_auto_resume` | No | `False` | If true, the loader will only load the model if the run is being auto resumed. |
|
||||
|
||||
If using `local`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
},
|
||||
"data": {
|
||||
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
|
||||
"embeddings_url": "s3://bucket/embeddings/path/",
|
||||
"img_embeddings_url": "s3://bucket/img_embeddings/path/",
|
||||
"num_workers": 4,
|
||||
"batch_size": 64,
|
||||
"start_shard": 0,
|
||||
|
||||
102
configs/train_decoder_config.test.json
Normal file
102
configs/train_decoder_config.test.json
Normal file
@@ -0,0 +1,102 @@
|
||||
{
|
||||
"decoder": {
|
||||
"unets": [
|
||||
{
|
||||
"dim": 16,
|
||||
"image_embed_dim": 768,
|
||||
"cond_dim": 16,
|
||||
"channels": 3,
|
||||
"dim_mults": [1, 2, 4, 8],
|
||||
"attn_dim_head": 16,
|
||||
"attn_heads": 4,
|
||||
"self_attn": [false, true, true, true]
|
||||
}
|
||||
],
|
||||
"clip": {
|
||||
"make": "openai",
|
||||
"model": "ViT-L/14"
|
||||
},
|
||||
|
||||
"timesteps": 10,
|
||||
"image_sizes": [64],
|
||||
"channels": 3,
|
||||
"loss_type": "l2",
|
||||
"beta_schedule": ["cosine"],
|
||||
"learned_variance": true
|
||||
},
|
||||
"data": {
|
||||
"webdataset_base_url": "test_data/{}.tar",
|
||||
"num_workers": 4,
|
||||
"batch_size": 4,
|
||||
"start_shard": 0,
|
||||
"end_shard": 9,
|
||||
"shard_width": 1,
|
||||
"index_width": 1,
|
||||
"splits": {
|
||||
"train": 0.75,
|
||||
"val": 0.15,
|
||||
"test": 0.1
|
||||
},
|
||||
"shuffle_train": false,
|
||||
"resample_train": true,
|
||||
"preprocessing": {
|
||||
"RandomResizedCrop": {
|
||||
"size": [224, 224],
|
||||
"scale": [0.75, 1.0],
|
||||
"ratio": [1.0, 1.0]
|
||||
},
|
||||
"ToTensor": true
|
||||
}
|
||||
},
|
||||
"train": {
|
||||
"epochs": 1,
|
||||
"lr": 1e-16,
|
||||
"wd": 0.01,
|
||||
"max_grad_norm": 0.5,
|
||||
"save_every_n_samples": 100,
|
||||
"n_sample_images": 1,
|
||||
"device": "cpu",
|
||||
"epoch_samples": 50,
|
||||
"validation_samples": 5,
|
||||
"use_ema": true,
|
||||
"ema_beta": 0.99,
|
||||
"amp": false,
|
||||
"save_all": false,
|
||||
"save_latest": true,
|
||||
"save_best": true,
|
||||
"unet_training_mask": [true]
|
||||
},
|
||||
"evaluate": {
|
||||
"n_evaluation_samples": 2,
|
||||
"FID": {
|
||||
"feature": 64
|
||||
},
|
||||
"IS": {
|
||||
"feature": 64,
|
||||
"splits": 10
|
||||
},
|
||||
"KID": {
|
||||
"feature": 64,
|
||||
"subset_size": 2
|
||||
},
|
||||
"LPIPS": {
|
||||
"net_type": "vgg",
|
||||
"reduction": "mean"
|
||||
}
|
||||
},
|
||||
"tracker": {
|
||||
"overwrite_data_path": true,
|
||||
|
||||
"log": {
|
||||
"log_type": "console"
|
||||
},
|
||||
|
||||
"load": {
|
||||
"load_from": null
|
||||
},
|
||||
|
||||
"save": [{
|
||||
"save_to": "local"
|
||||
}]
|
||||
}
|
||||
}
|
||||
@@ -77,6 +77,11 @@ def cast_tuple(val, length = None):
|
||||
def module_device(module):
|
||||
return next(module.parameters()).device
|
||||
|
||||
def zero_init_(m):
|
||||
nn.init.zeros_(m.weight)
|
||||
if exists(m.bias):
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
@contextmanager
|
||||
def null_context(*args, **kwargs):
|
||||
yield
|
||||
@@ -125,14 +130,23 @@ def log(t, eps = 1e-12):
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1)
|
||||
|
||||
def resize_image_to(image, target_image_size, clamp_range = None):
|
||||
def resize_image_to(
|
||||
image,
|
||||
target_image_size,
|
||||
clamp_range = None,
|
||||
nearest = False,
|
||||
**kwargs
|
||||
):
|
||||
orig_image_size = image.shape[-1]
|
||||
|
||||
if orig_image_size == target_image_size:
|
||||
return image
|
||||
|
||||
scale_factors = target_image_size / orig_image_size
|
||||
out = resize(image, scale_factors = scale_factors)
|
||||
if not nearest:
|
||||
scale_factors = target_image_size / orig_image_size
|
||||
out = resize(image, scale_factors = scale_factors, **kwargs)
|
||||
else:
|
||||
out = F.interpolate(image, target_image_size, mode = 'nearest', align_corners = False)
|
||||
|
||||
if exists(clamp_range):
|
||||
out = out.clamp(*clamp_range)
|
||||
@@ -160,6 +174,11 @@ class BaseClipAdapter(nn.Module):
|
||||
self.clip = clip
|
||||
self.overrides = kwargs
|
||||
|
||||
def validate_and_resize_image(self, image):
|
||||
image_size = image.shape[-1]
|
||||
assert image_size >= self.image_size, f'you are passing in an image of size {image_size} but CLIP requires the image size to be at least {self.image_size}'
|
||||
return resize_image_to(image, self.image_size)
|
||||
|
||||
@property
|
||||
def dim_latent(self):
|
||||
raise NotImplementedError
|
||||
@@ -206,11 +225,12 @@ class XClipAdapter(BaseClipAdapter):
|
||||
encoder_output = self.clip.text_transformer(text)
|
||||
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||
text_embed = self.clip.to_text_latent(text_cls)
|
||||
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
||||
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.validate_and_resize_image(image)
|
||||
encoder_output = self.clip.visual_transformer(image)
|
||||
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
@@ -241,11 +261,12 @@ class CoCaAdapter(BaseClipAdapter):
|
||||
text = text[..., :self.max_text_len]
|
||||
text_mask = text != 0
|
||||
text_embed, text_encodings = self.clip.embed_text(text)
|
||||
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
||||
return EmbeddedText(text_embed, text_encodings, text_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.validate_and_resize_image(image)
|
||||
image_embed, image_encodings = self.clip.embed_image(image)
|
||||
return EmbeddedImage(image_embed, image_encodings)
|
||||
|
||||
@@ -300,13 +321,14 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
|
||||
text_embed = self.clip.encode_text(text)
|
||||
text_encodings = self.text_encodings
|
||||
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
||||
del self.text_encodings
|
||||
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
assert not self.cleared
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.validate_and_resize_image(image)
|
||||
image = self.clip_normalize(image)
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||
@@ -491,6 +513,12 @@ class NoiseScheduler(nn.Module):
|
||||
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||
)
|
||||
|
||||
def predict_noise_from_start(self, x_t, t, x0):
|
||||
return (
|
||||
(x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \
|
||||
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
||||
)
|
||||
|
||||
def p2_reweigh_loss(self, loss, times):
|
||||
if not self.has_p2_loss_reweighting:
|
||||
return loss
|
||||
@@ -897,19 +925,23 @@ class DiffusionPrior(nn.Module):
|
||||
image_size = None,
|
||||
image_channels = 3,
|
||||
timesteps = 1000,
|
||||
sample_timesteps = None,
|
||||
cond_drop_prob = 0.,
|
||||
loss_type = "l2",
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
sampling_clamp_l2norm = False,
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
|
||||
sampling_final_clamp_l2norm = False, # whether to l2norm the final image embedding output (this is also done for images in ddpm)
|
||||
training_clamp_l2norm = False,
|
||||
init_image_embed_l2norm = False,
|
||||
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
clip_adapter_overrides = dict()
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_timesteps = sample_timesteps
|
||||
|
||||
self.noise_scheduler = NoiseScheduler(
|
||||
beta_schedule = beta_schedule,
|
||||
timesteps = timesteps,
|
||||
@@ -940,23 +972,32 @@ class DiffusionPrior(nn.Module):
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
|
||||
self.predict_x_start = predict_x_start
|
||||
|
||||
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
|
||||
self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)
|
||||
|
||||
# whether to force an l2norm, similar to clipping denoised, when sampling
|
||||
|
||||
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
||||
self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm
|
||||
|
||||
self.training_clamp_l2norm = training_clamp_l2norm
|
||||
self.init_image_embed_l2norm = init_image_embed_l2norm
|
||||
|
||||
# device tracker
|
||||
|
||||
self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._dummy.device
|
||||
|
||||
def l2norm_clamp_embed(self, image_embed):
|
||||
return l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
@@ -964,8 +1005,6 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
if self.predict_x_start:
|
||||
x_recon = pred
|
||||
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
|
||||
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
|
||||
else:
|
||||
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
@@ -988,21 +1027,81 @@ class DiffusionPrior(nn.Module):
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
|
||||
device = self.device
|
||||
|
||||
b = shape[0]
|
||||
image_embed = torch.randn(shape, device=device)
|
||||
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
|
||||
batch, device = shape[0], self.device
|
||||
image_embed = torch.randn(shape, device = device)
|
||||
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
|
||||
times = torch.full((b,), i, device = device, dtype = torch.long)
|
||||
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
||||
|
||||
if self.sampling_final_clamp_l2norm and self.predict_x_start:
|
||||
image_embed = self.l2norm_clamp_embed(image_embed)
|
||||
|
||||
return image_embed
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
|
||||
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
|
||||
|
||||
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
||||
|
||||
times = list(reversed(times.int().tolist()))
|
||||
time_pairs = list(zip(times[:-1], times[1:]))
|
||||
|
||||
image_embed = torch.randn(shape, device = device)
|
||||
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
||||
alpha = alphas[time]
|
||||
alpha_next = alphas[time_next]
|
||||
|
||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||
|
||||
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
x_start = pred
|
||||
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = pred)
|
||||
else:
|
||||
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
|
||||
pred_noise = pred
|
||||
|
||||
if not self.predict_x_start:
|
||||
x_start.clamp_(-1., 1.)
|
||||
|
||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||
x_start = self.l2norm_clamp_embed(x_start)
|
||||
|
||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
|
||||
|
||||
image_embed = x_start * alpha_next.sqrt() + \
|
||||
c1 * noise + \
|
||||
c2 * pred_noise
|
||||
|
||||
if self.predict_x_start and self.sampling_final_clamp_l2norm:
|
||||
image_embed = self.l2norm_clamp_embed(image_embed)
|
||||
|
||||
return image_embed
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, *args, timesteps = None, **kwargs):
|
||||
timesteps = default(timesteps, self.noise_scheduler.num_timesteps)
|
||||
assert timesteps <= self.noise_scheduler.num_timesteps
|
||||
is_ddim = timesteps < self.noise_scheduler.num_timesteps
|
||||
|
||||
if not is_ddim:
|
||||
return self.p_sample_loop_ddpm(*args, **kwargs)
|
||||
|
||||
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
|
||||
|
||||
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||
|
||||
@@ -1016,7 +1115,7 @@ class DiffusionPrior(nn.Module):
|
||||
)
|
||||
|
||||
if self.predict_x_start and self.training_clamp_l2norm:
|
||||
pred = l2norm(pred) * self.image_embed_scale
|
||||
pred = self.l2norm_clamp_embed(pred)
|
||||
|
||||
target = noise if not self.predict_x_start else image_embed
|
||||
|
||||
@@ -1037,7 +1136,15 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.):
|
||||
def sample(
|
||||
self,
|
||||
text,
|
||||
num_samples_per_batch = 2,
|
||||
cond_scale = 1.,
|
||||
timesteps = None
|
||||
):
|
||||
timesteps = default(timesteps, self.sample_timesteps)
|
||||
|
||||
# in the paper, what they did was
|
||||
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
|
||||
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
|
||||
@@ -1052,7 +1159,7 @@ class DiffusionPrior(nn.Module):
|
||||
if self.condition_on_text_encodings:
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale)
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)
|
||||
|
||||
# retrieve original unscaled image embed
|
||||
|
||||
@@ -1098,6 +1205,7 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
if self.condition_on_text_encodings:
|
||||
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
|
||||
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
|
||||
# timestep conditioning from ddpm
|
||||
@@ -1115,16 +1223,35 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
# decoder
|
||||
|
||||
def ConvTransposeUpsample(dim, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
|
||||
class PixelShuffleUpsample(nn.Module):
|
||||
"""
|
||||
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
|
||||
https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
|
||||
"""
|
||||
def __init__(self, dim, dim_out = None):
|
||||
super().__init__()
|
||||
dim_out = default(dim_out, dim)
|
||||
conv = nn.Conv2d(dim, dim_out * 4, 1)
|
||||
|
||||
def NearestUpsample(dim, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
return nn.Sequential(
|
||||
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
||||
nn.Conv2d(dim, dim_out, 3, padding = 1)
|
||||
)
|
||||
self.net = nn.Sequential(
|
||||
conv,
|
||||
nn.SiLU(),
|
||||
nn.PixelShuffle(2)
|
||||
)
|
||||
|
||||
self.init_conv_(conv)
|
||||
|
||||
def init_conv_(self, conv):
|
||||
o, i, h, w = conv.weight.shape
|
||||
conv_weight = torch.empty(o // 4, i, h, w)
|
||||
nn.init.kaiming_uniform_(conv_weight)
|
||||
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
|
||||
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
def Downsample(dim, *, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
@@ -1388,7 +1515,7 @@ class Unet(nn.Module):
|
||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||
memory_efficient = False,
|
||||
scale_skip_connection = False,
|
||||
nearest_upsample = False,
|
||||
pixel_shuffle_upsample = True,
|
||||
final_conv_kernel_size = 1,
|
||||
**kwargs
|
||||
):
|
||||
@@ -1454,10 +1581,12 @@ class Unet(nn.Module):
|
||||
# text encoding conditioning (optional)
|
||||
|
||||
self.text_to_cond = None
|
||||
self.text_embed_dim = None
|
||||
|
||||
if cond_on_text_encodings:
|
||||
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
|
||||
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
|
||||
self.text_embed_dim = text_embed_dim
|
||||
|
||||
# finer control over whether to condition on image embeddings and text encodings
|
||||
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
||||
@@ -1500,7 +1629,7 @@ class Unet(nn.Module):
|
||||
|
||||
# upsample klass
|
||||
|
||||
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
|
||||
upsample_klass = ConvTransposeUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
|
||||
|
||||
# give memory efficient unet an initial resnet block
|
||||
|
||||
@@ -1564,6 +1693,8 @@ class Unet(nn.Module):
|
||||
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
||||
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
||||
|
||||
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
# for cascading DDPM, then reinit the unet with the right settings
|
||||
def cast_model_parameters(
|
||||
@@ -1686,6 +1817,8 @@ class Unet(nn.Module):
|
||||
text_tokens = None
|
||||
|
||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
|
||||
|
||||
text_tokens = self.text_to_cond(text_encodings)
|
||||
text_tokens = text_tokens[:, :self.max_text_len]
|
||||
|
||||
@@ -1781,12 +1914,15 @@ class LowresConditioner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
downsample_first = True,
|
||||
downsample_mode_nearest = False,
|
||||
blur_sigma = 0.6,
|
||||
blur_kernel_size = 3,
|
||||
input_image_range = None
|
||||
):
|
||||
super().__init__()
|
||||
self.downsample_first = downsample_first
|
||||
self.downsample_mode_nearest = downsample_mode_nearest
|
||||
|
||||
self.input_image_range = input_image_range
|
||||
|
||||
self.blur_sigma = blur_sigma
|
||||
@@ -1802,7 +1938,7 @@ class LowresConditioner(nn.Module):
|
||||
blur_kernel_size = None
|
||||
):
|
||||
if self.training and self.downsample_first and exists(downsample_image_size):
|
||||
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range)
|
||||
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = self.downsample_mode_nearest)
|
||||
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
@@ -1836,6 +1972,7 @@ class Decoder(nn.Module):
|
||||
channels = 3,
|
||||
vae = tuple(),
|
||||
timesteps = 1000,
|
||||
sample_timesteps = None,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
loss_type = 'l2',
|
||||
@@ -1845,6 +1982,7 @@ class Decoder(nn.Module):
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
lowres_downsample_mode_nearest = False, # cascading ddpm - whether to use nearest mode downsampling for lower resolution
|
||||
blur_sigma = 0.6, # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
clip_denoised = True,
|
||||
@@ -1858,7 +1996,8 @@ class Decoder(nn.Module):
|
||||
use_dynamic_thres = False, # from the Imagen paper
|
||||
dynamic_thres_percentile = 0.9,
|
||||
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
|
||||
p2_loss_weight_k = 1
|
||||
p2_loss_weight_k = 1,
|
||||
ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1938,6 +2077,11 @@ class Decoder(nn.Module):
|
||||
self.unets.append(one_unet)
|
||||
self.vaes.append(one_vae.copy_for_eval())
|
||||
|
||||
# sampling timesteps, defaults to non-ddim with full timesteps sampling
|
||||
|
||||
self.sample_timesteps = cast_tuple(sample_timesteps, num_unets)
|
||||
self.ddim_sampling_eta = ddim_sampling_eta
|
||||
|
||||
# create noise schedulers per unet
|
||||
|
||||
if not exists(beta_schedule):
|
||||
@@ -1948,7 +2092,9 @@ class Decoder(nn.Module):
|
||||
|
||||
self.noise_schedulers = nn.ModuleList([])
|
||||
|
||||
for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(beta_schedule, p2_loss_weight_gamma):
|
||||
for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
|
||||
assert not exists(sample_timesteps) or sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
|
||||
|
||||
noise_scheduler = NoiseScheduler(
|
||||
beta_schedule = unet_beta_schedule,
|
||||
timesteps = timesteps,
|
||||
@@ -1987,6 +2133,7 @@ class Decoder(nn.Module):
|
||||
|
||||
self.to_lowres_cond = LowresConditioner(
|
||||
downsample_first = lowres_downsample_first,
|
||||
downsample_mode_nearest = lowres_downsample_mode_nearest,
|
||||
blur_sigma = blur_sigma,
|
||||
blur_kernel_size = blur_kernel_size,
|
||||
input_image_range = self.input_image_range
|
||||
@@ -2048,6 +2195,26 @@ class Decoder(nn.Module):
|
||||
for unet, device in zip(self.unets, devices):
|
||||
unet.to(device)
|
||||
|
||||
def dynamic_threshold(self, x):
|
||||
""" proposed in https://arxiv.org/abs/2205.11487 as an improved clamping in the setting of classifier free guidance """
|
||||
|
||||
# s is the threshold amount
|
||||
# static thresholding would just be s = 1
|
||||
s = 1.
|
||||
if self.use_dynamic_thres:
|
||||
s = torch.quantile(
|
||||
rearrange(x, 'b ... -> b (...)').abs(),
|
||||
self.dynamic_thres_percentile,
|
||||
dim = -1
|
||||
)
|
||||
|
||||
s.clamp_(min = 1.)
|
||||
s = s.view(-1, *((1,) * (x.ndim - 1)))
|
||||
|
||||
# clip by threshold, depending on whether static or dynamic
|
||||
x = x.clamp(-s, s) / s
|
||||
return x
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
@@ -2062,21 +2229,7 @@ class Decoder(nn.Module):
|
||||
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised:
|
||||
# s is the threshold amount
|
||||
# static thresholding would just be s = 1
|
||||
s = 1.
|
||||
if self.use_dynamic_thres:
|
||||
s = torch.quantile(
|
||||
rearrange(x_recon, 'b ... -> b (...)').abs(),
|
||||
self.dynamic_thres_percentile,
|
||||
dim = -1
|
||||
)
|
||||
|
||||
s.clamp_(min = 1.)
|
||||
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
|
||||
|
||||
# clip by threshold, depending on whether static or dynamic
|
||||
x_recon = x_recon.clamp(-s, s) / s
|
||||
x_recon = self.dynamic_threshold(x_recon)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
|
||||
@@ -2106,7 +2259,7 @@ class Decoder(nn.Module):
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
device = self.device
|
||||
|
||||
b = shape[0]
|
||||
@@ -2134,6 +2287,62 @@ class Decoder(nn.Module):
|
||||
unnormalize_img = self.unnormalize_img(img)
|
||||
return unnormalize_img
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
|
||||
|
||||
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
||||
|
||||
times = list(reversed(times.int().tolist()))
|
||||
time_pairs = list(zip(times[:-1], times[1:]))
|
||||
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
||||
alpha = alphas[time]
|
||||
alpha_next = alphas[time_next]
|
||||
|
||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||
|
||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
if learned_variance:
|
||||
pred, _ = pred.chunk(2, dim = 1)
|
||||
|
||||
if predict_x_start:
|
||||
x_start = pred
|
||||
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
|
||||
else:
|
||||
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
|
||||
pred_noise = pred
|
||||
|
||||
if clip_denoised:
|
||||
x_start = self.dynamic_threshold(x_start)
|
||||
|
||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||
noise = torch.randn_like(img) if time_next > 0 else 0.
|
||||
|
||||
img = x_start * alpha_next.sqrt() + \
|
||||
c1 * noise + \
|
||||
c2 * pred_noise
|
||||
|
||||
img = self.unnormalize_img(img)
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, *args, noise_scheduler, timesteps = None, **kwargs):
|
||||
num_timesteps = noise_scheduler.num_timesteps
|
||||
|
||||
timesteps = default(timesteps, num_timesteps)
|
||||
assert timesteps <= num_timesteps
|
||||
is_ddim = timesteps < num_timesteps
|
||||
|
||||
if not is_ddim:
|
||||
return self.p_sample_loop_ddpm(*args, noise_scheduler = noise_scheduler, **kwargs)
|
||||
|
||||
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
|
||||
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
@@ -2231,10 +2440,13 @@ class Decoder(nn.Module):
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
|
||||
if self.condition_on_text_encodings:
|
||||
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
|
||||
|
||||
img = None
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers)):
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps)):
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
|
||||
|
||||
@@ -2263,7 +2475,8 @@ class Decoder(nn.Module):
|
||||
clip_denoised = not is_latent_diffusion,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
is_latent_diffusion = is_latent_diffusion,
|
||||
noise_scheduler = noise_scheduler
|
||||
noise_scheduler = noise_scheduler,
|
||||
timesteps = sample_timesteps
|
||||
)
|
||||
|
||||
img = vae.decode(img)
|
||||
@@ -2313,6 +2526,9 @@ class Decoder(nn.Module):
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
|
||||
if self.condition_on_text_encodings:
|
||||
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
|
||||
|
||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||
image = resize_image_to(image, target_image_size)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import webdataset as wds
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
import fsspec
|
||||
import shutil
|
||||
@@ -255,7 +256,7 @@ def create_image_embedding_dataloader(
|
||||
)
|
||||
if shuffle_num is not None and shuffle_num > 0:
|
||||
ds.shuffle(1000)
|
||||
return wds.WebLoader(
|
||||
return DataLoader(
|
||||
ds,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import urllib.request
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from itertools import zip_longest
|
||||
@@ -37,14 +38,17 @@ class BaseLogger:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
verbose (bool): Whether of not to always print logs to the console.
|
||||
"""
|
||||
def __init__(self, data_path: str, verbose: bool = False, **kwargs):
|
||||
def __init__(self, data_path: str, resume: bool = False, auto_resume: bool = False, verbose: bool = False, **kwargs):
|
||||
self.data_path = Path(data_path)
|
||||
self.resume = resume
|
||||
self.auto_resume = auto_resume
|
||||
self.verbose = verbose
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
"""
|
||||
Initializes the logger.
|
||||
Errors if the logger is invalid.
|
||||
full_config is the config file dict while extra_config is anything else from the script that is not defined the config file.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -60,6 +64,14 @@ class BaseLogger:
|
||||
def log_error(self, error_string, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_resume_data(self, **kwargs) -> dict:
|
||||
"""
|
||||
Sets tracker attributes that along with { "resume": True } will be used to resume training.
|
||||
It is assumed that after init is called this data will be complete.
|
||||
If the logger does not have any resume functionality, it should return an empty dict.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
class ConsoleLogger(BaseLogger):
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
print("Logging to console")
|
||||
@@ -76,6 +88,9 @@ class ConsoleLogger(BaseLogger):
|
||||
def log_error(self, error_string, **kwargs) -> None:
|
||||
print(error_string)
|
||||
|
||||
def get_resume_data(self, **kwargs) -> dict:
|
||||
return {}
|
||||
|
||||
class WandbLogger(BaseLogger):
|
||||
"""
|
||||
Logs to a wandb run.
|
||||
@@ -85,7 +100,6 @@ class WandbLogger(BaseLogger):
|
||||
wandb_project (str): The wandb project to log to.
|
||||
wandb_run_id (str): The wandb run id to resume.
|
||||
wandb_run_name (str): The wandb run name to use.
|
||||
wandb_resume (bool): Whether to resume a wandb run.
|
||||
"""
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
@@ -93,7 +107,6 @@ class WandbLogger(BaseLogger):
|
||||
wandb_project: str,
|
||||
wandb_run_id: Optional[str] = None,
|
||||
wandb_run_name: Optional[str] = None,
|
||||
wandb_resume: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(data_path, **kwargs)
|
||||
@@ -101,7 +114,6 @@ class WandbLogger(BaseLogger):
|
||||
self.project = wandb_project
|
||||
self.run_id = wandb_run_id
|
||||
self.run_name = wandb_run_name
|
||||
self.resume = wandb_resume
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
|
||||
@@ -149,6 +161,14 @@ class WandbLogger(BaseLogger):
|
||||
print(error_string)
|
||||
self.wandb.log({"error": error_string, **kwargs}, step=step)
|
||||
|
||||
def get_resume_data(self, **kwargs) -> dict:
|
||||
# In order to resume, we need wandb_entity, wandb_project, and wandb_run_id
|
||||
return {
|
||||
"entity": self.entity,
|
||||
"project": self.project,
|
||||
"run_id": self.wandb.run.id
|
||||
}
|
||||
|
||||
logger_type_map = {
|
||||
'console': ConsoleLogger,
|
||||
'wandb': WandbLogger,
|
||||
@@ -168,8 +188,9 @@ class BaseLoader:
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
"""
|
||||
def __init__(self, data_path: str, **kwargs):
|
||||
def __init__(self, data_path: str, only_auto_resume: bool = False, **kwargs):
|
||||
self.data_path = Path(data_path)
|
||||
self.only_auto_resume = only_auto_resume
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
@@ -304,6 +325,10 @@ class LocalSaver(BaseSaver):
|
||||
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
|
||||
# Copy the file to save_path
|
||||
save_path_file_name = Path(save_path).name
|
||||
# Make sure parent directory exists
|
||||
save_path_parent = Path(save_path).parent
|
||||
if not save_path_parent.exists():
|
||||
save_path_parent.mkdir(parents=True)
|
||||
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
|
||||
shutil.copy(local_path, save_path)
|
||||
|
||||
@@ -385,11 +410,7 @@ class Tracker:
|
||||
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
|
||||
self.data_path = Path(data_path)
|
||||
if not dummy_mode:
|
||||
if overwrite_data_path:
|
||||
if self.data_path.exists():
|
||||
shutil.rmtree(self.data_path)
|
||||
self.data_path.mkdir(parents=True)
|
||||
else:
|
||||
if not overwrite_data_path:
|
||||
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
|
||||
if not self.data_path.exists():
|
||||
self.data_path.mkdir(parents=True)
|
||||
@@ -398,7 +419,46 @@ class Tracker:
|
||||
self.savers: List[BaseSaver]= []
|
||||
self.dummy_mode = dummy_mode
|
||||
|
||||
def _load_auto_resume(self) -> bool:
|
||||
# If the file does not exist, we return False. If autoresume is enabled we print a warning so that the user can know that this is the first run.
|
||||
if not self.auto_resume_path.exists():
|
||||
if self.logger.auto_resume:
|
||||
print("Auto_resume is enabled but no auto_resume.json file exists. Assuming this is the first run.")
|
||||
return False
|
||||
|
||||
# Now we know that the autoresume file exists, but if we are not auto resuming we should remove it so that we don't accidentally load it next time
|
||||
if not self.logger.auto_resume:
|
||||
print(f'Removing auto_resume.json because auto_resume is not enabled in the config')
|
||||
self.auto_resume_path.unlink()
|
||||
return False
|
||||
|
||||
# Otherwise we read the json into a dictionary will will override parts of logger.__dict__
|
||||
with open(self.auto_resume_path, 'r') as f:
|
||||
auto_resume_dict = json.load(f)
|
||||
# Check if the logger is of the same type as the autoresume save
|
||||
if auto_resume_dict["logger_type"] != self.logger.__class__.__name__:
|
||||
raise Exception(f'The logger type in the auto_resume file is {auto_resume_dict["logger_type"]} but the current logger is {self.logger.__class__.__name__}. Either use the original logger type, set `auto_resume` to `False`, or delete your existing tracker-data folder.')
|
||||
# Then we are ready to override the logger with the autoresume save
|
||||
self.logger.__dict__["resume"] = True
|
||||
print(f"Updating {self.logger.__dict__} with {auto_resume_dict}")
|
||||
self.logger.__dict__.update(auto_resume_dict)
|
||||
return True
|
||||
|
||||
def _save_auto_resume(self):
|
||||
# Gets the autoresume dict from the logger and adds "logger_type" to it then saves it to the auto_resume file
|
||||
auto_resume_dict = self.logger.get_resume_data()
|
||||
auto_resume_dict['logger_type'] = self.logger.__class__.__name__
|
||||
with open(self.auto_resume_path, 'w') as f:
|
||||
json.dump(auto_resume_dict, f)
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict):
|
||||
self.auto_resume_path = self.data_path / 'auto_resume.json'
|
||||
# Check for resuming the run
|
||||
self.did_auto_resume = self._load_auto_resume()
|
||||
if self.did_auto_resume:
|
||||
print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
|
||||
print(f"New logger config: {self.logger.__dict__}")
|
||||
|
||||
assert self.logger is not None, '`logger` must be set before `init` is called'
|
||||
if self.dummy_mode:
|
||||
# The only thing we need is a loader
|
||||
@@ -406,12 +466,17 @@ class Tracker:
|
||||
self.loader.init(self.logger)
|
||||
return
|
||||
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
|
||||
|
||||
self.logger.init(full_config, extra_config)
|
||||
if self.loader is not None:
|
||||
self.loader.init(self.logger)
|
||||
for saver in self.savers:
|
||||
saver.init(self.logger)
|
||||
|
||||
if self.logger.auto_resume:
|
||||
# Then we need to save the autoresume file. It is assumed after logger.init is called that the logger is ready to be saved.
|
||||
self._save_auto_resume()
|
||||
|
||||
def add_logger(self, logger: BaseLogger):
|
||||
self.logger = logger
|
||||
|
||||
@@ -503,11 +568,16 @@ class Tracker:
|
||||
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
|
||||
print(f'Error saving checkpoint: {e}')
|
||||
|
||||
@property
|
||||
def can_recall(self):
|
||||
# Defines whether a recall can be performed.
|
||||
return self.loader is not None and (not self.loader.only_auto_resume or self.did_auto_resume)
|
||||
|
||||
def recall(self):
|
||||
if self.loader is not None:
|
||||
if self.can_recall:
|
||||
return self.loader.recall()
|
||||
else:
|
||||
raise ValueError('No loader specified')
|
||||
raise ValueError('Tried to recall, but no loader was set or auto-resume was not performed.')
|
||||
|
||||
|
||||
|
||||
@@ -47,6 +47,8 @@ class TrainSplitConfig(BaseModel):
|
||||
|
||||
class TrackerLogConfig(BaseModel):
|
||||
log_type: str = 'console'
|
||||
resume: bool = False # For logs that are saved to unique locations, resume a previous run
|
||||
auto_resume: bool = False # If the process crashes and restarts, resume from the run that crashed
|
||||
verbose: bool = False
|
||||
|
||||
class Config:
|
||||
@@ -59,6 +61,7 @@ class TrackerLogConfig(BaseModel):
|
||||
|
||||
class TrackerLoadConfig(BaseModel):
|
||||
load_from: Optional[str] = None
|
||||
only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
@@ -151,6 +154,7 @@ class DiffusionPriorConfig(BaseModel):
|
||||
image_size: int
|
||||
image_channels: int = 3
|
||||
timesteps: int = 1000
|
||||
sample_timesteps: Optional[int] = None
|
||||
cond_drop_prob: float = 0.
|
||||
loss_type: str = 'l2'
|
||||
predict_x_start: bool = True
|
||||
@@ -230,6 +234,7 @@ class DecoderConfig(BaseModel):
|
||||
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
||||
channels: int = 3
|
||||
timesteps: int = 1000
|
||||
sample_timesteps: Optional[SingularOrIterable(int)] = None
|
||||
loss_type: str = 'l2'
|
||||
beta_schedule: ListOrTuple(str) = 'cosine'
|
||||
learned_variance: bool = True
|
||||
|
||||
@@ -21,7 +21,7 @@ import pytorch_warmup as warmup
|
||||
|
||||
from ema_pytorch import EMA
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DistributedType
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -76,6 +76,7 @@ def cast_torch_tensor(fn):
|
||||
def inner(model, *args, **kwargs):
|
||||
device = kwargs.pop('_device', next(model.parameters()).device)
|
||||
cast_device = kwargs.pop('_cast_device', True)
|
||||
cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)
|
||||
|
||||
kwargs_keys = kwargs.keys()
|
||||
all_args = (*args, *kwargs.values())
|
||||
@@ -85,6 +86,21 @@ def cast_torch_tensor(fn):
|
||||
if cast_device:
|
||||
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||
|
||||
if cast_deepspeed_precision:
|
||||
try:
|
||||
accelerator = model.accelerator
|
||||
if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
cast_type_map = {
|
||||
"fp16": torch.half,
|
||||
"bf16": torch.bfloat16,
|
||||
"no": torch.float
|
||||
}
|
||||
precision_type = cast_type_map[accelerator.mixed_precision]
|
||||
all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||
except AttributeError:
|
||||
# Then this model doesn't have an accelerator
|
||||
pass
|
||||
|
||||
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
|
||||
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
|
||||
|
||||
@@ -192,6 +208,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
self.device = diffusion_prior_device
|
||||
else:
|
||||
self.device = accelerator.device if exists(accelerator) else device
|
||||
diffusion_prior.to(self.device)
|
||||
|
||||
# save model
|
||||
|
||||
@@ -445,6 +462,7 @@ class DecoderTrainer(nn.Module):
|
||||
self,
|
||||
decoder,
|
||||
accelerator = None,
|
||||
dataloaders = None,
|
||||
use_ema = True,
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
@@ -507,11 +525,31 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
|
||||
|
||||
if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:
|
||||
# Then we need to make sure clip is using the correct precision or else deepspeed will error
|
||||
cast_type_map = {
|
||||
"fp16": torch.half,
|
||||
"bf16": torch.bfloat16,
|
||||
"no": torch.float
|
||||
}
|
||||
precision_type = cast_type_map[accelerator.mixed_precision]
|
||||
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
|
||||
clip = decoder.clip
|
||||
clip.to(precision_type)
|
||||
|
||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||
schedulers = list(self.accelerator.prepare(*schedulers))
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
# prepare dataloaders
|
||||
|
||||
train_loader = val_loader = None
|
||||
if exists(dataloaders):
|
||||
train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"])
|
||||
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
|
||||
# store optimizers
|
||||
|
||||
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
|
||||
@@ -526,6 +564,17 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.warmup_schedulers = warmup_schedulers
|
||||
|
||||
def validate_and_return_unet_number(self, unet_number = None):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
|
||||
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
|
||||
return unet_number
|
||||
|
||||
def num_steps_taken(self, unet_number = None):
|
||||
unet_number = self.validate_and_return_unet_number(unet_number)
|
||||
return self.steps[unet_number - 1].item()
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
@@ -594,10 +643,7 @@ class DecoderTrainer(nn.Module):
|
||||
self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
|
||||
|
||||
def update(self, unet_number = None):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
|
||||
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
|
||||
unet_number = self.validate_and_return_unet_number(unet_number)
|
||||
index = unet_number - 1
|
||||
|
||||
optimizer = getattr(self, f'optim{index}')
|
||||
@@ -663,11 +709,13 @@ class DecoderTrainer(nn.Module):
|
||||
max_batch_size = None,
|
||||
**kwargs
|
||||
):
|
||||
if self.num_unets == 1:
|
||||
unet_number = default(unet_number, 1)
|
||||
unet_number = self.validate_and_return_unet_number(unet_number)
|
||||
|
||||
total_loss = 0.
|
||||
|
||||
|
||||
using_amp = self.accelerator.mixed_precision != 'no'
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
with self.accelerator.autocast():
|
||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.16.16'
|
||||
__version__ = '0.21.0'
|
||||
|
||||
BIN
test_data/0.tar
Normal file
BIN
test_data/0.tar
Normal file
Binary file not shown.
BIN
test_data/1.tar
Normal file
BIN
test_data/1.tar
Normal file
Binary file not shown.
BIN
test_data/2.tar
Normal file
BIN
test_data/2.tar
Normal file
Binary file not shown.
BIN
test_data/3.tar
Normal file
BIN
test_data/3.tar
Normal file
Binary file not shown.
BIN
test_data/4.tar
Normal file
BIN
test_data/4.tar
Normal file
Binary file not shown.
BIN
test_data/5.tar
Normal file
BIN
test_data/5.tar
Normal file
Binary file not shown.
BIN
test_data/6.tar
Normal file
BIN
test_data/6.tar
Normal file
Binary file not shown.
BIN
test_data/7.tar
Normal file
BIN
test_data/7.tar
Normal file
Binary file not shown.
BIN
test_data/8.tar
Normal file
BIN
test_data/8.tar
Normal file
Binary file not shown.
BIN
test_data/9.tar
Normal file
BIN
test_data/9.tar
Normal file
Binary file not shown.
@@ -132,7 +132,7 @@ def get_example_data(dataloader, device, n=5):
|
||||
break
|
||||
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
|
||||
|
||||
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""):
|
||||
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend="", match_image_size=True):
|
||||
"""
|
||||
Takes example data and generates images from the embeddings
|
||||
Returns three lists: real images, generated images, and captions
|
||||
@@ -160,6 +160,9 @@ def generate_samples(trainer, example_data, condition_on_text_encodings=False, t
|
||||
samples = trainer.sample(**sample_params)
|
||||
generated_images = list(samples)
|
||||
captions = [text_prepend + txt for txt in txts]
|
||||
if match_image_size:
|
||||
generated_image_size = generated_images[0].shape[-1]
|
||||
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
|
||||
return real_images, generated_images, captions
|
||||
|
||||
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
|
||||
@@ -167,14 +170,6 @@ def generate_grid_samples(trainer, examples, condition_on_text_encodings=False,
|
||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||
"""
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
|
||||
|
||||
real_image_size = real_images[0].shape[-1]
|
||||
generated_image_size = generated_images[0].shape[-1]
|
||||
|
||||
# training images may be larger than the generated one
|
||||
if real_image_size > generated_image_size:
|
||||
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
|
||||
|
||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||
return grid_images, captions
|
||||
|
||||
@@ -279,6 +274,7 @@ def train(
|
||||
trainer = DecoderTrainer(
|
||||
decoder=decoder,
|
||||
accelerator=accelerator,
|
||||
dataloaders=dataloaders,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -289,9 +285,8 @@ def train(
|
||||
sample = 0
|
||||
samples_seen = 0
|
||||
val_sample = 0
|
||||
step = lambda: int(trainer.step.item())
|
||||
|
||||
if tracker.loader is not None:
|
||||
if tracker.can_recall:
|
||||
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
|
||||
if next_task == 'train':
|
||||
sample = recalled_sample
|
||||
@@ -304,6 +299,8 @@ def train(
|
||||
if not exists(unet_training_mask):
|
||||
# Then the unet mask should be true for all unets in the decoder
|
||||
unet_training_mask = [True] * trainer.num_unets
|
||||
first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask)
|
||||
step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1))
|
||||
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
||||
|
||||
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
|
||||
@@ -326,7 +323,7 @@ def train(
|
||||
last_snapshot = sample
|
||||
|
||||
if next_task == 'train':
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
|
||||
for i, (img, emb, txt) in enumerate(trainer.train_loader):
|
||||
# We want to count the total number of samples across all processes
|
||||
sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
||||
@@ -419,7 +416,7 @@ def train(
|
||||
timer = Timer()
|
||||
accelerator.wait_for_everyone()
|
||||
i = 0
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
||||
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
|
||||
val_sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(val_sample_length_tensor)
|
||||
total_samples = all_samples.sum().item()
|
||||
@@ -524,6 +521,20 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
# Set up accelerator for configurable distributed training
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||
|
||||
if accelerator.num_processes > 1:
|
||||
# We are using distributed training and want to immediately ensure all can connect
|
||||
accelerator.print("Waiting for all processes to connect...")
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.print("All processes online and connected")
|
||||
|
||||
# If we are in deepspeed fp16 mode, we must ensure learned variance is off
|
||||
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
|
||||
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
|
||||
|
||||
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
|
||||
# This is an invalid configuration until we figure out how to handle this
|
||||
raise ValueError("DeepSpeed does not support multi-node distributed training")
|
||||
|
||||
# Set up data
|
||||
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
||||
@@ -546,7 +557,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
|
||||
# Create the decoder model and print basic info
|
||||
decoder = config.decoder.create()
|
||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
|
||||
|
||||
# Create and initialize the tracker if we are the master
|
||||
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
|
||||
@@ -575,7 +586,10 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
accelerator.print(print_ribbon("Loaded Config", repeat=40))
|
||||
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
|
||||
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
|
||||
accelerator.print(f"Number of parameters: {num_parameters}")
|
||||
accelerator.print(f"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training")
|
||||
for i, unet in enumerate(decoder.unets):
|
||||
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
|
||||
|
||||
train(dataloaders, decoder, accelerator,
|
||||
tracker=tracker,
|
||||
inference_device=accelerator.device,
|
||||
|
||||
Reference in New Issue
Block a user