mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 02:04:19 +01:00
prepare unet to be conditioned on image embedding, optionally text encodings, and reminder for self to build conditional dropout for classifier free guidance
This commit is contained in:
@@ -284,6 +284,8 @@ class Unet(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
|
*,
|
||||||
|
image_embed_dim,
|
||||||
out_dim = None,
|
out_dim = None,
|
||||||
dim_mults=(1, 2, 4, 8),
|
dim_mults=(1, 2, 4, 8),
|
||||||
channels = 3,
|
channels = 3,
|
||||||
@@ -338,7 +340,15 @@ class Unet(nn.Module):
|
|||||||
nn.Conv2d(dim, out_dim, 1)
|
nn.Conv2d(dim, out_dim, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, time):
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
*,
|
||||||
|
image_embed,
|
||||||
|
time,
|
||||||
|
text_encodings = None,
|
||||||
|
cond_prob_drop = 0.2
|
||||||
|
):
|
||||||
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
||||||
|
|
||||||
hiddens = []
|
hiddens = []
|
||||||
|
|||||||
Reference in New Issue
Block a user