diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index 1f9ccec..a333637 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -289,7 +289,7 @@ class TrainDecoderConfig(BaseModel): # Then something else errored and we should just pass through return values - using_text_encodings = any([unet.cond_on_text_encodings for unet in decoder_config.unets]) + using_text_embeddings = any([unet.cond_on_text_encodings for unet in decoder_config.unets]) using_clip = exists(decoder_config.clip) img_emb_url = data_config.img_embeddings_url text_emb_url = data_config.text_embeddings_url