Compare commits

...

16 Commits

Author SHA1 Message Date
Phil Wang
b8af2210df make sure diffusion prior can be instantiated from pydantic class without clip 2022-05-26 08:47:30 -07:00
Phil Wang
f4fe6c570d allow for full customization of number of resnet blocks per down or upsampling layers in unet, as in imagen 2022-05-26 08:33:31 -07:00
Phil Wang
645e207441 credit assignment 2022-05-26 08:16:03 -07:00
Phil Wang
00743b3a0b update 2022-05-26 08:12:25 -07:00
Phil Wang
01589aff6a cite maxvit properly 2022-05-26 07:12:25 -07:00
Phil Wang
7ecfd76cc0 fix evaluation config splat in training decoder script 2022-05-26 07:11:31 -07:00
Phil Wang
6161b61c55 0.5.4 2022-05-25 09:32:17 -07:00
zion
1ed0f9d80b use deterministic optimizer params (#116) 2022-05-25 09:31:43 -07:00
Phil Wang
f326a95e26 0.5.3 2022-05-25 09:07:28 -07:00
zion
d7a0a2ce4b add more support for configuring prior (#113) 2022-05-25 09:06:50 -07:00
Phil Wang
f23fab7ef7 switch over to scale shift conditioning, as it seems like Imagen and Glide used it and it may be important 2022-05-24 21:46:12 -07:00
Phil Wang
857b9fbf1e allow for one to stop grouping out weight decayable parameters, to debug optimizer state dict problem 2022-05-24 21:42:32 -07:00
Phil Wang
8864fd0aa7 bring in the dynamic thresholding technique from the Imagen paper, which purportedly improves classifier free guidance for the cascading ddpm 2022-05-24 18:15:14 -07:00
Phil Wang
72bf159331 update 2022-05-24 08:25:40 -07:00
Phil Wang
e5e47cfecb link to aidan's test run 2022-05-23 12:41:46 -07:00
Phil Wang
fa533962bd just use an assert to make sure clip image channels is never different than the channels of the diffusion prior and decoder, if clip is given 2022-05-22 22:43:14 -07:00
8 changed files with 272 additions and 64 deletions

View File

@@ -12,7 +12,7 @@ This model is SOTA for text-to-image for now.
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lucidrains/imagen-pytorch">here</a>. Jax versions as well as text-to-video project will be shifted towards the Imagen architecture, as it is way simpler.
## Status
@@ -24,9 +24,11 @@ There was enough interest for a <a href="https://github.com/lucidrains/dalle2-ja
*ongoing at 21k steps*
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
## Pre-Trained Models
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
- Decoder 🚧
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
- DALL-E 2 🚧
## Install
@@ -1048,6 +1050,7 @@ This library would not have gotten to this working state without the help of
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
... and many others. Thank you! 🙏
@@ -1140,7 +1143,7 @@ This library would not have gotten to this working state without the help of
```bibtex
@inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer},
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
year = {2022}
}
```
@@ -1195,4 +1198,12 @@ This library would not have gotten to this working state without the help of
}
```
```bibtex
@misc{Saharia2022,https://stability.ai/
title = {Imagen: unprecedented photorealism × deep level of language understanding},
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
year = {2022}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -0,0 +1,70 @@
{
"prior": {
"clip": {
"make": "x-clip",
"model": "ViT-L/14",
"base_model_kwargs": {
"dim_text": 768,
"dim_image": 768,
"dim_latent": 768
}
},
"net": {
"dim": 768,
"depth": 12,
"num_timesteps": 1000,
"num_time_embeds": 1,
"num_image_embeds": 1,
"num_text_embeds": 1,
"dim_head": 64,
"heads": 12,
"ff_mult": 4,
"norm_out": true,
"attn_dropout": 0.0,
"ff_dropout": 0.0,
"final_proj": true,
"normformer": true,
"rotary_emb": true
},
"image_embed_dim": 768,
"image_size": 224,
"image_channels": 3,
"timesteps": 1000,
"cond_drop_prob": 0.1,
"loss_type": "l2",
"predict_x_start": true,
"beta_schedule": "cosine",
"condition_on_text_encodings": true
},
"data": {
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
"batch_size": 256,
"splits": {
"train": 0.9,
"val": 1e-7,
"test": 0.0999999
}
},
"train": {
"epochs": 1,
"lr": 1.1e-4,
"wd": 6.02e-2,
"max_grad_norm": 0.5,
"use_ema": true,
"amp": false,
"save_every": 10000
},
"load": {
"source": null,
"resume": false
},
"tracker": {
"tracker_type": "wandb",
"data_path": "./prior_checkpoints",
"wandb_entity": "laion",
"wandb_project": "diffusion-prior",
"verbose": true
}
}

View File

@@ -890,6 +890,8 @@ class DiffusionPrior(BaseGaussianDiffusion):
)
if exists(clip):
assert image_channels == clip.image_channels, f'channels of image ({image_channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
if isinstance(clip, CLIP):
clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa):
@@ -1105,13 +1107,20 @@ class Block(nn.Module):
groups = 8
):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(dim, dim_out, 3, padding = 1),
nn.GroupNorm(groups, dim_out),
nn.SiLU()
)
def forward(self, x):
return self.block(x)
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.project(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module):
def __init__(
@@ -1130,7 +1139,7 @@ class ResnetBlock(nn.Module):
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_cond_dim, dim_out)
nn.Linear(time_cond_dim, dim_out * 2)
)
self.cross_attn = None
@@ -1150,11 +1159,14 @@ class ResnetBlock(nn.Module):
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, cond = None, time_emb = None):
h = self.block1(x)
scale_shift = None
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
h = rearrange(time_emb, 'b c -> b c 1 1') + h
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x, scale_shift = scale_shift)
if exists(self.cross_attn):
assert exists(cond)
@@ -1334,6 +1346,7 @@ class Unet(nn.Module):
init_dim = None,
init_conv_kernel_size = 7,
resnet_groups = 8,
num_resnet_blocks = 1,
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
@@ -1419,6 +1432,7 @@ class Unet(nn.Module):
# resnet block klass
resnet_groups = cast_tuple(resnet_groups, len(in_out))
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
assert len(resnet_groups) == len(in_out)
@@ -1434,7 +1448,7 @@ class Unet(nn.Module):
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)):
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
is_first = ind == 0
is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if not is_first else None
@@ -1442,7 +1456,7 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
downsample_klass(dim_out) if not is_last else nn.Identity()
]))
@@ -1452,14 +1466,14 @@ class Unet(nn.Module):
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))):
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups), reversed(num_resnet_blocks))):
is_last = ind >= (num_resolutions - 2)
layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
Upsample(dim_in)
]))
@@ -1616,10 +1630,13 @@ class Unet(nn.Module):
hiddens = []
for block1, sparse_attn, block2, downsample in self.downs:
x = block1(x, c, t)
for init_block, sparse_attn, resnet_blocks, downsample in self.downs:
x = init_block(x, c, t)
x = sparse_attn(x)
x = block2(x, c, t)
for resnet_block in resnet_blocks:
x = resnet_block(x, c, t)
hiddens.append(x)
x = downsample(x)
@@ -1630,11 +1647,14 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c, t)
for block1, sparse_attn, block2, upsample in self.ups:
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = block1(x, c, t)
x = init_block(x, c, t)
x = sparse_attn(x)
x = block2(x, c, t)
for resnet_block in resnet_blocks:
x = resnet_block(x, c, t)
x = upsample(x)
return self.final_conv(x)
@@ -1702,6 +1722,8 @@ class Decoder(BaseGaussianDiffusion):
vb_loss_weight = 0.001,
unconditional = False,
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
use_dynamic_thres = False, # from the Imagen paper
dynamic_thres_percentile = 0.9
):
super().__init__(
beta_schedule = beta_schedule,
@@ -1721,6 +1743,7 @@ class Decoder(BaseGaussianDiffusion):
self.clip = None
if exists(clip):
assert not unconditional, 'clip must not be given if doing unconditional image training'
assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
if isinstance(clip, CLIP):
clip = XClipAdapter(clip, **clip_adapter_overrides)
@@ -1823,6 +1846,11 @@ class Decoder(BaseGaussianDiffusion):
self.clip_denoised = clip_denoised
self.clip_x_start = clip_x_start
# dynamic thresholding settings, if clipping denoised during sampling
self.use_dynamic_thres = use_dynamic_thres
self.dynamic_thres_percentile = dynamic_thres_percentile
# normalize and unnormalize image functions
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
@@ -1865,7 +1893,21 @@ class Decoder(BaseGaussianDiffusion):
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised:
x_recon.clamp_(-1., 1.)
# s is the threshold amount
# static thresholding would just be s = 1
s = 1.
if self.use_dynamic_thres:
s = torch.quantile(
rearrange(x_recon, 'b ... -> b (...)').abs(),
self.dynamic_thres_percentile,
dim = -1
)
s.clamp_(min = 1.)
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
# clip by threshold, depending on whether static or dynamic
x_recon = x_recon.clamp(-s, s) / s
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)

View File

@@ -12,6 +12,7 @@ def get_optimizer(
betas = (0.9, 0.999),
eps = 1e-8,
filter_by_requires_grad = False,
group_wd_params = True,
**kwargs
):
if filter_by_requires_grad:
@@ -20,12 +21,12 @@ def get_optimizer(
if wd == 0:
return Adam(params, lr = lr, betas = betas, eps = eps)
params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params)
if group_wd_params:
wd_params, no_wd_params = separate_weight_decayable_params(params)
param_groups = [
{'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0},
]
params = [
{'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0},
]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -3,7 +3,18 @@ from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder, DiffusionPrior, DiffusionPriorNetwork
from x_clip import CLIP as XCLIP
from coca_pytorch import CoCa
from dalle2_pytorch.dalle2_pytorch import (
CoCaAdapter,
OpenAIClipAdapter,
Unet,
Decoder,
DiffusionPrior,
DiffusionPriorNetwork,
XClipAdapter,
)
# helper functions
@@ -16,7 +27,47 @@ def default(val, d):
def ListOrTuple(inner_type):
return Union[List[inner_type], Tuple[inner_type]]
# pydantic classes
def SingularOrIterable(inner_type):
return Union[inner_type, ListOrTuple(inner_type)]
# general pydantic classes
class TrainSplitConfig(BaseModel):
train: float = 0.75
val: float = 0.15
test: float = 0.1
@root_validator
def validate_all(cls, fields):
actual_sum = sum([*fields.values()])
if actual_sum != 1.:
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
return fields
class TrackerConfig(BaseModel):
tracker_type: str = 'console' # Decoder currently supports console and wandb
data_path: str = './models' # The path where files will be saved locally
init_config: Dict[str, Any] = None
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
wandb_project: str = ''
verbose: bool = False # Whether to print console logging for non-console trackers
# diffusion prior pydantic classes
class AdapterConfig(BaseModel):
make: str = "openai"
model: str = "ViT-L/14"
base_model_kwargs: Dict[str, Any] = None
def create(self):
if self.make == "openai":
return OpenAIClipAdapter(self.model)
elif self.make == "x-clip":
return XClipAdapter(XCLIP(**self.base_model_kwargs))
elif self.make == "coca":
return CoCaAdapter(CoCa(**self.base_model_kwargs))
else:
raise AttributeError("No adapter with that name is available.")
class DiffusionPriorNetworkConfig(BaseModel):
dim: int
@@ -35,8 +86,12 @@ class DiffusionPriorNetworkConfig(BaseModel):
normformer: bool = False
rotary_emb: bool = True
def create(self):
kwargs = self.dict()
return DiffusionPriorNetwork(**kwargs)
class DiffusionPriorConfig(BaseModel):
# only clip-less diffusion prior config for now
clip: AdapterConfig = None
net: DiffusionPriorNetworkConfig
image_embed_dim: int
image_size: int
@@ -46,15 +101,59 @@ class DiffusionPriorConfig(BaseModel):
loss_type: str = 'l2'
predict_x_start: bool = True
beta_schedule: str = 'cosine'
def create(self):
kwargs = self.dict()
diffusion_prior_network = DiffusionPriorNetwork(**kwargs.pop('net'))
return DiffusionPrior(net = diffusion_prior_network, **kwargs)
condition_on_text_encodings: bool = True
class Config:
extra = "allow"
def create(self):
kwargs = self.dict()
has_clip = exists(kwargs.pop('clip'))
kwargs.pop('net')
clip = None
if has_clip:
clip = self.clip.create()
diffusion_prior_network = self.net.create()
return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)
class DiffusionPriorTrainConfig(BaseModel):
epochs: int = 1
lr: float = 1.1e-4
wd: float = 6.02e-2
max_grad_norm: float = 0.5
use_ema: bool = True
ema_beta: float = 0.99
amp: bool = False
save_every: int = 10000 # what steps to save on
class DiffusionPriorDataConfig(BaseModel):
image_url: str # path to embeddings folder
meta_url: str # path to metadata (captions) for images
splits: TrainSplitConfig
batch_size: int = 64
class DiffusionPriorLoadConfig(BaseModel):
source: str = None
resume: bool = False
class TrainDiffusionPriorConfig(BaseModel):
prior: DiffusionPriorConfig
data: DiffusionPriorDataConfig
train: DiffusionPriorTrainConfig
load: DiffusionPriorLoadConfig
tracker: TrackerConfig
@classmethod
def from_json_path(cls, json_path):
with open(json_path) as f:
config = json.load(f)
return cls(**config)
# decoder pydantic classes
class UnetConfig(BaseModel):
dim: int
dim_mults: ListOrTuple(int)
@@ -94,17 +193,6 @@ class DecoderConfig(BaseModel):
class Config:
extra = "allow"
class TrainSplitConfig(BaseModel):
train: float = 0.75
val: float = 0.15
test: float = 0.1
@root_validator
def validate_all(cls, fields):
if sum([*fields.values()]) != 1.:
raise ValueError(f'{fields.keys()} must sum to 1.0')
return fields
class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images
embeddings_url: str # path to .npy files with embeddings
@@ -137,16 +225,16 @@ class DecoderDataConfig(BaseModel):
class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: float = 1e-4
wd: float = 0.01
max_grad_norm: float = 0.5
lr: SingularOrIterable(float) = 1e-4
wd: SingularOrIterable(float) = 0.01
max_grad_norm: SingularOrIterable(float) = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
use_ema: bool = True
ema_beta: float = 0.99
ema_beta: float = 0.999
amp: bool = False
save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint
@@ -160,14 +248,6 @@ class DecoderEvaluateConfig(BaseModel):
KID: Dict[str, Any] = None
LPIPS: Dict[str, Any] = None
class TrackerConfig(BaseModel):
tracker_type: str = 'console' # Decoder currently supports console and wandb
data_path: str = './models' # The path where files will be saved locally
init_config: Dict[str, Any] = None
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
wandb_project: str = ''
verbose: bool = False # Whether to print console logging for non-console trackers
class DecoderLoadConfig(BaseModel):
source: str = None # Supports file and wandb
run_path: str = '' # Used only if source is wandb

View File

@@ -254,6 +254,7 @@ class DiffusionPriorTrainer(nn.Module):
eps = 1e-6,
max_grad_norm = None,
amp = False,
group_wd_params = True,
**kwargs
):
super().__init__()
@@ -279,6 +280,7 @@ class DiffusionPriorTrainer(nn.Module):
lr = lr,
wd = wd,
eps = eps,
group_wd_params = group_wd_params,
**kwargs
)
@@ -410,6 +412,7 @@ class DecoderTrainer(nn.Module):
eps = 1e-8,
max_grad_norm = 0.5,
amp = False,
group_wd_params = True,
**kwargs
):
super().__init__()
@@ -435,6 +438,7 @@ class DecoderTrainer(nn.Module):
lr = unet_lr,
wd = unet_wd,
eps = unet_eps,
group_wd_params = group_wd_params,
**kwargs
)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.4.11',
version = '0.5.6',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',

View File

@@ -347,7 +347,7 @@ def train(
# Compute evaluation metrics
if exists(evaluate_config):
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
tracker.log(evaluation, step=step, verbose=True)
# Generate sample images