remove text masking altogether in favor of deriving from text encodings (padded text encodings must be pad value of 0.)

This commit is contained in:
Phil Wang
2022-07-12 15:40:31 -07:00
parent bb3ff0ac67
commit e76e89f9eb
4 changed files with 28 additions and 41 deletions

View File

@@ -126,9 +126,9 @@ def report_cosine_sims(
# we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned:
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
text_embedding, text_encodings = trainer.embed_text(text_data)
text_cond = dict(
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
text_embed=text_embedding, text_encodings=text_encodings
)
else:
text_embedding = text_data
@@ -146,15 +146,12 @@ def report_cosine_sims(
if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx]
text_mask_shuffled = text_mask[rolled_idx]
else:
text_encodings_shuffled = None
text_mask_shuffled = None
text_cond_shuffled = dict(
text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled,
mask=text_mask_shuffled,
text_encodings=text_encodings_shuffled
)
# prepare the text embedding