Add data flexibility to decoder trainer (#165)

* Added the ability to train decoder with text embeddings

* Added the ability to train using on the fly generated embeddings with clip

* Clip now generates embeddings for whatever is not precomputed
This commit is contained in:
Aidan Dempster
2022-06-25 22:05:20 -04:00
committed by GitHub
parent c453f468b1
commit f5760bdb92
4 changed files with 228 additions and 59 deletions

View File

@@ -13,7 +13,7 @@ from dalle2_pytorch.dalle2_pytorch import (
Decoder,
DiffusionPrior,
DiffusionPriorNetwork,
XClipAdapter,
XClipAdapter
)
# helper functions
@@ -170,6 +170,8 @@ class DecoderConfig(BaseModel):
unets: ListOrTuple(UnetConfig)
image_size: int = None
image_sizes: ListOrTuple(int) = None
condition_on_text_encodings: bool = False
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
loss_type: str = 'l2'
@@ -180,9 +182,16 @@ class DecoderConfig(BaseModel):
def create(self):
decoder_kwargs = self.dict()
unet_configs = decoder_kwargs.pop('unets')
unets = [Unet(**config) for config in unet_configs]
return Decoder(unets, **decoder_kwargs)
has_clip = exists(decoder_kwargs.pop('clip'))
clip = None
if has_clip:
clip = self.clip.create()
return Decoder(unets, clip=clip, **decoder_kwargs)
@validator('image_sizes')
def check_image_sizes(cls, image_sizes, values):
@@ -194,8 +203,9 @@ class DecoderConfig(BaseModel):
extra = "allow"
class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images
embeddings_url: str # path to .npy files with embeddings
webdataset_base_url: str # path to a webdataset with jpg images
img_embeddings_url: Optional[str] # path to .npy files with embeddings
text_embeddings_url: Optional[str] # path to .npy files with embeddings
num_workers: int = 4
batch_size: int = 64
start_shard: int = 0
@@ -268,3 +278,26 @@ class TrainDecoderConfig(BaseModel):
with open(json_path) as f:
config = json.load(f)
return cls(**config)
@root_validator
def check_has_embeddings(cls, values):
# Makes sure that enough information is provided to get the embeddings specified for training
data_config, decoder_config = values.get('data'), values.get('decoder')
if data_config is None or decoder_config is None:
# Then something else errored and we should just pass through
return values
using_text_embeddings = decoder_config.condition_on_text_encodings
using_clip = exists(decoder_config.clip)
img_emb_url = data_config.img_embeddings_url
text_emb_url = data_config.text_embeddings_url
if using_text_embeddings:
# Then we need some way to get the embeddings
assert using_clip or text_emb_url is not None, 'If condition_on_text_encodings is true, either clip or text_embeddings_url must be provided'
if using_clip:
if using_text_embeddings:
assert text_emb_url is None or img_emb_url is None, 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
else:
assert img_emb_url is None, 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
if text_emb_url:
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
return values