mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
prepare unet to be conditioned on image embedding, optionally text encodings, and reminder for self to build conditional dropout for classifier free guidance
This commit is contained in:
@@ -284,6 +284,8 @@ class Unet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
image_embed_dim,
|
||||
out_dim = None,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
channels = 3,
|
||||
@@ -338,7 +340,15 @@ class Unet(nn.Module):
|
||||
nn.Conv2d(dim, out_dim, 1)
|
||||
)
|
||||
|
||||
def forward(self, x, time):
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
*,
|
||||
image_embed,
|
||||
time,
|
||||
text_encodings = None,
|
||||
cond_prob_drop = 0.2
|
||||
):
|
||||
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
||||
|
||||
hiddens = []
|
||||
|
||||
Reference in New Issue
Block a user