mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 18:24:19 +01:00
provide option to l2norm the output of the diffusion prior
This commit is contained in:
@@ -29,6 +29,9 @@ from x_clip import CLIP
|
|||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
|
def identity(t, *args, **kwargs):
|
||||||
|
return t
|
||||||
|
|
||||||
def default(val, d):
|
def default(val, d):
|
||||||
if exists(val):
|
if exists(val):
|
||||||
return val
|
return val
|
||||||
@@ -635,12 +638,14 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
num_timesteps = None,
|
num_timesteps = None,
|
||||||
|
l2norm_output = False, # whether to restrict image embedding output with l2norm at the end (may make it easier to learn?)
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||||
|
self.l2norm_output = l2norm_output
|
||||||
|
|
||||||
def forward_with_cond_scale(
|
def forward_with_cond_scale(
|
||||||
self,
|
self,
|
||||||
@@ -719,7 +724,8 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
|
|
||||||
pred_image_embed = tokens[..., -1, :]
|
pred_image_embed = tokens[..., -1, :]
|
||||||
|
|
||||||
return pred_image_embed
|
output_fn = l2norm if self.l2norm_output else identity
|
||||||
|
return output_fn(pred_image_embed)
|
||||||
|
|
||||||
class DiffusionPrior(BaseGaussianDiffusion):
|
class DiffusionPrior(BaseGaussianDiffusion):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Reference in New Issue
Block a user