allow for decoder conditioning with the text encodings from CLIP, if it is passed in. use lazy linear to avoid researchers having to worry about text encoding dimensions, but remove later if it does not work well

This commit is contained in:
Phil Wang
2022-04-14 11:46:45 -07:00
parent 69e822b7f8
commit 9f55c24db6
3 changed files with 50 additions and 19 deletions

View File

@@ -276,7 +276,7 @@ decoder = Decoder(
cond_drop_prob = 0.2
).cuda()
loss = decoder(images)
loss = decoder(images) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps