Compare commits

..

1 Commits

11 changed files with 57 additions and 361 deletions

3
.gitignore vendored
View File

@@ -1,6 +1,3 @@
# default experiment tracker data
.tracker-data/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

View File

@@ -14,16 +14,6 @@ Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
## Status
- A research group has used the code in this repository to train a functional diffusion prior for their CLIP generations. Will share their work once they release their preprint. This, and <a href="https://github.com/crowsonkb">Katherine's</a> own experiments, validate OpenAI's finding that the extra prior increases variety of generations.
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
<img src="./samples/oxford.png" width="600px" />
*ongoing at 21k steps*
## Install
```bash
@@ -824,8 +814,8 @@ clip = CLIP(
# mock data
text = torch.randint(0, 49408, (512, 256)).cuda()
images = torch.randn(512, 3, 256, 256).cuda()
text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(32, 3, 256, 256).cuda()
# prior networks (with transformer)
@@ -858,7 +848,7 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
# after much of the above three lines in a loop
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
image_embeds = diffusion_prior_trainer.sample(text, max_batch_size = 4) # (512, 512) - exponential moving averaged image embeddings
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
```
## Bonus
@@ -871,7 +861,7 @@ ex.
```python
import torch
from dalle2_pytorch import Unet, Decoder, DecoderTrainer
from dalle2_pytorch import Unet, Decoder
# unet for the cascading ddpm
@@ -894,24 +884,20 @@ decoder = Decoder(
unconditional = True
).cuda()
# decoder trainer
decoder_trainer = DecoderTrainer(decoder)
# images (get a lot of this)
# mock images (get a lot of this)
images = torch.randn(1, 3, 512, 512).cuda()
# feed images into decoder
for i in (1, 2):
loss = decoder_trainer(images, unet_number = i)
decoder_trainer.update(unet_number = i)
loss = decoder(images, unet_number = i)
loss.backward()
# do the above for many many many many images
# do the above for many many many many steps
# then it will learn to generate images
images = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512)
images = decoder.sample(batch_size = 2) # (2, 3, 512, 512)
```
## Dataloaders

View File

@@ -1697,8 +1697,7 @@ class Decoder(BaseGaussianDiffusion):
clip_adapter_overrides = dict(),
learned_variance = True,
vb_loss_weight = 0.001,
unconditional = False,
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
unconditional = False
):
super().__init__(
beta_schedule = beta_schedule,
@@ -1807,10 +1806,6 @@ class Decoder(BaseGaussianDiffusion):
self.clip_denoised = clip_denoised
self.clip_x_start = clip_x_start
# normalize and unnormalize image functions
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
@@ -1875,14 +1870,13 @@ class Decoder(BaseGaussianDiffusion):
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device = device)
if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(
@@ -1899,17 +1893,16 @@ class Decoder(BaseGaussianDiffusion):
clip_denoised = clip_denoised
)
unnormalize_img = self.unnormalize_img(img)
unnormalize_img = unnormalize_zero_to_one(img)
return unnormalize_img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
if not is_latent_diffusion:
x_start = self.normalize_img(x_start)
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
x_start = normalize_neg_one_to_one(x_start)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
# get x_t
@@ -1987,7 +1980,7 @@ class Decoder(BaseGaussianDiffusion):
batch_size = image_embed.shape[0]
if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exists(self.clip)
assert exist(self.clip)
_, text_encodings, text_mask = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
@@ -2023,8 +2016,7 @@ class Decoder(BaseGaussianDiffusion):
predict_x_start = predict_x_start,
learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img,
is_latent_diffusion = is_latent_diffusion
lowres_cond_img = lowres_cond_img
)
img = vae.decode(img)
@@ -2083,14 +2075,12 @@ class Decoder(BaseGaussianDiffusion):
image = aug(image)
lowres_cond_img = aug(lowres_cond_img, params = aug._params)
is_latent_diffusion = not isinstance(vae, NullVQGanVAE)
vae.eval()
with torch.no_grad():
image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance)
# main class

View File

@@ -1,41 +0,0 @@
## Dataloaders
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
### Decoder: Image Embedding Dataset
When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
Generating a dataset of this type:
1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
Usage:
```python
from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader
# Create a dataloader directly.
dataloader = create_image_embedding_dataloader(
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
num_workers=4,
batch_size=32,
shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
shuffle_num=200, # Does a shuffle of the data with a buffer size of 200
shuffle_shards=True, # Shuffle the order the shards are read in
resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
)
for img, emb in dataloader:
print(img.shape) # torch.Size([32, 3, 256, 256])
print(emb.shape) # torch.Size([32, 512])
# Train decoder only as shown above
# Or create a dataset without a loader so you can configure it manually
dataset = ImageEmbeddingDataset(
urls="/path/or/url/to/webdataset/{0000..9999}.tar",
embedding_folder_url="path/or/url/to/embeddings/folder",
shard_width=4,
shuffle_shards=True,
resample=False
)
```

View File

@@ -3,7 +3,6 @@ import webdataset as wds
import torch
import numpy as np
import fsspec
import shutil
def get_shard(filename):
"""
@@ -21,7 +20,7 @@ def get_example_file(fs, path, file_format):
"""
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handlers.reraise_exception):
def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handlers.reraise_exception):
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
previous_tar_url = None
current_embeddings = None
@@ -51,12 +50,8 @@ def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handler
previous_tar_url = tar_url
current_embeddings = load_corresponding_embeds(tar_url)
embedding_index = int(key[-index_width:])
embedding = current_embeddings[embedding_index]
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
if torch.count_nonzero(embedding) == 0:
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
sample["npy"] = embedding
embedding_index = int(key[shard_width:])
sample["npy"] = current_embeddings[embedding_index]
yield sample
except Exception as exn: # From wds implementation
if handler(exn):
@@ -65,28 +60,6 @@ def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handler
break
insert_embedding = wds.filters.pipelinefilter(embedding_inserter)
def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.reraise_exception):
"""Finds if the is a corresponding embedding for the tarfile at { url: [URL] }"""
embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
embedding_files = embeddings_fs.ls(embeddings_path)
get_embedding_shard = lambda embedding_file: int(embedding_file.split("_")[-1].split(".")[0])
embedding_shards = set([get_embedding_shard(filename) for filename in embedding_files]) # Sets have O(1) check for member
get_tar_shard = lambda tar_file: int(tar_file.split("/")[-1].split(".")[0])
for tarfile in tarfiles:
try:
webdataset_shard = get_tar_shard(tarfile["url"])
# If this shard has an associated embeddings file, we pass it through. Otherwise we iterate until we do have one
if webdataset_shard in embedding_shards:
yield tarfile
except Exception as exn: # From wds implementation
if handler(exn):
continue
else:
break
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
def verify_keys(samples, handler=wds.handlers.reraise_exception):
"""
Requires that both the image and embedding are present in the sample
@@ -113,9 +86,7 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
self,
urls,
embedding_folder_url=None,
index_width=None,
img_preproc=None,
extra_keys=[],
shard_width=None,
handler=wds.handlers.reraise_exception,
resample=False,
shuffle_shards=True
@@ -126,31 +97,13 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
:param urls: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
:param embedding_folder_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
:param img_preproc: This function is run on the img before it is batched and returned. Useful for data augmentation or converting to torch tensor.
:param shard_width: The number of digits in the shard number. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard with this 4 and the last three digits are the index.
:param handler: A webdataset handler.
:param resample: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
:param shuffle_shards: If true, shuffle the shards before resampling. This cannot be true if resample is true.
"""
super().__init__()
keys = ["jpg", "npy"] + extra_keys
self.key_map = {key: i for i, key in enumerate(keys)}
self.resampling = resample
self.img_preproc = img_preproc
# If s3, check if s3fs is installed and s3cmd is installed and check if the data is piped instead of straight up
if (isinstance(urls, str) and "s3:" in urls) or (isinstance(urls, list) and any(["s3:" in url for url in urls])):
# Then this has an s3 link for the webdataset and we need extra packages
if shutil.which("s3cmd") is None:
raise RuntimeError("s3cmd is required for s3 webdataset")
if "s3:" in embedding_folder_url:
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
try:
import s3fs
except ImportError:
raise RuntimeError("s3fs is required to load embeddings from s3")
# Add the shardList and randomize or resample if requested
if resample:
assert not shuffle_shards, "Cannot both resample and shuffle"
@@ -159,43 +112,28 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
self.append(wds.SimpleShardList(urls))
if shuffle_shards:
self.append(wds.filters.shuffle(1000))
if embedding_folder_url is not None:
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
self.append(wds.split_by_node)
self.append(wds.split_by_worker)
self.append(wds.tarfile_to_samples(handler=handler))
self.append(wds.decode("pilrgb", handler=handler))
self.append(wds.decode("torchrgb"))
if embedding_folder_url is not None:
# Then we are loading embeddings for a remote source
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
self.append(insert_embedding(embeddings_url=embedding_folder_url, index_width=index_width, handler=handler))
assert shard_width is not None, "Reading embeddings separately requires shard length to be given"
self.append(insert_embedding(embeddings_url=embedding_folder_url, shard_width=shard_width, handler=handler))
self.append(verify_keys)
# Apply preprocessing
self.append(wds.map(self.preproc))
self.append(wds.to_tuple(*keys))
def preproc(self, sample):
"""Applies the preprocessing for images"""
if self.img_preproc is not None:
sample["jpg"] = self.img_preproc(sample["jpg"])
return sample
self.append(wds.to_tuple("jpg", "npy"))
def create_image_embedding_dataloader(
tar_url,
num_workers,
batch_size,
embeddings_url=None,
index_width=None,
shard_width=None,
shuffle_num = None,
shuffle_shards = True,
resample_shards = False,
img_preproc=None,
extra_keys=[],
handler=wds.handlers.reraise_exception#warn_and_continue
handler=wds.handlers.warn_and_continue
):
"""
Convenience function to create an image embedding dataseta and dataloader in one line
@@ -205,8 +143,8 @@ def create_image_embedding_dataloader(
:param batch_size: The batch size to use for the dataloader
:param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
:param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
:param shard_width: The number of digits in the shard number. This is used to align the embedding index with the image index.
For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index.
:param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling.
:param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true.
:param resample_shards: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
@@ -215,11 +153,9 @@ def create_image_embedding_dataloader(
ds = ImageEmbeddingDataset(
tar_url,
embeddings_url,
index_width=index_width,
shard_width=shard_width,
shuffle_shards=shuffle_shards,
resample=resample_shards,
extra_keys=extra_keys,
img_preproc=img_preproc,
handler=handler
)
if shuffle_num is not None and shuffle_num > 0:

View File

@@ -1,59 +0,0 @@
from pathlib import Path
import torch
from torch.utils import data
from torchvision import transforms, utils
from PIL import Image
# helpers functions
def cycle(dl):
while True:
for data in dl:
yield data
# dataset and dataloader
class Dataset(data.Dataset):
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png']
):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
self.transform = transforms.Compose([
transforms.Resize(image_size),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(image_size),
transforms.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
def get_images_dataloader(
folder,
*,
batch_size,
image_size,
shuffle = True,
cycle_dl = True,
pin_memory = True
):
ds = Dataset(folder, image_size)
dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
if cycle_dl:
dl = cycle(dl)
return dl

View File

@@ -7,7 +7,7 @@ def separate_weight_decayable_params(params):
def get_optimizer(
params,
lr = 1e-4,
lr = 2e-5,
wd = 1e-2,
betas = (0.9, 0.999),
eps = 1e-8,

View File

@@ -1,47 +1,17 @@
import os
from pathlib import Path
from enum import Enum
import importlib
from itertools import zip_longest
import torch
from torch import nn
# constants
DEFAULT_DATA_PATH = './.tracker-data'
# helper functions
def exists(val):
return val is not None
def import_or_print_error(pkg_name, err_str = None):
try:
return importlib.import_module(pkg_name)
except ModuleNotFoundError as e:
if exists(err_str):
print(err_str)
exit()
# load state dict functions
def load_wandb_state_dict(run_path, file_path, **kwargs):
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
file_reference = wandb.restore(file_path, run_path=run_path)
return torch.load(file_reference.name)
def load_local_state_dict(file_path, **kwargs):
return torch.load(file_path)
# base class
class BaseTracker(nn.Module):
def __init__(self, data_path = DEFAULT_DATA_PATH):
def __init__(self):
super().__init__()
assert data_path is not None, "Tracker must have a data_path to save local content"
self.data_path = Path(data_path)
self.data_path.mkdir(parents = True, exist_ok = True)
def init(self, config, **kwargs):
raise NotImplementedError
@@ -49,27 +19,6 @@ class BaseTracker(nn.Module):
def log(self, log, **kwargs):
raise NotImplementedError
def log_images(self, images, **kwargs):
raise NotImplementedError
def save_state_dict(self, state_dict, relative_path, **kwargs):
raise NotImplementedError
def recall_state_dict(self, recall_source, *args, **kwargs):
"""
Loads a state dict from any source.
Since a user may wish to load a model from a different source than their own tracker (i.e. tracking using wandb but recalling from disk),
this should not be linked to any individual tracker.
"""
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
if recall_source == 'wandb':
return load_wandb_state_dict(*args, **kwargs)
elif recall_source == 'local':
return load_local_state_dict(*args, **kwargs)
else:
raise ValueError('`recall_source` must be one of `wandb` or `local`')
# basic stdout class
class ConsoleTracker(BaseTracker):
@@ -79,39 +28,22 @@ class ConsoleTracker(BaseTracker):
def log(self, log, **kwargs):
print(log)
def log_images(self, images, **kwargs): # noop for logging images
pass
def save_state_dict(self, state_dict, relative_path, **kwargs):
torch.save(state_dict, str(self.data_path / relative_path))
# basic wandb class
class WandbTracker(BaseTracker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
def __init__(self):
super().__init__()
try:
import wandb
except ImportError as e:
print('`pip install wandb` to use the wandb experiment tracker')
raise e
os.environ["WANDB_SILENT"] = "true"
self.wandb = wandb
def init(self, **config):
self.wandb.init(**config)
def log(self, log, verbose=False, **kwargs):
if verbose:
print(log)
def log(self, log, **kwargs):
self.wandb.log(log, **kwargs)
def log_images(self, images, captions=[], image_section="images", **kwargs):
"""
Takes a tensor of images and a list of captions and logs them to wandb.
"""
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
self.wandb.log({ image_section: wandb_images }, **kwargs)
def save_state_dict(self, state_dict, relative_path, **kwargs):
"""
Saves a state_dict to disk and uploads it
"""
full_path = str(self.data_path / relative_path)
torch.save(state_dict, full_path)
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path

View File

@@ -47,14 +47,6 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
# decorators
def cast_torch_tensor(fn):
@@ -187,8 +179,8 @@ class EMA(nn.Module):
self.online_model = model
self.ema_model = copy.deepcopy(model)
self.update_after_step = update_after_step # only start EMA after this step number, starting at 0
self.update_every = update_every
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))
@@ -197,21 +189,14 @@ class EMA(nn.Module):
device = self.initted.device
self.ema_model.to(device)
def copy_params_from_model_to_ema(self):
self.ema_model.state_dict(self.online_model.state_dict())
def update(self):
self.step += 1
if (self.step % self.update_every) != 0:
return
if self.step <= self.update_after_step:
self.copy_params_from_model_to_ema()
if self.step <= self.update_after_step or (self.step % self.update_every) != 0:
return
if not self.initted:
self.copy_params_from_model_to_ema()
self.ema_model.state_dict(self.online_model.state_dict())
self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model)
@@ -235,16 +220,6 @@ class EMA(nn.Module):
# diffusion prior trainer
def prior_sample_in_chunks(fn):
@wraps(fn)
def inner(self, *args, max_batch_size = None, **kwargs):
if not exists(max_batch_size):
return fn(self, *args, **kwargs)
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
return torch.cat(outputs, dim = 0)
return inner
class DiffusionPriorTrainer(nn.Module):
def __init__(
self,
@@ -305,13 +280,11 @@ class DiffusionPriorTrainer(nn.Module):
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
@@ -342,31 +315,15 @@ class DiffusionPriorTrainer(nn.Module):
# decoder trainer
def decoder_sample_in_chunks(fn):
@wraps(fn)
def inner(self, *args, max_batch_size = None, **kwargs):
if not exists(max_batch_size):
return fn(self, *args, **kwargs)
if self.decoder.unconditional:
batch_size = kwargs.get('batch_size')
batch_sizes = num_to_groups(batch_size, max_batch_size)
outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
else:
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
return torch.cat(outputs, dim = 0)
return inner
class DecoderTrainer(nn.Module):
def __init__(
self,
decoder,
use_ema = True,
lr = 1e-4,
lr = 2e-5,
wd = 1e-2,
eps = 1e-8,
max_grad_norm = 0.5,
max_grad_norm = None,
amp = False,
**kwargs
):
@@ -447,17 +404,15 @@ class DecoderTrainer(nn.Module):
@torch.no_grad()
@cast_torch_tensor
@decoder_sample_in_chunks
def sample(self, *args, **kwargs):
if kwargs.pop('use_non_ema', False) or not self.use_ema:
return self.decoder.sample(*args, **kwargs)
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
if self.use_ema:
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
output = self.decoder.sample(*args, **kwargs)
self.decoder.unets = trainable_unets # restore original training unets
if self.use_ema:
self.decoder.unets = trainable_unets # restore original training unets
# cast the ema_model unets back to original device
for ema in self.ema_unets:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 985 KiB

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.3.4',
version = '0.2.41',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',