use cross attention for conditioning unet based on image embedding tokens (which opens up the door on conditioning on text encodings as well

This commit is contained in:
Phil Wang
2022-04-14 10:10:04 -07:00
parent 95b018374a
commit 68e9883f59
3 changed files with 124 additions and 44 deletions

View File

@@ -101,7 +101,7 @@ clip = CLIP(
unet = Unet(
dim = 128,
image_embed_dim = 512,
time_dim = 128,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
@@ -264,7 +264,7 @@ loss.backward()
unet = Unet(
dim = 128,
image_embed_dim = 512,
time_dim = 128,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()