Compare commits

...

11 Commits

Author SHA1 Message Date
Phil Wang
f4fe6c570d allow for full customization of number of resnet blocks per down or upsampling layers in unet, as in imagen 2022-05-26 08:33:31 -07:00
Phil Wang
645e207441 credit assignment 2022-05-26 08:16:03 -07:00
Phil Wang
00743b3a0b update 2022-05-26 08:12:25 -07:00
Phil Wang
01589aff6a cite maxvit properly 2022-05-26 07:12:25 -07:00
Phil Wang
7ecfd76cc0 fix evaluation config splat in training decoder script 2022-05-26 07:11:31 -07:00
Phil Wang
6161b61c55 0.5.4 2022-05-25 09:32:17 -07:00
zion
1ed0f9d80b use deterministic optimizer params (#116) 2022-05-25 09:31:43 -07:00
Phil Wang
f326a95e26 0.5.3 2022-05-25 09:07:28 -07:00
zion
d7a0a2ce4b add more support for configuring prior (#113) 2022-05-25 09:06:50 -07:00
Phil Wang
f23fab7ef7 switch over to scale shift conditioning, as it seems like Imagen and Glide used it and it may be important 2022-05-24 21:46:12 -07:00
Phil Wang
857b9fbf1e allow for one to stop grouping out weight decayable parameters, to debug optimizer state dict problem 2022-05-24 21:42:32 -07:00
8 changed files with 224 additions and 58 deletions

View File

@@ -24,6 +24,8 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
*ongoing at 21k steps* *ongoing at 21k steps*
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
## Pre-Trained Models ## Pre-Trained Models
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>. - LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧 - Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
@@ -1048,6 +1050,7 @@ This library would not have gotten to this working state without the help of
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management - <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs - <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice - <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
... and many others. Thank you! 🙏 ... and many others. Thank you! 🙏
@@ -1140,7 +1143,7 @@ This library would not have gotten to this working state without the help of
```bibtex ```bibtex
@inproceedings{Tu2022MaxViTMV, @inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer}, title = {MaxViT: Multi-Axis Vision Transformer},
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li}, author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
year = {2022} year = {2022}
} }
``` ```
@@ -1196,7 +1199,7 @@ This library would not have gotten to this working state without the help of
``` ```
```bibtex ```bibtex
@misc{Saharia2022, @misc{Saharia2022,https://stability.ai/
title = {Imagen: unprecedented photorealism × deep level of language understanding}, title = {Imagen: unprecedented photorealism × deep level of language understanding},
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*}, author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
year = {2022} year = {2022}

View File

@@ -0,0 +1,70 @@
{
"prior": {
"clip": {
"make": "x-clip",
"model": "ViT-L/14",
"base_model_kwargs": {
"dim_text": 768,
"dim_image": 768,
"dim_latent": 768
}
},
"net": {
"dim": 768,
"depth": 12,
"num_timesteps": 1000,
"num_time_embeds": 1,
"num_image_embeds": 1,
"num_text_embeds": 1,
"dim_head": 64,
"heads": 12,
"ff_mult": 4,
"norm_out": true,
"attn_dropout": 0.0,
"ff_dropout": 0.0,
"final_proj": true,
"normformer": true,
"rotary_emb": true
},
"image_embed_dim": 768,
"image_size": 224,
"image_channels": 3,
"timesteps": 1000,
"cond_drop_prob": 0.1,
"loss_type": "l2",
"predict_x_start": true,
"beta_schedule": "cosine",
"condition_on_text_encodings": true
},
"data": {
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
"batch_size": 256,
"splits": {
"train": 0.9,
"val": 1e-7,
"test": 0.0999999
}
},
"train": {
"epochs": 1,
"lr": 1.1e-4,
"wd": 6.02e-2,
"max_grad_norm": 0.5,
"use_ema": true,
"amp": false,
"save_every": 10000
},
"load": {
"source": null,
"resume": false
},
"tracker": {
"tracker_type": "wandb",
"data_path": "./prior_checkpoints",
"wandb_entity": "laion",
"wandb_project": "diffusion-prior",
"verbose": true
}
}

View File

@@ -1107,13 +1107,20 @@ class Block(nn.Module):
groups = 8 groups = 8
): ):
super().__init__() super().__init__()
self.block = nn.Sequential( self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
nn.Conv2d(dim, dim_out, 3, padding = 1), self.norm = nn.GroupNorm(groups, dim_out)
nn.GroupNorm(groups, dim_out), self.act = nn.SiLU()
nn.SiLU()
) def forward(self, x, scale_shift = None):
def forward(self, x): x = self.project(x)
return self.block(x) x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module): class ResnetBlock(nn.Module):
def __init__( def __init__(
@@ -1132,7 +1139,7 @@ class ResnetBlock(nn.Module):
if exists(time_cond_dim): if exists(time_cond_dim):
self.time_mlp = nn.Sequential( self.time_mlp = nn.Sequential(
nn.SiLU(), nn.SiLU(),
nn.Linear(time_cond_dim, dim_out) nn.Linear(time_cond_dim, dim_out * 2)
) )
self.cross_attn = None self.cross_attn = None
@@ -1152,11 +1159,14 @@ class ResnetBlock(nn.Module):
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, cond = None, time_emb = None): def forward(self, x, cond = None, time_emb = None):
h = self.block1(x)
scale_shift = None
if exists(self.time_mlp) and exists(time_emb): if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb) time_emb = self.time_mlp(time_emb)
h = rearrange(time_emb, 'b c -> b c 1 1') + h time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x, scale_shift = scale_shift)
if exists(self.cross_attn): if exists(self.cross_attn):
assert exists(cond) assert exists(cond)
@@ -1336,6 +1346,7 @@ class Unet(nn.Module):
init_dim = None, init_dim = None,
init_conv_kernel_size = 7, init_conv_kernel_size = 7,
resnet_groups = 8, resnet_groups = 8,
num_resnet_blocks = 1,
init_cross_embed_kernel_sizes = (3, 7, 15), init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False, cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4), cross_embed_downsample_kernel_sizes = (2, 4),
@@ -1421,6 +1432,7 @@ class Unet(nn.Module):
# resnet block klass # resnet block klass
resnet_groups = cast_tuple(resnet_groups, len(in_out)) resnet_groups = cast_tuple(resnet_groups, len(in_out))
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
assert len(resnet_groups) == len(in_out) assert len(resnet_groups) == len(in_out)
@@ -1436,7 +1448,7 @@ class Unet(nn.Module):
self.ups = nn.ModuleList([]) self.ups = nn.ModuleList([])
num_resolutions = len(in_out) num_resolutions = len(in_out)
for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)): for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
is_first = ind == 0 is_first = ind == 0
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if not is_first else None layer_cond_dim = cond_dim if not is_first else None
@@ -1444,7 +1456,7 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([ self.downs.append(nn.ModuleList([
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups), ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(), Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups), nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
downsample_klass(dim_out) if not is_last else nn.Identity() downsample_klass(dim_out) if not is_last else nn.Identity()
])) ]))
@@ -1454,14 +1466,14 @@ class Unet(nn.Module):
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))): for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups), reversed(num_resnet_blocks))):
is_last = ind >= (num_resolutions - 2) is_last = ind >= (num_resolutions - 2)
layer_cond_dim = cond_dim if not is_last else None layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([ self.ups.append(nn.ModuleList([
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups), ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(), Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups), nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
Upsample(dim_in) Upsample(dim_in)
])) ]))
@@ -1618,10 +1630,13 @@ class Unet(nn.Module):
hiddens = [] hiddens = []
for block1, sparse_attn, block2, downsample in self.downs: for init_block, sparse_attn, resnet_blocks, downsample in self.downs:
x = block1(x, c, t) x = init_block(x, c, t)
x = sparse_attn(x) x = sparse_attn(x)
x = block2(x, c, t)
for resnet_block in resnet_blocks:
x = resnet_block(x, c, t)
hiddens.append(x) hiddens.append(x)
x = downsample(x) x = downsample(x)
@@ -1632,11 +1647,14 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c, t) x = self.mid_block2(x, mid_c, t)
for block1, sparse_attn, block2, upsample in self.ups: for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1) x = torch.cat((x, hiddens.pop()), dim=1)
x = block1(x, c, t) x = init_block(x, c, t)
x = sparse_attn(x) x = sparse_attn(x)
x = block2(x, c, t)
for resnet_block in resnet_blocks:
x = resnet_block(x, c, t)
x = upsample(x) x = upsample(x)
return self.final_conv(x) return self.final_conv(x)

View File

@@ -12,6 +12,7 @@ def get_optimizer(
betas = (0.9, 0.999), betas = (0.9, 0.999),
eps = 1e-8, eps = 1e-8,
filter_by_requires_grad = False, filter_by_requires_grad = False,
group_wd_params = True,
**kwargs **kwargs
): ):
if filter_by_requires_grad: if filter_by_requires_grad:
@@ -20,12 +21,12 @@ def get_optimizer(
if wd == 0: if wd == 0:
return Adam(params, lr = lr, betas = betas, eps = eps) return Adam(params, lr = lr, betas = betas, eps = eps)
params = set(params) if group_wd_params:
wd_params, no_wd_params = separate_weight_decayable_params(params) wd_params, no_wd_params = separate_weight_decayable_params(params)
param_groups = [ params = [
{'params': list(wd_params)}, {'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0}, {'params': list(no_wd_params), 'weight_decay': 0},
] ]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps) return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -3,7 +3,18 @@ from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder, DiffusionPrior, DiffusionPriorNetwork from x_clip import CLIP as XCLIP
from coca_pytorch import CoCa
from dalle2_pytorch.dalle2_pytorch import (
CoCaAdapter,
OpenAIClipAdapter,
Unet,
Decoder,
DiffusionPrior,
DiffusionPriorNetwork,
XClipAdapter,
)
# helper functions # helper functions
@@ -16,7 +27,44 @@ def default(val, d):
def ListOrTuple(inner_type): def ListOrTuple(inner_type):
return Union[List[inner_type], Tuple[inner_type]] return Union[List[inner_type], Tuple[inner_type]]
# pydantic classes # general pydantic classes
class TrainSplitConfig(BaseModel):
train: float = 0.75
val: float = 0.15
test: float = 0.1
@root_validator
def validate_all(cls, fields):
actual_sum = sum([*fields.values()])
if actual_sum != 1.:
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
return fields
class TrackerConfig(BaseModel):
tracker_type: str = 'console' # Decoder currently supports console and wandb
data_path: str = './models' # The path where files will be saved locally
init_config: Dict[str, Any] = None
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
wandb_project: str = ''
verbose: bool = False # Whether to print console logging for non-console trackers
# diffusion prior pydantic classes
class AdapterConfig(BaseModel):
make: str = "openai"
model: str = "ViT-L/14"
base_model_kwargs: Dict[str, Any] = None
def create(self):
if self.make == "openai":
return OpenAIClipAdapter(self.model)
elif self.make == "x-clip":
return XClipAdapter(XCLIP(**self.base_model_kwargs))
elif self.make == "coca":
return CoCaAdapter(CoCa(**self.base_model_kwargs))
else:
raise AttributeError("No adapter with that name is available.")
class DiffusionPriorNetworkConfig(BaseModel): class DiffusionPriorNetworkConfig(BaseModel):
dim: int dim: int
@@ -35,8 +83,12 @@ class DiffusionPriorNetworkConfig(BaseModel):
normformer: bool = False normformer: bool = False
rotary_emb: bool = True rotary_emb: bool = True
def create(self):
kwargs = self.dict()
return DiffusionPriorNetwork(**kwargs)
class DiffusionPriorConfig(BaseModel): class DiffusionPriorConfig(BaseModel):
# only clip-less diffusion prior config for now clip: AdapterConfig
net: DiffusionPriorNetworkConfig net: DiffusionPriorNetworkConfig
image_embed_dim: int image_embed_dim: int
image_size: int image_size: int
@@ -46,15 +98,52 @@ class DiffusionPriorConfig(BaseModel):
loss_type: str = 'l2' loss_type: str = 'l2'
predict_x_start: bool = True predict_x_start: bool = True
beta_schedule: str = 'cosine' beta_schedule: str = 'cosine'
condition_on_text_encodings: bool = True
def create(self):
kwargs = self.dict()
diffusion_prior_network = DiffusionPriorNetwork(**kwargs.pop('net'))
return DiffusionPrior(net = diffusion_prior_network, **kwargs)
class Config: class Config:
extra = "allow" extra = "allow"
def create(self):
kwargs = self.dict()
clip = AdapterConfig(**kwargs.pop('clip')).create()
diffusion_prior_network = DiffusionPriorNetworkConfig(**kwargs.pop('net')).create()
return DiffusionPrior(net = diffusion_prior_network, clip=clip, **kwargs)
class DiffusionPriorTrainConfig(BaseModel):
epochs: int = 1
lr: float = 1.1e-4
wd: float = 6.02e-2
max_grad_norm: float = 0.5
use_ema: bool = True
ema_beta: float = 0.99
amp: bool = False
save_every: int = 10000 # what steps to save on
class DiffusionPriorDataConfig(BaseModel):
image_url: str # path to embeddings folder
meta_url: str # path to metadata (captions) for images
splits: TrainSplitConfig
batch_size: int = 64
class DiffusionPriorLoadConfig(BaseModel):
source: str = None
resume: bool = False
class TrainDiffusionPriorConfig(BaseModel):
prior: DiffusionPriorConfig
data: DiffusionPriorDataConfig
train: DiffusionPriorTrainConfig
load: DiffusionPriorLoadConfig
tracker: TrackerConfig
@classmethod
def from_json_path(cls, json_path):
with open(json_path) as f:
config = json.load(f)
return cls(**config)
# decoder pydantic classes
class UnetConfig(BaseModel): class UnetConfig(BaseModel):
dim: int dim: int
dim_mults: ListOrTuple(int) dim_mults: ListOrTuple(int)
@@ -94,17 +183,6 @@ class DecoderConfig(BaseModel):
class Config: class Config:
extra = "allow" extra = "allow"
class TrainSplitConfig(BaseModel):
train: float = 0.75
val: float = 0.15
test: float = 0.1
@root_validator
def validate_all(cls, fields):
if sum([*fields.values()]) != 1.:
raise ValueError(f'{fields.keys()} must sum to 1.0')
return fields
class DecoderDataConfig(BaseModel): class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images webdataset_base_url: str # path to a webdataset with jpg images
embeddings_url: str # path to .npy files with embeddings embeddings_url: str # path to .npy files with embeddings
@@ -160,14 +238,6 @@ class DecoderEvaluateConfig(BaseModel):
KID: Dict[str, Any] = None KID: Dict[str, Any] = None
LPIPS: Dict[str, Any] = None LPIPS: Dict[str, Any] = None
class TrackerConfig(BaseModel):
tracker_type: str = 'console' # Decoder currently supports console and wandb
data_path: str = './models' # The path where files will be saved locally
init_config: Dict[str, Any] = None
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
wandb_project: str = ''
verbose: bool = False # Whether to print console logging for non-console trackers
class DecoderLoadConfig(BaseModel): class DecoderLoadConfig(BaseModel):
source: str = None # Supports file and wandb source: str = None # Supports file and wandb
run_path: str = '' # Used only if source is wandb run_path: str = '' # Used only if source is wandb

View File

@@ -254,6 +254,7 @@ class DiffusionPriorTrainer(nn.Module):
eps = 1e-6, eps = 1e-6,
max_grad_norm = None, max_grad_norm = None,
amp = False, amp = False,
group_wd_params = True,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -279,6 +280,7 @@ class DiffusionPriorTrainer(nn.Module):
lr = lr, lr = lr,
wd = wd, wd = wd,
eps = eps, eps = eps,
group_wd_params = group_wd_params,
**kwargs **kwargs
) )
@@ -410,6 +412,7 @@ class DecoderTrainer(nn.Module):
eps = 1e-8, eps = 1e-8,
max_grad_norm = 0.5, max_grad_norm = 0.5,
amp = False, amp = False,
group_wd_params = True,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -435,6 +438,7 @@ class DecoderTrainer(nn.Module):
lr = unet_lr, lr = unet_lr,
wd = unet_wd, wd = unet_wd,
eps = unet_eps, eps = unet_eps,
group_wd_params = group_wd_params,
**kwargs **kwargs
) )

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.5.0', version = '0.5.5',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',

View File

@@ -347,7 +347,7 @@ def train(
# Compute evaluation metrics # Compute evaluation metrics
if exists(evaluate_config): if exists(evaluate_config):
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40)) print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config) evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
tracker.log(evaluation, step=step, verbose=True) tracker.log(evaluation, step=step, verbose=True)
# Generate sample images # Generate sample images