mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 22:14:20 +01:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e85e736f3 | ||
|
|
f5760bdb92 | ||
|
|
c453f468b1 | ||
|
|
98f0c17759 | ||
|
|
a5b9fd6ca8 | ||
|
|
4b994601ae | ||
|
|
fddf66e91e |
@@ -1072,7 +1072,7 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
|
- [ ] speed up inference, read up on papers (ddim or diffusion-gan, etc)
|
||||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||||
- [ ] build infilling
|
- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from tqdm import tqdm
|
from tqdm.auto import tqdm
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
@@ -1766,14 +1766,13 @@ class Decoder(nn.Module):
|
|||||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||||
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
|
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
|
||||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
|
||||||
clip_denoised = True,
|
clip_denoised = True,
|
||||||
clip_x_start = True,
|
clip_x_start = True,
|
||||||
clip_adapter_overrides = dict(),
|
clip_adapter_overrides = dict(),
|
||||||
learned_variance = True,
|
learned_variance = True,
|
||||||
learned_variance_constrain_frac = False,
|
learned_variance_constrain_frac = False,
|
||||||
vb_loss_weight = 0.001,
|
vb_loss_weight = 0.001,
|
||||||
unconditional = False,
|
unconditional = False, # set to True for generating images without conditioning
|
||||||
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||||
use_dynamic_thres = False, # from the Imagen paper
|
use_dynamic_thres = False, # from the Imagen paper
|
||||||
dynamic_thres_percentile = 0.9,
|
dynamic_thres_percentile = 0.9,
|
||||||
@@ -1852,8 +1851,8 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
one_unet = one_unet.cast_model_parameters(
|
one_unet = one_unet.cast_model_parameters(
|
||||||
lowres_cond = not is_first,
|
lowres_cond = not is_first,
|
||||||
cond_on_image_embeds = is_first and not unconditional,
|
cond_on_image_embeds = not unconditional and is_first,
|
||||||
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
|
cond_on_text_encodings = not unconditional and (is_first or one_unet.cond_on_text_encodings),
|
||||||
channels = unet_channels,
|
channels = unet_channels,
|
||||||
channels_out = unet_channels_out
|
channels_out = unet_channels_out
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ def get_example_file(fs, path, file_format):
|
|||||||
"""
|
"""
|
||||||
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
|
return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
|
||||||
|
|
||||||
def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handlers.reraise_exception):
|
def embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', handler=wds.handlers.reraise_exception):
|
||||||
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
|
"""Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
|
||||||
previous_tar_url = None
|
previous_tar_url = None
|
||||||
current_embeddings = None
|
current_embeddings = None
|
||||||
@@ -56,7 +56,7 @@ def embedding_inserter(samples, embeddings_url, index_width, handler=wds.handler
|
|||||||
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
|
# We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
|
||||||
if torch.count_nonzero(embedding) == 0:
|
if torch.count_nonzero(embedding) == 0:
|
||||||
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
|
raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
|
||||||
sample["npy"] = embedding
|
sample[sample_key] = embedding
|
||||||
yield sample
|
yield sample
|
||||||
except Exception as exn: # From wds implementation
|
except Exception as exn: # From wds implementation
|
||||||
if handler(exn):
|
if handler(exn):
|
||||||
@@ -84,18 +84,20 @@ def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.re
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
|
skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
|
||||||
|
|
||||||
def verify_keys(samples, handler=wds.handlers.reraise_exception):
|
def join_embeddings(samples, handler=wds.handlers.reraise_exception):
|
||||||
"""
|
"""
|
||||||
Requires that both the image and embedding are present in the sample
|
Takes the img_emb and text_emb keys and turns them into one key "emb": { "text": text_emb, "img": img_emb }
|
||||||
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
|
either or both of text_emb and img_emb may not be in the sample so we only add the ones that exist
|
||||||
"""
|
"""
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
try:
|
try:
|
||||||
assert "jpg" in sample, f"Sample {sample['__key__']} missing image"
|
sample['emb'] = {}
|
||||||
assert "npy" in sample, f"Sample {sample['__key__']} missing embedding. Did you set embedding_folder_url?"
|
if 'text_emb' in sample:
|
||||||
|
sample['emb']['text'] = sample['text_emb']
|
||||||
|
if 'img_emb' in sample:
|
||||||
|
sample['emb']['img'] = sample['img_emb']
|
||||||
yield sample
|
yield sample
|
||||||
except Exception as exn: # From wds implementation
|
except Exception as exn: # From wds implementation
|
||||||
if handler(exn):
|
if handler(exn):
|
||||||
@@ -103,6 +105,23 @@ def verify_keys(samples, handler=wds.handlers.reraise_exception):
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def verify_keys(samples, required_keys, handler=wds.handlers.reraise_exception):
|
||||||
|
"""
|
||||||
|
Requires that both the image and embedding are present in the sample
|
||||||
|
This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
|
||||||
|
"""
|
||||||
|
for sample in samples:
|
||||||
|
try:
|
||||||
|
for key in required_keys:
|
||||||
|
assert key in sample, f"Sample {sample['__key__']} missing {key}. Has keys {sample.keys()}"
|
||||||
|
yield sample
|
||||||
|
except Exception as exn: # From wds implementation
|
||||||
|
if handler(exn):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
key_verifier = wds.filters.pipelinefilter(verify_keys)
|
||||||
|
|
||||||
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
||||||
"""
|
"""
|
||||||
A fluid interface wrapper for DataPipline that returns image embedding pairs
|
A fluid interface wrapper for DataPipline that returns image embedding pairs
|
||||||
@@ -112,7 +131,8 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
urls,
|
urls,
|
||||||
embedding_folder_url=None,
|
img_embedding_folder_url=None,
|
||||||
|
text_embedding_folder_url=None,
|
||||||
index_width=None,
|
index_width=None,
|
||||||
img_preproc=None,
|
img_preproc=None,
|
||||||
extra_keys=[],
|
extra_keys=[],
|
||||||
@@ -136,7 +156,12 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
keys = ["jpg", "npy"] + extra_keys
|
keys = ["jpg", "emb"] + extra_keys
|
||||||
|
# if img_embedding_folder_url is not None:
|
||||||
|
# keys.append("img_emb")
|
||||||
|
# if text_embedding_folder_url is not None:
|
||||||
|
# keys.append("text_emb")
|
||||||
|
# keys.extend(extra_keys)
|
||||||
self.key_map = {key: i for i, key in enumerate(keys)}
|
self.key_map = {key: i for i, key in enumerate(keys)}
|
||||||
self.resampling = resample
|
self.resampling = resample
|
||||||
self.img_preproc = img_preproc
|
self.img_preproc = img_preproc
|
||||||
@@ -145,7 +170,7 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
|||||||
# Then this has an s3 link for the webdataset and we need extra packages
|
# Then this has an s3 link for the webdataset and we need extra packages
|
||||||
if shutil.which("s3cmd") is None:
|
if shutil.which("s3cmd") is None:
|
||||||
raise RuntimeError("s3cmd is required for s3 webdataset")
|
raise RuntimeError("s3cmd is required for s3 webdataset")
|
||||||
if "s3:" in embedding_folder_url:
|
if (img_embedding_folder_url is not None and "s3:" in img_embedding_folder_url) or (text_embedding_folder_url is not None and "s3:" in text_embedding_folder_url):
|
||||||
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
|
# Then the embeddings are being loaded from s3 and fsspec requires s3fs
|
||||||
try:
|
try:
|
||||||
import s3fs
|
import s3fs
|
||||||
@@ -160,17 +185,24 @@ class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
|
|||||||
if shuffle_shards:
|
if shuffle_shards:
|
||||||
self.append(wds.filters.shuffle(1000))
|
self.append(wds.filters.shuffle(1000))
|
||||||
|
|
||||||
if embedding_folder_url is not None:
|
if img_embedding_folder_url is not None:
|
||||||
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
|
# There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
|
||||||
self.append(skip_unassociated_shards(embeddings_url=embedding_folder_url, handler=handler))
|
self.append(skip_unassociated_shards(embeddings_url=img_embedding_folder_url, handler=handler))
|
||||||
|
if text_embedding_folder_url is not None:
|
||||||
|
self.append(skip_unassociated_shards(embeddings_url=text_embedding_folder_url, handler=handler))
|
||||||
|
|
||||||
self.append(wds.tarfile_to_samples(handler=handler))
|
self.append(wds.tarfile_to_samples(handler=handler))
|
||||||
self.append(wds.decode("pilrgb", handler=handler))
|
self.append(wds.decode("pilrgb", handler=handler))
|
||||||
if embedding_folder_url is not None:
|
if img_embedding_folder_url is not None:
|
||||||
# Then we are loading embeddings for a remote source
|
# Then we are loading image embeddings for a remote source
|
||||||
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
||||||
self.append(insert_embedding(embeddings_url=embedding_folder_url, index_width=index_width, handler=handler))
|
self.append(insert_embedding(embeddings_url=img_embedding_folder_url, index_width=index_width, sample_key='img_emb', handler=handler))
|
||||||
self.append(verify_keys)
|
if text_embedding_folder_url is not None:
|
||||||
|
# Then we are loading image embeddings for a remote source
|
||||||
|
assert index_width is not None, "Reading embeddings separately requires index width length to be given"
|
||||||
|
self.append(insert_embedding(embeddings_url=text_embedding_folder_url, index_width=index_width, sample_key='text_emb', handler=handler))
|
||||||
|
self.append(join_embeddings)
|
||||||
|
self.append(key_verifier(required_keys=keys, handler=handler))
|
||||||
# Apply preprocessing
|
# Apply preprocessing
|
||||||
self.append(wds.map(self.preproc))
|
self.append(wds.map(self.preproc))
|
||||||
self.append(wds.to_tuple(*keys))
|
self.append(wds.to_tuple(*keys))
|
||||||
@@ -185,7 +217,8 @@ def create_image_embedding_dataloader(
|
|||||||
tar_url,
|
tar_url,
|
||||||
num_workers,
|
num_workers,
|
||||||
batch_size,
|
batch_size,
|
||||||
embeddings_url=None,
|
img_embeddings_url=None,
|
||||||
|
text_embeddings_url=None,
|
||||||
index_width=None,
|
index_width=None,
|
||||||
shuffle_num = None,
|
shuffle_num = None,
|
||||||
shuffle_shards = True,
|
shuffle_shards = True,
|
||||||
@@ -211,7 +244,8 @@ def create_image_embedding_dataloader(
|
|||||||
"""
|
"""
|
||||||
ds = ImageEmbeddingDataset(
|
ds = ImageEmbeddingDataset(
|
||||||
tar_url,
|
tar_url,
|
||||||
embeddings_url,
|
img_embedding_folder_url=img_embeddings_url,
|
||||||
|
text_embedding_folder_url=text_embeddings_url,
|
||||||
index_width=index_width,
|
index_width=index_width,
|
||||||
shuffle_shards=shuffle_shards,
|
shuffle_shards=shuffle_shards,
|
||||||
resample=resample_shards,
|
resample=resample_shards,
|
||||||
@@ -228,4 +262,4 @@ def create_image_embedding_dataloader(
|
|||||||
prefetch_factor=2, # This might be good to have high so the next npy file is prefetched
|
prefetch_factor=2, # This might be good to have high so the next npy file is prefetched
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
shuffle=False
|
shuffle=False
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from dalle2_pytorch.dalle2_pytorch import (
|
|||||||
Decoder,
|
Decoder,
|
||||||
DiffusionPrior,
|
DiffusionPrior,
|
||||||
DiffusionPriorNetwork,
|
DiffusionPriorNetwork,
|
||||||
XClipAdapter,
|
XClipAdapter
|
||||||
)
|
)
|
||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
@@ -170,6 +170,8 @@ class DecoderConfig(BaseModel):
|
|||||||
unets: ListOrTuple(UnetConfig)
|
unets: ListOrTuple(UnetConfig)
|
||||||
image_size: int = None
|
image_size: int = None
|
||||||
image_sizes: ListOrTuple(int) = None
|
image_sizes: ListOrTuple(int) = None
|
||||||
|
condition_on_text_encodings: bool = False
|
||||||
|
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
||||||
channels: int = 3
|
channels: int = 3
|
||||||
timesteps: int = 1000
|
timesteps: int = 1000
|
||||||
loss_type: str = 'l2'
|
loss_type: str = 'l2'
|
||||||
@@ -180,9 +182,16 @@ class DecoderConfig(BaseModel):
|
|||||||
|
|
||||||
def create(self):
|
def create(self):
|
||||||
decoder_kwargs = self.dict()
|
decoder_kwargs = self.dict()
|
||||||
|
|
||||||
unet_configs = decoder_kwargs.pop('unets')
|
unet_configs = decoder_kwargs.pop('unets')
|
||||||
unets = [Unet(**config) for config in unet_configs]
|
unets = [Unet(**config) for config in unet_configs]
|
||||||
return Decoder(unets, **decoder_kwargs)
|
|
||||||
|
has_clip = exists(decoder_kwargs.pop('clip'))
|
||||||
|
clip = None
|
||||||
|
if has_clip:
|
||||||
|
clip = self.clip.create()
|
||||||
|
|
||||||
|
return Decoder(unets, clip=clip, **decoder_kwargs)
|
||||||
|
|
||||||
@validator('image_sizes')
|
@validator('image_sizes')
|
||||||
def check_image_sizes(cls, image_sizes, values):
|
def check_image_sizes(cls, image_sizes, values):
|
||||||
@@ -194,8 +203,9 @@ class DecoderConfig(BaseModel):
|
|||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|
||||||
class DecoderDataConfig(BaseModel):
|
class DecoderDataConfig(BaseModel):
|
||||||
webdataset_base_url: str # path to a webdataset with jpg images
|
webdataset_base_url: str # path to a webdataset with jpg images
|
||||||
embeddings_url: str # path to .npy files with embeddings
|
img_embeddings_url: Optional[str] # path to .npy files with embeddings
|
||||||
|
text_embeddings_url: Optional[str] # path to .npy files with embeddings
|
||||||
num_workers: int = 4
|
num_workers: int = 4
|
||||||
batch_size: int = 64
|
batch_size: int = 64
|
||||||
start_shard: int = 0
|
start_shard: int = 0
|
||||||
@@ -268,3 +278,26 @@ class TrainDecoderConfig(BaseModel):
|
|||||||
with open(json_path) as f:
|
with open(json_path) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
return cls(**config)
|
return cls(**config)
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def check_has_embeddings(cls, values):
|
||||||
|
# Makes sure that enough information is provided to get the embeddings specified for training
|
||||||
|
data_config, decoder_config = values.get('data'), values.get('decoder')
|
||||||
|
if data_config is None or decoder_config is None:
|
||||||
|
# Then something else errored and we should just pass through
|
||||||
|
return values
|
||||||
|
using_text_embeddings = decoder_config.condition_on_text_encodings
|
||||||
|
using_clip = exists(decoder_config.clip)
|
||||||
|
img_emb_url = data_config.img_embeddings_url
|
||||||
|
text_emb_url = data_config.text_embeddings_url
|
||||||
|
if using_text_embeddings:
|
||||||
|
# Then we need some way to get the embeddings
|
||||||
|
assert using_clip or text_emb_url is not None, 'If condition_on_text_encodings is true, either clip or text_embeddings_url must be provided'
|
||||||
|
if using_clip:
|
||||||
|
if using_text_embeddings:
|
||||||
|
assert text_emb_url is None or img_emb_url is None, 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
|
||||||
|
else:
|
||||||
|
assert img_emb_url is None, 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
|
||||||
|
if text_emb_url:
|
||||||
|
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
|
||||||
|
return values
|
||||||
|
|||||||
@@ -451,6 +451,8 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||||
|
|
||||||
|
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||||
|
|
||||||
optimizers = []
|
optimizers = []
|
||||||
|
|
||||||
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
|
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
|
||||||
@@ -576,6 +578,18 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@cast_torch_tensor
|
||||||
|
@prior_sample_in_chunks
|
||||||
|
def embed_text(self, *args, **kwargs):
|
||||||
|
return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@cast_torch_tensor
|
||||||
|
@prior_sample_in_chunks
|
||||||
|
def embed_image(self, *args, **kwargs):
|
||||||
|
return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs)
|
||||||
|
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.11.4'
|
__version__ = '0.12.0'
|
||||||
|
|||||||
174
train_decoder.py
174
train_decoder.py
@@ -6,6 +6,7 @@ from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
|
|||||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||||
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
||||||
|
from clip import tokenize
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
import torch
|
import torch
|
||||||
@@ -33,7 +34,8 @@ def exists(val):
|
|||||||
def create_dataloaders(
|
def create_dataloaders(
|
||||||
available_shards,
|
available_shards,
|
||||||
webdataset_base_url,
|
webdataset_base_url,
|
||||||
embeddings_url,
|
img_embeddings_url=None,
|
||||||
|
text_embeddings_url=None,
|
||||||
shard_width=6,
|
shard_width=6,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
@@ -63,14 +65,15 @@ def create_dataloaders(
|
|||||||
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
|
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
|
||||||
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
|
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
|
||||||
|
|
||||||
create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader(
|
create_dataloader = lambda tar_urls, shuffle=False, resample=False, for_sampling=False: create_image_embedding_dataloader(
|
||||||
tar_url=tar_urls,
|
tar_url=tar_urls,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
batch_size=batch_size if not for_sampling else n_sample_images,
|
batch_size=batch_size if not for_sampling else n_sample_images,
|
||||||
embeddings_url=embeddings_url,
|
img_embeddings_url=img_embeddings_url,
|
||||||
|
text_embeddings_url=text_embeddings_url,
|
||||||
index_width=index_width,
|
index_width=index_width,
|
||||||
shuffle_num = None,
|
shuffle_num = None,
|
||||||
extra_keys= ["txt"] if with_text else [],
|
extra_keys= ["txt"],
|
||||||
shuffle_shards = shuffle,
|
shuffle_shards = shuffle,
|
||||||
resample_shards = resample,
|
resample_shards = resample,
|
||||||
img_preproc=img_preproc,
|
img_preproc=img_preproc,
|
||||||
@@ -79,8 +82,8 @@ def create_dataloaders(
|
|||||||
|
|
||||||
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
|
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
|
||||||
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
|
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
|
||||||
val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True)
|
val_dataloader = create_dataloader(val_urls, shuffle=False)
|
||||||
test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True)
|
test_dataloader = create_dataloader(test_urls, shuffle=False)
|
||||||
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
|
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
|
||||||
return {
|
return {
|
||||||
"train": train_dataloader,
|
"train": train_dataloader,
|
||||||
@@ -104,42 +107,65 @@ def get_example_data(dataloader, device, n=5):
|
|||||||
Samples the dataloader and returns a zipped list of examples
|
Samples the dataloader and returns a zipped list of examples
|
||||||
"""
|
"""
|
||||||
images = []
|
images = []
|
||||||
embeddings = []
|
img_embeddings = []
|
||||||
|
text_embeddings = []
|
||||||
captions = []
|
captions = []
|
||||||
dataset_keys = get_dataset_keys(dataloader)
|
for img, emb, txt in dataloader:
|
||||||
has_caption = "txt" in dataset_keys
|
img_emb, text_emb = emb.get('img'), emb.get('text')
|
||||||
for data in dataloader:
|
if img_emb is not None:
|
||||||
if has_caption:
|
img_emb = img_emb.to(device=device, dtype=torch.float)
|
||||||
img, emb, txt = data
|
img_embeddings.extend(list(img_emb))
|
||||||
else:
|
else:
|
||||||
img, emb = data
|
# Then we add None img.shape[0] times
|
||||||
txt = [""] * emb.shape[0]
|
img_embeddings.extend([None]*img.shape[0])
|
||||||
|
if text_emb is not None:
|
||||||
|
text_emb = text_emb.to(device=device, dtype=torch.float)
|
||||||
|
text_embeddings.extend(list(text_emb))
|
||||||
|
else:
|
||||||
|
# Then we add None img.shape[0] times
|
||||||
|
text_embeddings.extend([None]*img.shape[0])
|
||||||
img = img.to(device=device, dtype=torch.float)
|
img = img.to(device=device, dtype=torch.float)
|
||||||
emb = emb.to(device=device, dtype=torch.float)
|
|
||||||
images.extend(list(img))
|
images.extend(list(img))
|
||||||
embeddings.extend(list(emb))
|
|
||||||
captions.extend(list(txt))
|
captions.extend(list(txt))
|
||||||
if len(images) >= n:
|
if len(images) >= n:
|
||||||
break
|
break
|
||||||
return list(zip(images[:n], embeddings[:n], captions[:n]))
|
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
|
||||||
|
|
||||||
def generate_samples(trainer, example_data, text_prepend=""):
|
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""):
|
||||||
"""
|
"""
|
||||||
Takes example data and generates images from the embeddings
|
Takes example data and generates images from the embeddings
|
||||||
Returns three lists: real images, generated images, and captions
|
Returns three lists: real images, generated images, and captions
|
||||||
"""
|
"""
|
||||||
real_images, embeddings, txts = zip(*example_data)
|
real_images, img_embeddings, text_embeddings, txts = zip(*example_data)
|
||||||
embeddings_tensor = torch.stack(embeddings)
|
sample_params = {}
|
||||||
samples = trainer.sample(embeddings_tensor)
|
if img_embeddings[0] is None:
|
||||||
|
# Generate image embeddings from clip
|
||||||
|
imgs_tensor = torch.stack(real_images)
|
||||||
|
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
|
||||||
|
sample_params["image_embed"] = img_embeddings
|
||||||
|
else:
|
||||||
|
# Then we are using precomputed image embeddings
|
||||||
|
img_embeddings = torch.stack(img_embeddings)
|
||||||
|
sample_params["image_embed"] = img_embeddings
|
||||||
|
if condition_on_text_encodings:
|
||||||
|
if text_embeddings[0] is None:
|
||||||
|
# Generate text embeddings from text
|
||||||
|
tokenized_texts = tokenize(txts, truncate=True)
|
||||||
|
sample_params["text"] = tokenized_texts
|
||||||
|
else:
|
||||||
|
# Then we are using precomputed text embeddings
|
||||||
|
text_embeddings = torch.stack(text_embeddings)
|
||||||
|
sample_params["text_encodings"] = text_embeddings
|
||||||
|
samples = trainer.sample(**sample_params)
|
||||||
generated_images = list(samples)
|
generated_images = list(samples)
|
||||||
captions = [text_prepend + txt for txt in txts]
|
captions = [text_prepend + txt for txt in txts]
|
||||||
return real_images, generated_images, captions
|
return real_images, generated_images, captions
|
||||||
|
|
||||||
def generate_grid_samples(trainer, examples, text_prepend=""):
|
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
|
||||||
"""
|
"""
|
||||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||||
"""
|
"""
|
||||||
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
|
||||||
|
|
||||||
real_image_size = real_images[0].shape[-1]
|
real_image_size = real_images[0].shape[-1]
|
||||||
generated_image_size = generated_images[0].shape[-1]
|
generated_image_size = generated_images[0].shape[-1]
|
||||||
@@ -151,7 +177,7 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
|
|||||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||||
return grid_images, captions
|
return grid_images, captions
|
||||||
|
|
||||||
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||||
"""
|
"""
|
||||||
Computes evaluation metrics for the decoder
|
Computes evaluation metrics for the decoder
|
||||||
"""
|
"""
|
||||||
@@ -161,7 +187,7 @@ def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID
|
|||||||
if len(examples) == 0:
|
if len(examples) == 0:
|
||||||
print("No data to evaluate. Check that your dataloader has shards.")
|
print("No data to evaluate. Check that your dataloader has shards.")
|
||||||
return metrics
|
return metrics
|
||||||
real_images, generated_images, captions = generate_samples(trainer, examples)
|
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings)
|
||||||
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||||
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||||
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||||
@@ -250,6 +276,7 @@ def train(
|
|||||||
save_latest=True,
|
save_latest=True,
|
||||||
save_best=True,
|
save_best=True,
|
||||||
unet_training_mask=None,
|
unet_training_mask=None,
|
||||||
|
condition_on_text_encodings=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -258,8 +285,8 @@ def train(
|
|||||||
is_master = accelerator.process_index == 0
|
is_master = accelerator.process_index == 0
|
||||||
|
|
||||||
trainer = DecoderTrainer(
|
trainer = DecoderTrainer(
|
||||||
accelerator,
|
decoder=decoder,
|
||||||
decoder,
|
accelerator=accelerator,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -268,6 +295,7 @@ def train(
|
|||||||
validation_losses = []
|
validation_losses = []
|
||||||
next_task = 'train'
|
next_task = 'train'
|
||||||
sample = 0
|
sample = 0
|
||||||
|
samples_seen = 0
|
||||||
val_sample = 0
|
val_sample = 0
|
||||||
step = lambda: int(trainer.step.item())
|
step = lambda: int(trainer.step.item())
|
||||||
|
|
||||||
@@ -306,13 +334,22 @@ def train(
|
|||||||
last_snapshot = sample
|
last_snapshot = sample
|
||||||
|
|
||||||
if next_task == 'train':
|
if next_task == 'train':
|
||||||
for i, (img, emb) in enumerate(dataloaders["train"]):
|
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
|
||||||
# We want to count the total number of samples across all processes
|
# We want to count the total number of samples across all processes
|
||||||
sample_length_tensor[0] = len(img)
|
sample_length_tensor[0] = len(img)
|
||||||
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
||||||
total_samples = all_samples.sum().item()
|
total_samples = all_samples.sum().item()
|
||||||
sample += total_samples
|
sample += total_samples
|
||||||
img, emb = send_to_device((img, emb))
|
samples_seen += total_samples
|
||||||
|
img_emb = emb.get('img')
|
||||||
|
has_img_embedding = img_emb is not None
|
||||||
|
if has_img_embedding:
|
||||||
|
img_emb, = send_to_device((img_emb,))
|
||||||
|
text_emb = emb.get('text')
|
||||||
|
has_text_embedding = text_emb is not None
|
||||||
|
if has_text_embedding:
|
||||||
|
text_emb, = send_to_device((text_emb,))
|
||||||
|
img, = send_to_device((img,))
|
||||||
|
|
||||||
trainer.train()
|
trainer.train()
|
||||||
for unet in range(1, trainer.num_unets+1):
|
for unet in range(1, trainer.num_unets+1):
|
||||||
@@ -320,7 +357,20 @@ def train(
|
|||||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
|
forward_params = {}
|
||||||
|
if has_img_embedding:
|
||||||
|
forward_params['image_embed'] = img_emb
|
||||||
|
else:
|
||||||
|
# Forward pass automatically generates embedding
|
||||||
|
pass
|
||||||
|
if condition_on_text_encodings:
|
||||||
|
if has_text_embedding:
|
||||||
|
forward_params['text_encodings'] = text_emb
|
||||||
|
else:
|
||||||
|
# Then we need to pass the text instead
|
||||||
|
tokenized_texts = tokenize(txt, truncate=True)
|
||||||
|
forward_params['text'] = tokenized_texts
|
||||||
|
loss = trainer.forward(img, **forward_params, unet_number=unet)
|
||||||
trainer.update(unet_number=unet)
|
trainer.update(unet_number=unet)
|
||||||
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
||||||
|
|
||||||
@@ -334,14 +384,20 @@ def train(
|
|||||||
mask = unet_all_losses != 0
|
mask = unet_all_losses != 0
|
||||||
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
|
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
|
||||||
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
|
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
|
||||||
|
|
||||||
|
# gather decay rate on each UNet
|
||||||
|
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)}
|
||||||
|
|
||||||
log_data = {
|
log_data = {
|
||||||
"Epoch": epoch,
|
"Epoch": epoch,
|
||||||
"Sample": sample,
|
"Sample": sample,
|
||||||
"Step": i,
|
"Step": i,
|
||||||
"Samples per second": samples_per_sec,
|
"Samples per second": samples_per_sec,
|
||||||
|
"Samples Seen": samples_seen,
|
||||||
|
**ema_decay_list,
|
||||||
**loss_map
|
**loss_map
|
||||||
}
|
}
|
||||||
# print(f"I am rank {accelerator.state.process_index}. Example weight: {trainer.decoder.state_dict()['module.unets.0.init_conv.convs.0.weight'][0,0,0,0]}")
|
|
||||||
if is_master:
|
if is_master:
|
||||||
tracker.log(log_data, step=step(), verbose=True)
|
tracker.log(log_data, step=step(), verbose=True)
|
||||||
|
|
||||||
@@ -358,7 +414,7 @@ def train(
|
|||||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
||||||
if exists(n_sample_images) and n_sample_images > 0:
|
if exists(n_sample_images) and n_sample_images > 0:
|
||||||
trainer.eval()
|
trainer.eval()
|
||||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||||
|
|
||||||
if epoch_samples is not None and sample >= epoch_samples:
|
if epoch_samples is not None and sample >= epoch_samples:
|
||||||
@@ -381,14 +437,35 @@ def train(
|
|||||||
all_samples = accelerator.gather(val_sample_length_tensor)
|
all_samples = accelerator.gather(val_sample_length_tensor)
|
||||||
total_samples = all_samples.sum().item()
|
total_samples = all_samples.sum().item()
|
||||||
val_sample += total_samples
|
val_sample += total_samples
|
||||||
img, emb = send_to_device((img, emb))
|
img_emb = emb.get('img')
|
||||||
|
has_img_embedding = img_emb is not None
|
||||||
|
if has_img_embedding:
|
||||||
|
img_emb, = send_to_device((img_emb,))
|
||||||
|
text_emb = emb.get('text')
|
||||||
|
has_text_embedding = text_emb is not None
|
||||||
|
if has_text_embedding:
|
||||||
|
text_emb, = send_to_device((text_emb,))
|
||||||
|
img, = send_to_device((img,))
|
||||||
|
|
||||||
for unet in range(1, len(decoder.unets)+1):
|
for unet in range(1, len(decoder.unets)+1):
|
||||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||||
# No need to evaluate an unchanging unet
|
# No need to evaluate an unchanging unet
|
||||||
continue
|
continue
|
||||||
|
|
||||||
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
|
forward_params = {}
|
||||||
|
if has_img_embedding:
|
||||||
|
forward_params['image_embed'] = img_emb.float()
|
||||||
|
else:
|
||||||
|
# Forward pass automatically generates embedding
|
||||||
|
pass
|
||||||
|
if condition_on_text_encodings:
|
||||||
|
if has_text_embedding:
|
||||||
|
forward_params['text_encodings'] = text_emb.float()
|
||||||
|
else:
|
||||||
|
# Then we need to pass the text instead
|
||||||
|
tokenized_texts = tokenize(txt, truncate=True)
|
||||||
|
forward_params['text'] = tokenized_texts
|
||||||
|
loss = trainer.forward(img.float(), **forward_params, unet_number=unet)
|
||||||
average_val_loss_tensor[0, unet-1] += loss
|
average_val_loss_tensor[0, unet-1] += loss
|
||||||
|
|
||||||
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
||||||
@@ -415,7 +492,7 @@ def train(
|
|||||||
if next_task == 'eval':
|
if next_task == 'eval':
|
||||||
if exists(evaluate_config):
|
if exists(evaluate_config):
|
||||||
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
|
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
|
||||||
if is_master:
|
if is_master:
|
||||||
tracker.log(evaluation, step=step(), verbose=True)
|
tracker.log(evaluation, step=step(), verbose=True)
|
||||||
next_task = 'sample'
|
next_task = 'sample'
|
||||||
@@ -426,8 +503,8 @@ def train(
|
|||||||
# Generate examples and save the model if we are the master
|
# Generate examples and save the model if we are the master
|
||||||
# Generate sample images
|
# Generate sample images
|
||||||
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
||||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
|
test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ")
|
||||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||||
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
||||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||||
|
|
||||||
@@ -517,14 +594,35 @@ def initialize_training(config, config_path):
|
|||||||
# Create and initialize the tracker if we are the master
|
# Create and initialize the tracker if we are the master
|
||||||
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
|
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
|
||||||
|
|
||||||
|
has_img_embeddings = config.data.img_embeddings_url is not None
|
||||||
|
has_text_embeddings = config.data.text_embeddings_url is not None
|
||||||
|
conditioning_on_text = config.decoder.condition_on_text_encodings
|
||||||
|
has_clip_model = config.decoder.clip is not None
|
||||||
|
data_source_string = ""
|
||||||
|
if has_img_embeddings:
|
||||||
|
data_source_string += "precomputed image embeddings"
|
||||||
|
elif has_clip_model:
|
||||||
|
data_source_string += "clip image embeddings generation"
|
||||||
|
else:
|
||||||
|
raise ValueError("No image embeddings source specified")
|
||||||
|
if conditioning_on_text:
|
||||||
|
if has_text_embeddings:
|
||||||
|
data_source_string += " and precomputed text embeddings"
|
||||||
|
elif has_clip_model:
|
||||||
|
data_source_string += " and clip text encoding generation"
|
||||||
|
else:
|
||||||
|
raise ValueError("No text embeddings source specified")
|
||||||
|
|
||||||
accelerator.print(print_ribbon("Loaded Config", repeat=40))
|
accelerator.print(print_ribbon("Loaded Config", repeat=40))
|
||||||
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
|
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
|
||||||
|
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
|
||||||
accelerator.print(f"Number of parameters: {num_parameters}")
|
accelerator.print(f"Number of parameters: {num_parameters}")
|
||||||
train(dataloaders, decoder, accelerator,
|
train(dataloaders, decoder, accelerator,
|
||||||
tracker=tracker,
|
tracker=tracker,
|
||||||
inference_device=accelerator.device,
|
inference_device=accelerator.device,
|
||||||
load_config=config.load,
|
load_config=config.load,
|
||||||
evaluate_config=config.evaluate,
|
evaluate_config=config.evaluate,
|
||||||
|
condition_on_text_encodings=config.decoder.condition_on_text_encodings,
|
||||||
**config.train.dict(),
|
**config.train.dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user