mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
make it so even if text mask is omitted, it will be derived based on whether text encodings are all 0s or not, simplify dataloading
This commit is contained in:
@@ -220,6 +220,7 @@ class XClipAdapter(BaseClipAdapter):
|
|||||||
encoder_output = self.clip.text_transformer(text)
|
encoder_output = self.clip.text_transformer(text)
|
||||||
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||||
text_embed = self.clip.to_text_latent(text_cls)
|
text_embed = self.clip.to_text_latent(text_cls)
|
||||||
|
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
||||||
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
|
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -255,6 +256,7 @@ class CoCaAdapter(BaseClipAdapter):
|
|||||||
text = text[..., :self.max_text_len]
|
text = text[..., :self.max_text_len]
|
||||||
text_mask = text != 0
|
text_mask = text != 0
|
||||||
text_embed, text_encodings = self.clip.embed_text(text)
|
text_embed, text_encodings = self.clip.embed_text(text)
|
||||||
|
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
||||||
return EmbeddedText(text_embed, text_encodings, text_mask)
|
return EmbeddedText(text_embed, text_encodings, text_mask)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -314,6 +316,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
|
|
||||||
text_embed = self.clip.encode_text(text)
|
text_embed = self.clip.encode_text(text)
|
||||||
text_encodings = self.text_encodings
|
text_encodings = self.text_encodings
|
||||||
|
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
||||||
del self.text_encodings
|
del self.text_encodings
|
||||||
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
|
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
|
||||||
|
|
||||||
@@ -1197,6 +1200,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
if self.condition_on_text_encodings:
|
if self.condition_on_text_encodings:
|
||||||
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
|
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
|
||||||
|
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
|
||||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||||
|
|
||||||
# timestep conditioning from ddpm
|
# timestep conditioning from ddpm
|
||||||
@@ -2410,6 +2414,9 @@ class Decoder(nn.Module):
|
|||||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||||
|
|
||||||
|
if self.condition_on_text_encodings:
|
||||||
|
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
|
||||||
|
|
||||||
img = None
|
img = None
|
||||||
is_cuda = next(self.parameters()).is_cuda
|
is_cuda = next(self.parameters()).is_cuda
|
||||||
|
|
||||||
@@ -2493,6 +2500,9 @@ class Decoder(nn.Module):
|
|||||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||||
|
|
||||||
|
if self.condition_on_text_encodings:
|
||||||
|
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
|
||||||
|
|
||||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||||
image = resize_image_to(image, target_image_size)
|
image = resize_image_to(image, target_image_size)
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.19.6'
|
__version__ = '0.20.0'
|
||||||
|
|||||||
Reference in New Issue
Block a user