mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
just force it so researcher can never pass in an image that is less than the size that is required for CLIP or CoCa
This commit is contained in:
@@ -169,6 +169,11 @@ class BaseClipAdapter(nn.Module):
|
|||||||
self.clip = clip
|
self.clip = clip
|
||||||
self.overrides = kwargs
|
self.overrides = kwargs
|
||||||
|
|
||||||
|
def validate_and_resize_image(self, image):
|
||||||
|
image_size = image.shape[-1]
|
||||||
|
assert image_size >= self.image_size, f'you are passing in an image of size {image_size} but CLIP requires the image size to be at least {self.image_size}'
|
||||||
|
return resize_image_to(image, self.image_size)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dim_latent(self):
|
def dim_latent(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -219,7 +224,7 @@ class XClipAdapter(BaseClipAdapter):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def embed_image(self, image):
|
def embed_image(self, image):
|
||||||
image = resize_image_to(image, self.image_size)
|
image = self.validate_and_resize_image(image)
|
||||||
encoder_output = self.clip.visual_transformer(image)
|
encoder_output = self.clip.visual_transformer(image)
|
||||||
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||||
image_embed = self.clip.to_visual_latent(image_cls)
|
image_embed = self.clip.to_visual_latent(image_cls)
|
||||||
@@ -254,7 +259,7 @@ class CoCaAdapter(BaseClipAdapter):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def embed_image(self, image):
|
def embed_image(self, image):
|
||||||
image = resize_image_to(image, self.image_size)
|
image = self.validate_and_resize_image(image)
|
||||||
image_embed, image_encodings = self.clip.embed_image(image)
|
image_embed, image_encodings = self.clip.embed_image(image)
|
||||||
return EmbeddedImage(image_embed, image_encodings)
|
return EmbeddedImage(image_embed, image_encodings)
|
||||||
|
|
||||||
@@ -315,7 +320,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def embed_image(self, image):
|
def embed_image(self, image):
|
||||||
assert not self.cleared
|
assert not self.cleared
|
||||||
image = resize_image_to(image, self.image_size)
|
image = self.validate_and_resize_image(image)
|
||||||
image = self.clip_normalize(image)
|
image = self.clip_normalize(image)
|
||||||
image_embed = self.clip.encode_image(image)
|
image_embed = self.clip.encode_image(image)
|
||||||
return EmbeddedImage(l2norm(image_embed.float()), None)
|
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.17.0'
|
__version__ = '0.17.1'
|
||||||
|
|||||||
Reference in New Issue
Block a user