mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
prepare for ability to integrate other clips other than x-clip
This commit is contained in:
@@ -647,7 +647,7 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
|
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
|
||||||
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
|
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
|
||||||
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||||
- [ ] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||||
|
|||||||
@@ -89,6 +89,59 @@ def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://
|
|||||||
|
|
||||||
return F.interpolate(t, size = shape, mode = mode, align_corners = False)
|
return F.interpolate(t, size = shape, mode = mode, align_corners = False)
|
||||||
|
|
||||||
|
# clip related adapters
|
||||||
|
|
||||||
|
class BaseClipAdapter(nn.Module):
|
||||||
|
def __init__(self, clip):
|
||||||
|
super().__init__()
|
||||||
|
self.clip = clip
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dim_latent(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_size(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_channels(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def embed_text(self, text):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def embed_image(self, image):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class XClipAdapter(BaseClipAdapter):
|
||||||
|
@property
|
||||||
|
def dim_latent(self):
|
||||||
|
return self.clip.dim_latent
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_size(self):
|
||||||
|
return self.clip.image_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_channels(self):
|
||||||
|
return self.clip.image_channels
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def embed_text(self, text):
|
||||||
|
encoder_output = self.clip.text_transformer(text)
|
||||||
|
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||||
|
text_embed = self.clip.to_text_latent(text_cls)
|
||||||
|
return l2norm(text_embed), text_encodings
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def embed_image(self, image):
|
||||||
|
image = resize_image_to(image, self.image_size)
|
||||||
|
encoder_output = self.clip.visual_transformer(image)
|
||||||
|
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
||||||
|
image_embed = self.clip.to_visual_latent(image_cls)
|
||||||
|
return l2norm(image_embed), image_encodings
|
||||||
|
|
||||||
# classifier free guidance functions
|
# classifier free guidance functions
|
||||||
|
|
||||||
def prob_mask_like(shape, prob, device):
|
def prob_mask_like(shape, prob, device):
|
||||||
@@ -595,7 +648,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
if exists(clip):
|
if exists(clip):
|
||||||
assert isinstance(clip, CLIP)
|
assert isinstance(clip, CLIP)
|
||||||
freeze_model_and_make_eval_(clip)
|
freeze_model_and_make_eval_(clip)
|
||||||
self.clip = clip
|
self.clip = XClipAdapter(clip)
|
||||||
else:
|
else:
|
||||||
assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given'
|
assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given'
|
||||||
self.clip = None
|
self.clip = None
|
||||||
@@ -610,29 +663,6 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
self.predict_x_start = predict_x_start
|
self.predict_x_start = predict_x_start
|
||||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def get_image_embed(self, image):
|
|
||||||
assert exists(self.clip)
|
|
||||||
|
|
||||||
image_encoding = self.clip.visual_transformer(image)
|
|
||||||
image_cls = image_encoding[:, 0]
|
|
||||||
image_embed = self.clip.to_visual_latent(image_cls)
|
|
||||||
return l2norm(image_embed)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def get_text_cond(self, text):
|
|
||||||
assert exists(self.clip)
|
|
||||||
|
|
||||||
text_encodings = self.clip.text_transformer(text)
|
|
||||||
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
|
||||||
text_embed = self.clip.to_text_latent(text_cls)
|
|
||||||
text_embed = l2norm(text_embed)
|
|
||||||
|
|
||||||
if not self.condition_on_text_encodings:
|
|
||||||
return dict(text_embed = text_embed)
|
|
||||||
|
|
||||||
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
|
|
||||||
|
|
||||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||||
pred = self.net(x, t, **text_cond)
|
pred = self.net(x, t, **text_cond)
|
||||||
|
|
||||||
@@ -704,7 +734,13 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
batch_size = text.shape[0]
|
batch_size = text.shape[0]
|
||||||
image_embed_dim = self.image_embed_dim
|
image_embed_dim = self.image_embed_dim
|
||||||
|
|
||||||
text_cond = self.get_text_cond(text)
|
text_embed, text_encodings = self.clip.embed_text(text)
|
||||||
|
|
||||||
|
text_cond = dict(
|
||||||
|
text_embed = text_embed,
|
||||||
|
text_encodings = text_encodings,
|
||||||
|
mask = text != 0
|
||||||
|
)
|
||||||
|
|
||||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||||
text_embeds = text_cond['text_embed']
|
text_embeds = text_cond['text_embed']
|
||||||
@@ -736,18 +772,19 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
||||||
|
|
||||||
if exists(image):
|
if exists(image):
|
||||||
image_embed = self.get_image_embed(image)
|
image_embed, _ = self.clip.embed_image(image)
|
||||||
|
|
||||||
# calculate text conditionings, based on what is passed in
|
# calculate text conditionings, based on what is passed in
|
||||||
|
|
||||||
if exists(text):
|
if exists(text):
|
||||||
text_cond = self.get_text_cond(text)
|
text_embed, text_encodings = self.clip.embed_text(text)
|
||||||
else:
|
text_mask = text != 0
|
||||||
text_cond = dict(
|
|
||||||
text_embed = text_embed,
|
text_cond = dict(
|
||||||
text_encodings = text_encodings,
|
text_embed = text_embed,
|
||||||
mask = text_mask
|
text_encodings = text_encodings,
|
||||||
)
|
mask = text_mask
|
||||||
|
)
|
||||||
|
|
||||||
# timestep conditioning from ddpm
|
# timestep conditioning from ddpm
|
||||||
|
|
||||||
@@ -1208,7 +1245,9 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
loss_type = loss_type
|
loss_type = loss_type
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(clip, CLIP)
|
if isinstance(clip, CLIP):
|
||||||
|
clip = XClipAdapter(clip)
|
||||||
|
|
||||||
freeze_model_and_make_eval_(clip)
|
freeze_model_and_make_eval_(clip)
|
||||||
self.clip = clip
|
self.clip = clip
|
||||||
self.clip_image_size = clip.image_size
|
self.clip_image_size = clip.image_size
|
||||||
@@ -1290,10 +1329,6 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
yield
|
yield
|
||||||
unet.cpu()
|
unet.cpu()
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def get_text_encodings(self, text):
|
|
||||||
text_encodings = self.clip.text_transformer(text)
|
|
||||||
return text_encodings[:, 1:]
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_image_embed(self, image):
|
def get_image_embed(self, image):
|
||||||
@@ -1379,7 +1414,9 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
def sample(self, image_embed, text = None, cond_scale = 1.):
|
||||||
batch_size = image_embed.shape[0]
|
batch_size = image_embed.shape[0]
|
||||||
|
|
||||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
text_encodings = None
|
||||||
|
if exists(text):
|
||||||
|
_, text_encodings = self.clip.embed_text(text)
|
||||||
|
|
||||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||||
|
|
||||||
@@ -1442,9 +1479,11 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||||
|
|
||||||
if not exists(image_embed):
|
if not exists(image_embed):
|
||||||
image_embed = self.get_image_embed(image)
|
image_embed, _ = self.clip.embed_image(image)
|
||||||
|
|
||||||
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
text_encodings = None
|
||||||
|
if exists(text) and not exists(text_encodings):
|
||||||
|
_, text_encodings = self.clip.embed_text(text)
|
||||||
|
|
||||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user