mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 20:25:00 +01:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
38cd62010c | ||
|
|
1cc288af39 | ||
|
|
a851168633 | ||
|
|
1ffeecd0ca |
@@ -943,7 +943,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
|
|||||||
|
|
||||||
# Create a dataloader directly.
|
# Create a dataloader directly.
|
||||||
dataloader = create_image_embedding_dataloader(
|
dataloader = create_image_embedding_dataloader(
|
||||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
@@ -1097,7 +1097,7 @@ This library would not have gotten to this working state without the help of
|
|||||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697
|
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697
|
||||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
- [ ] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||||
- [ ] decoder needs one day worth of refactor for tech debt
|
- [ ] decoder needs one day worth of refactor for tech debt
|
||||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||||
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ Defines which evaluation metrics will be used to test the model.
|
|||||||
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
|
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
|
||||||
| Option | Required | Default | Description |
|
| Option | Required | Default | Description |
|
||||||
| ------ | -------- | ------- | ----------- |
|
| ------ | -------- | ------- | ----------- |
|
||||||
| `n_evalation_samples` | No | `1000` | The number of samples to generate to test the model. |
|
| `n_evaluation_samples` | No | `1000` | The number of samples to generate to test the model. |
|
||||||
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
|
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
|
||||||
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
|
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
|
||||||
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
|
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
|
import random
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
@@ -1676,7 +1677,7 @@ class LowresConditioner(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
downsample_first = True,
|
downsample_first = True,
|
||||||
blur_sigma = 0.1,
|
blur_sigma = (0.1, 0.2),
|
||||||
blur_kernel_size = 3,
|
blur_kernel_size = 3,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1700,6 +1701,16 @@ class LowresConditioner(nn.Module):
|
|||||||
# when training, blur the low resolution conditional image
|
# when training, blur the low resolution conditional image
|
||||||
blur_sigma = default(blur_sigma, self.blur_sigma)
|
blur_sigma = default(blur_sigma, self.blur_sigma)
|
||||||
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
||||||
|
|
||||||
|
# allow for drawing a random sigma between lo and hi float values
|
||||||
|
if isinstance(blur_sigma, tuple):
|
||||||
|
blur_sigma = random.uniform(*blur_sigma)
|
||||||
|
|
||||||
|
# allow for drawing a random kernel size between lo and hi int values
|
||||||
|
if isinstance(blur_kernel_size, tuple):
|
||||||
|
kernel_size_lo, kernel_size_hi = blur_kernel_size
|
||||||
|
blur_kernel_size = random.randrange(kernel_size_lo, kernel_size_hi + 1)
|
||||||
|
|
||||||
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||||
|
|
||||||
cond_fmap = resize_image_to(cond_fmap, target_image_size)
|
cond_fmap = resize_image_to(cond_fmap, target_image_size)
|
||||||
@@ -1725,7 +1736,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||||
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
||||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
|
||||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||||
clip_denoised = True,
|
clip_denoised = True,
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
|
|||||||
|
|
||||||
# Create a dataloader directly.
|
# Create a dataloader directly.
|
||||||
dataloader = create_image_embedding_dataloader(
|
dataloader = create_image_embedding_dataloader(
|
||||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
# to give users a quick easy start to training DALL-E without doing BPE
|
# to give users a quick easy start to training DALL-E without doing BPE
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import youtokentome as yttm
|
|
||||||
|
|
||||||
import html
|
import html
|
||||||
import os
|
import os
|
||||||
@@ -11,6 +10,8 @@ import regex as re
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from dalle2_pytorch.utils import import_or_print_error
|
||||||
|
|
||||||
# OpenAI simple tokenizer
|
# OpenAI simple tokenizer
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
@@ -156,7 +157,9 @@ class YttmTokenizer:
|
|||||||
bpe_path = Path(bpe_path)
|
bpe_path = Path(bpe_path)
|
||||||
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
|
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
|
||||||
|
|
||||||
tokenizer = yttm.BPE(model = str(bpe_path))
|
self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`')
|
||||||
|
|
||||||
|
tokenizer = self.yttm.BPE(model = str(bpe_path))
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.vocab_size = tokenizer.vocab_size()
|
self.vocab_size = tokenizer.vocab_size()
|
||||||
|
|
||||||
@@ -167,7 +170,7 @@ class YttmTokenizer:
|
|||||||
return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
|
return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
|
||||||
|
|
||||||
def encode(self, texts):
|
def encode(self, texts):
|
||||||
encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID)
|
encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID)
|
||||||
return list(map(torch.tensor, encoded))
|
return list(map(torch.tensor, encoded))
|
||||||
|
|
||||||
def tokenize(self, texts, context_length = 256, truncate_text = False):
|
def tokenize(self, texts, context_length = 256, truncate_text = False):
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from itertools import zip_longest
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from dalle2_pytorch.utils import import_or_print_error
|
||||||
|
|
||||||
# constants
|
# constants
|
||||||
|
|
||||||
DEFAULT_DATA_PATH = './.tracker-data'
|
DEFAULT_DATA_PATH = './.tracker-data'
|
||||||
@@ -15,14 +17,6 @@ DEFAULT_DATA_PATH = './.tracker-data'
|
|||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
def import_or_print_error(pkg_name, err_str = None):
|
|
||||||
try:
|
|
||||||
return importlib.import_module(pkg_name)
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
if exists(err_str):
|
|
||||||
print(err_str)
|
|
||||||
exit()
|
|
||||||
|
|
||||||
# load state dict functions
|
# load state dict functions
|
||||||
|
|
||||||
def load_wandb_state_dict(run_path, file_path, **kwargs):
|
def load_wandb_state_dict(run_path, file_path, **kwargs):
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class EMA(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
beta = 0.9999,
|
beta = 0.99,
|
||||||
update_after_step = 1000,
|
update_after_step = 1000,
|
||||||
update_every = 10,
|
update_every = 10,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -17,3 +17,13 @@ class Timer:
|
|||||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||||
flank = symbol * repeat
|
flank = symbol * repeat
|
||||||
return f'{flank} {s} {flank}'
|
return f'{flank} {s} {flank}'
|
||||||
|
|
||||||
|
# import helpers
|
||||||
|
|
||||||
|
def import_or_print_error(pkg_name, err_str = None):
|
||||||
|
try:
|
||||||
|
return importlib.import_module(pkg_name)
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
if exists(err_str):
|
||||||
|
print(err_str)
|
||||||
|
exit()
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.6.4'
|
__version__ = '0.6.7'
|
||||||
|
|||||||
1
setup.py
1
setup.py
@@ -42,7 +42,6 @@ setup(
|
|||||||
'tqdm',
|
'tqdm',
|
||||||
'vector-quantize-pytorch',
|
'vector-quantize-pytorch',
|
||||||
'x-clip>=0.4.4',
|
'x-clip>=0.4.4',
|
||||||
'youtokentome',
|
|
||||||
'webdataset>=0.2.5',
|
'webdataset>=0.2.5',
|
||||||
'fsspec>=2022.1.0',
|
'fsspec>=2022.1.0',
|
||||||
'torchmetrics[image]>=0.8.0'
|
'torchmetrics[image]>=0.8.0'
|
||||||
|
|||||||
Reference in New Issue
Block a user