no CLIP altogether for training DiffusionPrior

This commit is contained in:
Phil Wang
2022-04-26 10:23:34 -07:00
parent bdf5e9c009
commit c30544b73a
3 changed files with 68 additions and 8 deletions

View File

@@ -486,7 +486,10 @@ class DiffusionPrior(nn.Module):
self,
net,
*,
clip,
clip = None,
image_embed_dim = None,
image_size = None,
image_channels = 3,
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = "l1",
@@ -495,14 +498,18 @@ class DiffusionPrior(nn.Module):
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
):
super().__init__()
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
self.clip = clip
if exists(clip):
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
self.clip = clip
else:
assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given'
self.clip = None
self.net = net
self.image_embed_dim = clip.dim_latent
self.channels = clip.image_channels
self.image_size = clip.image_size
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
self.channels = default(image_channels, lambda: clip.image_channels)
self.cond_drop_prob = cond_drop_prob
self.condition_on_text_encodings = condition_on_text_encodings
@@ -559,6 +566,8 @@ class DiffusionPrior(nn.Module):
@torch.no_grad()
def get_image_embed(self, image):
assert exists(self.clip)
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
@@ -566,6 +575,8 @@ class DiffusionPrior(nn.Module):
@torch.no_grad()
def get_text_cond(self, text):
assert exists(self.clip)
text_encodings = self.clip.text_transformer(text)
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)