diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 7576698..48b332c 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -169,6 +169,11 @@ class BaseClipAdapter(nn.Module): self.clip = clip self.overrides = kwargs + def validate_and_resize_image(self, image): + image_size = image.shape[-1] + assert image_size >= self.image_size, f'you are passing in an image of size {image_size} but CLIP requires the image size to be at least {self.image_size}' + return resize_image_to(image, self.image_size) + @property def dim_latent(self): raise NotImplementedError @@ -219,7 +224,7 @@ class XClipAdapter(BaseClipAdapter): @torch.no_grad() def embed_image(self, image): - image = resize_image_to(image, self.image_size) + image = self.validate_and_resize_image(image) encoder_output = self.clip.visual_transformer(image) image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:] image_embed = self.clip.to_visual_latent(image_cls) @@ -254,7 +259,7 @@ class CoCaAdapter(BaseClipAdapter): @torch.no_grad() def embed_image(self, image): - image = resize_image_to(image, self.image_size) + image = self.validate_and_resize_image(image) image_embed, image_encodings = self.clip.embed_image(image) return EmbeddedImage(image_embed, image_encodings) @@ -315,7 +320,7 @@ class OpenAIClipAdapter(BaseClipAdapter): @torch.no_grad() def embed_image(self, image): assert not self.cleared - image = resize_image_to(image, self.image_size) + image = self.validate_and_resize_image(image) image = self.clip_normalize(image) image_embed = self.clip.encode_image(image) return EmbeddedImage(l2norm(image_embed.float()), None) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 435d64b..cbe8c74 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.17.0' +__version__ = '0.17.1'