prepare unet to be conditioned on image embedding, optionally text encodings, and reminder for self to build conditional dropout for classifier free guidance

This commit is contained in:
Phil Wang
2022-04-12 12:38:56 -07:00
parent 46dde54948
commit 59b8abe09e

View File

@@ -284,6 +284,8 @@ class Unet(nn.Module):
def __init__(
self,
dim,
*,
image_embed_dim,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
@@ -338,7 +340,15 @@ class Unet(nn.Module):
nn.Conv2d(dim, out_dim, 1)
)
def forward(self, x, time):
def forward(
self,
x,
*,
image_embed,
time,
text_encodings = None,
cond_prob_drop = 0.2
):
t = self.time_mlp(time) if exists(self.time_mlp) else None
hiddens = []