mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc50c6b34e | ||
|
|
2b1fd1ad2e | ||
|
|
82a2ef37d9 |
@@ -24,6 +24,11 @@ There was enough interest for a <a href="https://github.com/lucidrains/dalle2-ja
|
||||
|
||||
*ongoing at 21k steps*
|
||||
|
||||
## Pre-Trained Models
|
||||
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
|
||||
- Decoder 🚧
|
||||
- DALL-E 2 🚧
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
@@ -1079,6 +1084,7 @@ This library would not have gotten to this working state without the help of
|
||||
- [x] use pydantic for config drive training
|
||||
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] train on a toy task, offer in colab
|
||||
|
||||
@@ -3,7 +3,7 @@ from torchvision import transforms as T
|
||||
from pydantic import BaseModel, validator, root_validator
|
||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder
|
||||
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder, DiffusionPrior, DiffusionPriorNetwork
|
||||
|
||||
# helper functions
|
||||
|
||||
@@ -18,6 +18,43 @@ def ListOrTuple(inner_type):
|
||||
|
||||
# pydantic classes
|
||||
|
||||
class DiffusionPriorNetworkConfig(BaseModel):
|
||||
dim: int
|
||||
depth: int
|
||||
num_timesteps: int = None
|
||||
num_time_embeds: int = 1
|
||||
num_image_embeds: int = 1
|
||||
num_text_embeds: int = 1
|
||||
dim_head: int = 64
|
||||
heads: int = 8
|
||||
ff_mult: int = 4
|
||||
norm_out: bool = True
|
||||
attn_dropout: float = 0.
|
||||
ff_dropout: float = 0.
|
||||
final_proj: bool = True
|
||||
normformer: bool = False
|
||||
rotary_emb: bool = True
|
||||
|
||||
class DiffusionPriorConfig(BaseModel):
|
||||
# only clip-less diffusion prior config for now
|
||||
net: DiffusionPriorNetworkConfig
|
||||
image_embed_dim: int
|
||||
image_size: int
|
||||
image_channels: int = 3
|
||||
timesteps: int = 1000
|
||||
cond_drop_prob: float = 0.
|
||||
loss_type: str = 'l2'
|
||||
predict_x_start: bool = True
|
||||
beta_schedule: str = 'cosine'
|
||||
|
||||
def create(self):
|
||||
kwargs = self.dict()
|
||||
diffusion_prior_network = DiffusionPriorNetwork(**kwargs.pop('net'))
|
||||
return DiffusionPrior(net = diffusion_prior_network, **kwargs)
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
class UnetConfig(BaseModel):
|
||||
dim: int
|
||||
dim_mults: ListOrTuple(int)
|
||||
@@ -39,6 +76,8 @@ class DecoderConfig(BaseModel):
|
||||
loss_type: str = 'l2'
|
||||
beta_schedule: str = 'cosine'
|
||||
learned_variance: bool = True
|
||||
image_cond_drop_prob: float = 0.1
|
||||
text_cond_drop_prob: float = 0.5
|
||||
|
||||
def create(self):
|
||||
decoder_kwargs = self.dict()
|
||||
|
||||
Reference in New Issue
Block a user