diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 4541285..e5d4cec 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -173,12 +173,12 @@ class OpenAIClipAdapter(BaseClipAdapter): name = 'ViT-B/32' ): import clip - openai_clip, _ = clip.load(name) + openai_clip, preprocess = clip.load(name) super().__init__(openai_clip) text_attention_final = self.find_layer('ln_final') self.handle = text_attention_final.register_forward_hook(self._hook) - self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + self.clip_normalize = preprocess.transforms[-1] self.cleared = False def find_layer(self, layer):