|
|
|
@@ -100,6 +100,9 @@ def eval_decorator(fn):
|
|
|
|
return out
|
|
|
|
return out
|
|
|
|
return inner
|
|
|
|
return inner
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_float_dtype(dtype):
|
|
|
|
|
|
|
|
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
|
|
|
|
|
|
|
|
|
|
|
|
def is_list_str(x):
|
|
|
|
def is_list_str(x):
|
|
|
|
if not isinstance(x, (list, tuple)):
|
|
|
|
if not isinstance(x, (list, tuple)):
|
|
|
|
return False
|
|
|
|
return False
|
|
|
|
@@ -386,6 +389,8 @@ class OpenClipAdapter(BaseClipAdapter):
|
|
|
|
self.eos_id = 49407
|
|
|
|
self.eos_id = 49407
|
|
|
|
|
|
|
|
|
|
|
|
text_attention_final = self.find_layer('ln_final')
|
|
|
|
text_attention_final = self.find_layer('ln_final')
|
|
|
|
|
|
|
|
self._dim_latent = text_attention_final.weight.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
self.handle = text_attention_final.register_forward_hook(self._hook)
|
|
|
|
self.handle = text_attention_final.register_forward_hook(self._hook)
|
|
|
|
self.clip_normalize = preprocess.transforms[-1]
|
|
|
|
self.clip_normalize = preprocess.transforms[-1]
|
|
|
|
self.cleared = False
|
|
|
|
self.cleared = False
|
|
|
|
@@ -405,7 +410,7 @@ class OpenClipAdapter(BaseClipAdapter):
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def dim_latent(self):
|
|
|
|
def dim_latent(self):
|
|
|
|
return 512
|
|
|
|
return self._dim_latent
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def image_size(self):
|
|
|
|
def image_size(self):
|
|
|
|
@@ -968,6 +973,8 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
Rearrange('b (n d) -> b n d', n = num_text_embeds)
|
|
|
|
Rearrange('b (n d) -> b n d', n = num_text_embeds)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.continuous_embedded_time = not exists(num_timesteps)
|
|
|
|
|
|
|
|
|
|
|
|
self.to_time_embeds = nn.Sequential(
|
|
|
|
self.to_time_embeds = nn.Sequential(
|
|
|
|
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
|
|
|
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
|
|
|
Rearrange('b (n d) -> b n d', n = num_time_embeds)
|
|
|
|
Rearrange('b (n d) -> b n d', n = num_time_embeds)
|
|
|
|
@@ -1095,6 +1102,9 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
|
|
|
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
|
|
|
# but let's just do it right
|
|
|
|
# but let's just do it right
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.continuous_embedded_time:
|
|
|
|
|
|
|
|
diffusion_timesteps = diffusion_timesteps.type(dtype)
|
|
|
|
|
|
|
|
|
|
|
|
time_embed = self.to_time_embeds(diffusion_timesteps)
|
|
|
|
time_embed = self.to_time_embeds(diffusion_timesteps)
|
|
|
|
|
|
|
|
|
|
|
|
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
|
|
|
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
|
|
|
@@ -1432,7 +1442,7 @@ class DiffusionPrior(nn.Module):
|
|
|
|
**kwargs
|
|
|
|
**kwargs
|
|
|
|
):
|
|
|
|
):
|
|
|
|
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
|
|
|
|
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
|
|
|
|
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied'
|
|
|
|
assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'
|
|
|
|
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
|
|
|
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
|
|
|
|
|
|
|
|
|
|
|
if exists(image):
|
|
|
|
if exists(image):
|
|
|
|
@@ -1538,6 +1548,8 @@ class SinusoidalPosEmb(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
dtype, device = x.dtype, x.device
|
|
|
|
dtype, device = x.dtype, x.device
|
|
|
|
|
|
|
|
assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'
|
|
|
|
|
|
|
|
|
|
|
|
half_dim = self.dim // 2
|
|
|
|
half_dim = self.dim // 2
|
|
|
|
emb = math.log(10000) / (half_dim - 1)
|
|
|
|
emb = math.log(10000) / (half_dim - 1)
|
|
|
|
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
|
|
|
|
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
|
|
|
|
|