mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
add setting to attend to all text encodings regardless of padding, for diffusion prior
This commit is contained in:
@@ -806,6 +806,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
num_time_embeds = 1,
|
||||
num_image_embeds = 1,
|
||||
num_text_embeds = 1,
|
||||
attend_all_text_encodings = True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -831,6 +832,8 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||
|
||||
self.attend_all_text_encodings = attend_all_text_encodings
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
*args,
|
||||
@@ -852,7 +855,6 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
*,
|
||||
text_embed,
|
||||
text_encodings = None,
|
||||
mask = None,
|
||||
cond_drop_prob = 0.
|
||||
):
|
||||
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
||||
@@ -871,6 +873,9 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
if not exists(text_encodings):
|
||||
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
|
||||
|
||||
if self.attend_all_text_encodings:
|
||||
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
|
||||
else:
|
||||
mask = torch.any(text_encodings != 0., dim = -1)
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
@@ -133,6 +133,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
|
||||
num_time_embeds: int = 1
|
||||
num_image_embeds: int = 1
|
||||
num_text_embeds: int = 1
|
||||
attend_all_text_encodings: bool = True
|
||||
dim_head: int = 64
|
||||
heads: int = 8
|
||||
ff_mult: int = 4
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.22.1'
|
||||
__version__ = '0.22.2'
|
||||
|
||||
Reference in New Issue
Block a user