mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
no CLIP altogether for training DiffusionPrior
This commit is contained in:
@@ -486,7 +486,10 @@ class DiffusionPrior(nn.Module):
|
||||
self,
|
||||
net,
|
||||
*,
|
||||
clip,
|
||||
clip = None,
|
||||
image_embed_dim = None,
|
||||
image_size = None,
|
||||
image_channels = 3,
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = "l1",
|
||||
@@ -495,14 +498,18 @@ class DiffusionPrior(nn.Module):
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
self.clip = clip
|
||||
|
||||
if exists(clip):
|
||||
assert isinstance(clip, CLIP)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
self.clip = clip
|
||||
else:
|
||||
assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given'
|
||||
self.clip = None
|
||||
|
||||
self.net = net
|
||||
self.image_embed_dim = clip.dim_latent
|
||||
self.channels = clip.image_channels
|
||||
self.image_size = clip.image_size
|
||||
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
|
||||
self.channels = default(image_channels, lambda: clip.image_channels)
|
||||
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
@@ -559,6 +566,8 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
assert exists(self.clip)
|
||||
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
@@ -566,6 +575,8 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_cond(self, text):
|
||||
assert exists(self.clip)
|
||||
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
||||
text_embed = self.clip.to_text_latent(text_cls)
|
||||
|
||||
Reference in New Issue
Block a user