Compare commits

...

10 Commits

Author SHA1 Message Date
Phil Wang
276abf337b fix and cleanup image size determination logic in decoder 2022-05-22 22:28:45 -07:00
Phil Wang
ae42d03006 allow for saving of additional fields on save method in trainers, and return loaded objects from the load method 2022-05-22 22:14:25 -07:00
Phil Wang
4d346e98d9 allow for config driven creation of clip-less diffusion prior 2022-05-22 20:36:20 -07:00
Phil Wang
2b1fd1ad2e product management 2022-05-22 19:23:40 -07:00
zion
82a2ef37d9 Update README.md (#109)
block in a section that links to available pre-trained models for those who are interested
2022-05-22 19:22:30 -07:00
Phil Wang
5c397c9d66 move neural network creations off the configuration file into the pydantic classes 2022-05-22 19:18:18 -07:00
Phil Wang
0f4edff214 derived value for image preprocessing belongs to the data config class 2022-05-22 18:42:40 -07:00
Phil Wang
501a8c7c46 small cleanup 2022-05-22 15:39:38 -07:00
Phil Wang
4e49373fc5 project management 2022-05-22 15:27:40 -07:00
Phil Wang
49de72040c fix decoder trainer optimizer loading (since there are multiple for each unet), also save and load step number correctly 2022-05-22 15:21:00 -07:00
10 changed files with 173 additions and 89 deletions

View File

@@ -24,6 +24,11 @@ There was enough interest for a <a href="https://github.com/lucidrains/dalle2-ja
*ongoing at 21k steps* *ongoing at 21k steps*
## Pre-Trained Models
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
- Decoder 🚧
- DALL-E 2 🚧
## Install ## Install
```bash ```bash
@@ -1078,6 +1083,8 @@ This library would not have gotten to this working state without the help of
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a> - [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [x] use pydantic for config drive training - [x] use pydantic for config drive training
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number) - [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
@@ -1087,11 +1094,9 @@ This library would not have gotten to this working state without the help of
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove - [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783 - [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training - [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt - [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well - [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89 - [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
## Citations ## Citations

View File

@@ -6,9 +6,10 @@ For more complex configuration, we provide the option of using a configuration f
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json). The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
**<ins>Unets</ins>:** **<ins>Unet</ins>:**
This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`
Each member of this array defines a single unet that will be added to the decoder.
| Option | Required | Default | Description | | Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- | | ------ | -------- | ------- | ----------- |
| `dim` | Yes | N/A | The starting channels of the unet. | | `dim` | Yes | N/A | The starting channels of the unet. |
@@ -22,6 +23,7 @@ Any parameter from the `Unet` constructor can also be given here.
Defines the configuration options for the decoder model. The unets defined above will automatically be inserted. Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
| Option | Required | Default | Description | | Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- | | ------ | -------- | ------- | ----------- |
| `unets` | Yes | N/A | A list of unets, using the configuration above |
| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. | | `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
| `image_size` | Yes | N/A | Not used. Can be any number. | | `image_size` | Yes | N/A | Not used. Can be any number. |
| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. | | `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |

View File

@@ -1,16 +1,16 @@
{ {
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"decoder": { "decoder": {
"unets": [
{
"dim": 128,
"image_embed_dim": 768,
"cond_dim": 64,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 32,
"attn_heads": 16
}
],
"image_sizes": [64], "image_sizes": [64],
"channels": 3, "channels": 3,
"timesteps": 1000, "timesteps": 1000,

View File

@@ -1710,12 +1710,18 @@ class Decoder(BaseGaussianDiffusion):
) )
self.unconditional = unconditional self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)' # text conditioning
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
self.condition_on_text_encodings = condition_on_text_encodings
# clip
self.clip = None self.clip = None
if exists(clip): if exists(clip):
assert not unconditional, 'clip must not be given if doing unconditional image training'
if isinstance(clip, CLIP): if isinstance(clip, CLIP):
clip = XClipAdapter(clip, **clip_adapter_overrides) clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa): elif isinstance(clip, CoCa):
@@ -1725,13 +1731,20 @@ class Decoder(BaseGaussianDiffusion):
assert isinstance(clip, BaseClipAdapter) assert isinstance(clip, BaseClipAdapter)
self.clip = clip self.clip = clip
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
else:
self.clip_image_size = image_size
self.channels = channels
self.condition_on_text_encodings = condition_on_text_encodings # determine image size, with image_size and image_sizes taking precedence
if exists(image_size) or exists(image_sizes):
assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given'
image_size = default(image_size, lambda: image_sizes[-1])
elif exists(clip):
image_size = clip.image_size
else:
raise Error('either image_size, image_sizes, or clip must be given to decoder')
# channels
self.channels = channels
# automatically take care of ensuring that first unet is unconditional # automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet # while the rest of the unets are conditioned on the low resolution image produced by previous unet
@@ -1773,7 +1786,7 @@ class Decoder(BaseGaussianDiffusion):
# unet image sizes # unet image sizes
image_sizes = default(image_sizes, (self.clip_image_size,)) image_sizes = default(image_sizes, (image_size,))
image_sizes = tuple(sorted(set(image_sizes))) image_sizes = tuple(sorted(set(image_sizes)))
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}' assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
@@ -1811,6 +1824,7 @@ class Decoder(BaseGaussianDiffusion):
self.clip_x_start = clip_x_start self.clip_x_start = clip_x_start
# normalize and unnormalize image functions # normalize and unnormalize image functions
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity

View File

@@ -3,15 +3,61 @@ from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder, DiffusionPrior, DiffusionPriorNetwork
# helper functions
def exists(val): def exists(val):
return val is not None return val is not None
def default(val, d): def default(val, d):
return val if exists(val) else d return val if exists(val) else d
def ListOrTuple(inner_type):
return Union[List[inner_type], Tuple[inner_type]]
# pydantic classes
class DiffusionPriorNetworkConfig(BaseModel):
dim: int
depth: int
num_timesteps: int = None
num_time_embeds: int = 1
num_image_embeds: int = 1
num_text_embeds: int = 1
dim_head: int = 64
heads: int = 8
ff_mult: int = 4
norm_out: bool = True
attn_dropout: float = 0.
ff_dropout: float = 0.
final_proj: bool = True
normformer: bool = False
rotary_emb: bool = True
class DiffusionPriorConfig(BaseModel):
# only clip-less diffusion prior config for now
net: DiffusionPriorNetworkConfig
image_embed_dim: int
image_size: int
image_channels: int = 3
timesteps: int = 1000
cond_drop_prob: float = 0.
loss_type: str = 'l2'
predict_x_start: bool = True
beta_schedule: str = 'cosine'
def create(self):
kwargs = self.dict()
diffusion_prior_network = DiffusionPriorNetwork(**kwargs.pop('net'))
return DiffusionPrior(net = diffusion_prior_network, **kwargs)
class Config:
extra = "allow"
class UnetConfig(BaseModel): class UnetConfig(BaseModel):
dim: int dim: int
dim_mults: List[int] dim_mults: ListOrTuple(int)
image_embed_dim: int = None image_embed_dim: int = None
cond_dim: int = None cond_dim: int = None
channels: int = 3 channels: int = 3
@@ -22,13 +68,22 @@ class UnetConfig(BaseModel):
extra = "allow" extra = "allow"
class DecoderConfig(BaseModel): class DecoderConfig(BaseModel):
unets: ListOrTuple(UnetConfig)
image_size: int = None image_size: int = None
image_sizes: Union[List[int], Tuple[int]] = None image_sizes: ListOrTuple(int) = None
channels: int = 3 channels: int = 3
timesteps: int = 1000 timesteps: int = 1000
loss_type: str = 'l2' loss_type: str = 'l2'
beta_schedule: str = 'cosine' beta_schedule: str = 'cosine'
learned_variance: bool = True learned_variance: bool = True
image_cond_drop_prob: float = 0.1
text_cond_drop_prob: float = 0.5
def create(self):
decoder_kwargs = self.dict()
unet_configs = decoder_kwargs.pop('unets')
unets = [Unet(**config) for config in unet_configs]
return Decoder(unets, **decoder_kwargs)
@validator('image_sizes') @validator('image_sizes')
def check_image_sizes(cls, image_sizes, values): def check_image_sizes(cls, image_sizes, values):
@@ -64,23 +119,39 @@ class DecoderDataConfig(BaseModel):
resample_train: bool = False resample_train: bool = False
preprocessing: Dict[str, Any] = {'ToTensor': True} preprocessing: Dict[str, Any] = {'ToTensor': True}
@property
def img_preproc(self):
def _get_transformation(transformation_name, **kwargs):
if transformation_name == "RandomResizedCrop":
return T.RandomResizedCrop(**kwargs)
elif transformation_name == "RandomHorizontalFlip":
return T.RandomHorizontalFlip()
elif transformation_name == "ToTensor":
return T.ToTensor()
transforms = []
for transform_name, transform_kwargs_or_bool in self.preprocessing.items():
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
transforms.append(_get_transformation(transform_name, **transform_kwargs))
return T.Compose(transforms)
class DecoderTrainConfig(BaseModel): class DecoderTrainConfig(BaseModel):
epochs: int = 20 epochs: int = 20
lr: float = 1e-4 lr: float = 1e-4
wd: float = 0.01 wd: float = 0.01
max_grad_norm: float = 0.5 max_grad_norm: float = 0.5
save_every_n_samples: int = 100000 save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
device: str = 'cuda:0' device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite. epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation. validation_samples: int = None # Same as above but for validation.
use_ema: bool = True use_ema: bool = True
ema_beta: float = 0.99 ema_beta: float = 0.99
amp: bool = False amp: bool = False
save_all: bool = False # Whether to preserve all checkpoints save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint save_latest: bool = True # Whether to always save the latest checkpoint
save_best: bool = True # Whether to save the best checkpoint save_best: bool = True # Whether to save the best checkpoint
unet_training_mask: List[bool] = None # If None, use all unets unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel): class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000 n_evaluation_samples: int = 1000
@@ -104,7 +175,6 @@ class DecoderLoadConfig(BaseModel):
resume: bool = False # If using wandb, whether to resume the run resume: bool = False # If using wandb, whether to resume the run
class TrainDecoderConfig(BaseModel): class TrainDecoderConfig(BaseModel):
unets: List[UnetConfig]
decoder: DecoderConfig decoder: DecoderConfig
data: DecoderDataConfig data: DecoderDataConfig
train: DecoderTrainConfig train: DecoderTrainConfig
@@ -117,19 +187,3 @@ class TrainDecoderConfig(BaseModel):
with open(json_path) as f: with open(json_path) as f:
config = json.load(f) config = json.load(f)
return cls(**config) return cls(**config)
@property
def img_preproc(self):
def _get_transformation(transformation_name, **kwargs):
if transformation_name == "RandomResizedCrop":
return T.RandomResizedCrop(**kwargs)
elif transformation_name == "RandomHorizontalFlip":
return T.RandomHorizontalFlip()
elif transformation_name == "ToTensor":
return T.ToTensor()
transforms = []
for transform_name, transform_kwargs_or_bool in self.data.preprocessing.items():
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
transforms.append(_get_transformation(transform_name, **transform_kwargs))
return T.Compose(transforms)

View File

@@ -133,12 +133,6 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
chunk_size_frac = chunk_size / batch_size chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs) yield chunk_size_frac, (chunked_args, chunked_kwargs)
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'
# saving and loading functions # saving and loading functions
# for diffusion prior # for diffusion prior
@@ -196,7 +190,7 @@ class EMA(nn.Module):
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0 self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
self.register_buffer('initted', torch.Tensor([False])) self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.])) self.register_buffer('step', torch.tensor([0]))
def restore_ema_model_device(self): def restore_ema_model_device(self):
device = self.initted.device device = self.initted.device
@@ -292,9 +286,9 @@ class DiffusionPriorTrainer(nn.Module):
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.])) self.register_buffer('step', torch.tensor([0]))
def save(self, path, overwrite = True): def save(self, path, overwrite = True, **kwargs):
path = Path(path) path = Path(path)
assert not (path.exists() and not overwrite) assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True) path.parent.mkdir(parents = True, exist_ok = True)
@@ -303,7 +297,9 @@ class DiffusionPriorTrainer(nn.Module):
scaler = self.scaler.state_dict(), scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(), optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(), model = self.diffusion_prior.state_dict(),
version = get_pkg_version() version = get_pkg_version(),
step = self.step.item(),
**kwargs
) )
if self.use_ema: if self.use_ema:
@@ -321,9 +317,10 @@ class DiffusionPriorTrainer(nn.Module):
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}') print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict) self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model: if only_model:
return return loaded_obj
self.scaler.load_state_dict(loaded_obj['scaler']) self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer']) self.optimizer.load_state_dict(loaded_obj['optimizer'])
@@ -332,6 +329,8 @@ class DiffusionPriorTrainer(nn.Module):
assert 'ema' in loaded_obj assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict) self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
return loaded_obj
def update(self): def update(self):
if exists(self.max_grad_norm): if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer) self.scaler.unscale_(self.optimizer)
@@ -453,18 +452,25 @@ class DecoderTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0.])) self.register_buffer('step', torch.tensor([0.]))
def save(self, path, overwrite = True): def save(self, path, overwrite = True, **kwargs):
path = Path(path) path = Path(path)
assert not (path.exists() and not overwrite) assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True) path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict( save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.decoder.state_dict(), model = self.decoder.state_dict(),
version = get_pkg_version() version = get_pkg_version(),
step = self.step.item(),
**kwargs
) )
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'scaler{ind}'
scaler = getattr(self, scaler_key)
optimizer = getattr(self, optimizer_key)
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
if self.use_ema: if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
@@ -480,17 +486,26 @@ class DecoderTrainer(nn.Module):
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}') print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
self.decoder.load_state_dict(loaded_obj['model'], strict = strict) self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model: if only_model:
return return loaded_obj
self.scaler.load_state_dict(loaded_obj['scaler']) for ind in range(0, self.num_unets):
self.optimizer.load_state_dict(loaded_obj['optimizer']) scaler_key = f'scaler{ind}'
optimizer_key = f'scaler{ind}'
scaler = getattr(self, scaler_key)
optimizer = getattr(self, optimizer_key)
scaler.load_state_dict(loaded_obj[scaler_key])
optimizer.load_state_dict(loaded_obj[optimizer_key])
if self.use_ema: if self.use_ema:
assert 'ema' in loaded_obj assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
return loaded_obj
@property @property
def unets(self): def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) return nn.ModuleList([ema.ema_model for ema in self.ema_unets])

View File

@@ -1,5 +1,7 @@
import time import time
# time helpers
class Timer: class Timer:
def __init__(self): def __init__(self):
self.reset() self.reset()
@@ -9,3 +11,9 @@ class Timer:
def elapsed(self): def elapsed(self):
return time.time() - self.last_time return time.time() - self.last_time
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.4.3', version = '0.4.11',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',

View File

@@ -1,9 +1,9 @@
from dalle2_pytorch import Unet, Decoder from dalle2_pytorch import Unet, Decoder
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon from dalle2_pytorch.trainer import DecoderTrainer
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer from dalle2_pytorch.utils import Timer, print_ribbon
import torchvision import torchvision
import torch import torch
@@ -85,20 +85,6 @@ def create_dataloaders(
"test_sampling": test_sampling_dataloader "test_sampling": test_sampling_dataloader
} }
def create_decoder(device, decoder_config, unets_config):
"""Creates a sample decoder"""
unets = [Unet(**config.dict()) for config in unets_config]
decoder = Decoder(
unet=unets,
**decoder_config.dict()
)
decoder.to(device=device)
return decoder
def get_dataset_keys(dataloader): def get_dataset_keys(dataloader):
""" """
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it. It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
@@ -420,7 +406,7 @@ def initialize_training(config):
dataloaders = create_dataloaders ( dataloaders = create_dataloaders (
available_shards=all_shards, available_shards=all_shards,
img_preproc = config.img_preproc, img_preproc = config.data.img_preproc,
train_prop = config.data.splits.train, train_prop = config.data.splits.train,
val_prop = config.data.splits.val, val_prop = config.data.splits.val,
test_prop = config.data.splits.test, test_prop = config.data.splits.test,
@@ -428,7 +414,7 @@ def initialize_training(config):
**config.data.dict() **config.data.dict()
) )
decoder = create_decoder(device, config.decoder, config.unets) decoder = config.decoder.create().to(device = device)
num_parameters = sum(p.numel() for p in decoder.parameters()) num_parameters = sum(p.numel() for p in decoder.parameters())
print(print_ribbon("Loaded Config", repeat=40)) print(print_ribbon("Loaded Config", repeat=40))
print(f"Number of parameters: {num_parameters}") print(f"Number of parameters: {num_parameters}")

View File

@@ -9,10 +9,10 @@ from torch import nn
from dalle2_pytorch.dataloaders import make_splits from dalle2_pytorch.dataloaders import make_splits
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from dalle2_pytorch.utils import Timer from dalle2_pytorch.utils import Timer, print_ribbon
from embedding_reader import EmbeddingReader from embedding_reader import EmbeddingReader