diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 5b76322..7a516de 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -872,7 +872,7 @@ class DiffusionPriorNetwork(nn.Module): text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype) if not exists(mask): - mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool) + mask = torch.any(text_encodings != 0., dim = -1) # classifier free guidance @@ -1205,7 +1205,6 @@ class DiffusionPrior(nn.Module): if self.condition_on_text_encodings: assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified' - text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1)) text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask} # timestep conditioning from ddpm @@ -1819,6 +1818,9 @@ class Unet(nn.Module): if exists(text_encodings) and self.cond_on_text_encodings: assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.' + if not exists(text_mask): + text_mask = torch.any(text_encodings != 0., dim = -1) + text_tokens = self.text_to_cond(text_encodings) text_tokens = text_tokens[:, :self.max_text_len] @@ -1827,13 +1829,10 @@ class Unet(nn.Module): if remainder > 0: text_tokens = F.pad(text_tokens, (0, 0, 0, remainder)) + text_mask = F.pad(text_mask, (0, remainder), value = False) - if exists(text_mask): - if remainder > 0: - text_mask = F.pad(text_mask, (0, remainder), value = False) - - text_mask = rearrange(text_mask, 'b n -> b n 1') - text_keep_mask = text_mask & text_keep_mask + text_mask = rearrange(text_mask, 'b n -> b n 1') + text_keep_mask = text_mask & text_keep_mask null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working @@ -2440,9 +2439,6 @@ class Decoder(nn.Module): assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' - if self.condition_on_text_encodings: - text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1)) - img = None is_cuda = next(self.parameters()).is_cuda @@ -2526,9 +2522,6 @@ class Decoder(nn.Module): assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' - if self.condition_on_text_encodings: - text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1)) - lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None image = resize_image_to(image, target_image_size) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index e453371..8c306aa 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.21.0' +__version__ = '0.21.1'