mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
cleanup
This commit is contained in:
@@ -173,12 +173,12 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
name = 'ViT-B/32'
|
name = 'ViT-B/32'
|
||||||
):
|
):
|
||||||
import clip
|
import clip
|
||||||
openai_clip, _ = clip.load(name)
|
openai_clip, preprocess = clip.load(name)
|
||||||
super().__init__(openai_clip)
|
super().__init__(openai_clip)
|
||||||
|
|
||||||
text_attention_final = self.find_layer('ln_final')
|
text_attention_final = self.find_layer('ln_final')
|
||||||
self.handle = text_attention_final.register_forward_hook(self._hook)
|
self.handle = text_attention_final.register_forward_hook(self._hook)
|
||||||
self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
self.clip_normalize = preprocess.transforms[-1]
|
||||||
self.cleared = False
|
self.cleared = False
|
||||||
|
|
||||||
def find_layer(self, layer):
|
def find_layer(self, layer):
|
||||||
|
|||||||
Reference in New Issue
Block a user