mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
remove text masking altogether in favor of deriving from text encodings (padded text encodings must be pad value of 0.)
This commit is contained in:
@@ -126,9 +126,9 @@ def report_cosine_sims(
|
||||
|
||||
# we are text conditioned, we produce an embedding from the tokenized text
|
||||
if text_conditioned:
|
||||
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
|
||||
text_embedding, text_encodings = trainer.embed_text(text_data)
|
||||
text_cond = dict(
|
||||
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
|
||||
text_embed=text_embedding, text_encodings=text_encodings
|
||||
)
|
||||
else:
|
||||
text_embedding = text_data
|
||||
@@ -146,15 +146,12 @@ def report_cosine_sims(
|
||||
|
||||
if text_conditioned:
|
||||
text_encodings_shuffled = text_encodings[rolled_idx]
|
||||
text_mask_shuffled = text_mask[rolled_idx]
|
||||
else:
|
||||
text_encodings_shuffled = None
|
||||
text_mask_shuffled = None
|
||||
|
||||
text_cond_shuffled = dict(
|
||||
text_embed=text_embed_shuffled,
|
||||
text_encodings=text_encodings_shuffled,
|
||||
mask=text_mask_shuffled,
|
||||
text_encodings=text_encodings_shuffled
|
||||
)
|
||||
|
||||
# prepare the text embedding
|
||||
|
||||
Reference in New Issue
Block a user