mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
bug fixes for text conditioning update (#175)
This commit is contained in:
@@ -596,9 +596,10 @@ def initialize_training(config, config_path):
|
||||
|
||||
has_img_embeddings = config.data.img_embeddings_url is not None
|
||||
has_text_embeddings = config.data.text_embeddings_url is not None
|
||||
conditioning_on_text = config.decoder.condition_on_text_encodings
|
||||
conditioning_on_text = config.decoder.unets[0].cond_on_text_encodings
|
||||
has_clip_model = config.decoder.clip is not None
|
||||
data_source_string = ""
|
||||
|
||||
if has_img_embeddings:
|
||||
data_source_string += "precomputed image embeddings"
|
||||
elif has_clip_model:
|
||||
@@ -622,7 +623,7 @@ def initialize_training(config, config_path):
|
||||
inference_device=accelerator.device,
|
||||
load_config=config.load,
|
||||
evaluate_config=config.evaluate,
|
||||
condition_on_text_encodings=config.decoder.condition_on_text_encodings,
|
||||
condition_on_text_encodings=conditioning_on_text,
|
||||
**config.train.dict(),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user