generate text mask within the unet and diffusion prior itself from the text encodings, if not given

This commit is contained in:
Phil Wang
2022-07-12 12:54:54 -07:00
parent e055793e5d
commit e0835acca9
2 changed files with 8 additions and 15 deletions

View File

@@ -872,7 +872,7 @@ class DiffusionPriorNetwork(nn.Module):
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
if not exists(mask):
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
mask = torch.any(text_encodings != 0., dim = -1)
# classifier free guidance
@@ -1205,7 +1205,6 @@ class DiffusionPrior(nn.Module):
if self.condition_on_text_encodings:
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
# timestep conditioning from ddpm
@@ -1819,6 +1818,9 @@ class Unet(nn.Module):
if exists(text_encodings) and self.cond_on_text_encodings:
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
if not exists(text_mask):
text_mask = torch.any(text_encodings != 0., dim = -1)
text_tokens = self.text_to_cond(text_encodings)
text_tokens = text_tokens[:, :self.max_text_len]
@@ -1827,13 +1829,10 @@ class Unet(nn.Module):
if remainder > 0:
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
text_mask = F.pad(text_mask, (0, remainder), value = False)
if exists(text_mask):
if remainder > 0:
text_mask = F.pad(text_mask, (0, remainder), value = False)
text_mask = rearrange(text_mask, 'b n -> b n 1')
text_keep_mask = text_mask & text_keep_mask
text_mask = rearrange(text_mask, 'b n -> b n 1')
text_keep_mask = text_mask & text_keep_mask
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
@@ -2440,9 +2439,6 @@ class Decoder(nn.Module):
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
if self.condition_on_text_encodings:
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
img = None
is_cuda = next(self.parameters()).is_cuda
@@ -2526,9 +2522,6 @@ class Decoder(nn.Module):
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
if self.condition_on_text_encodings:
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size)