Compare commits

..

3 Commits

3 changed files with 7 additions and 3 deletions

View File

@@ -1537,10 +1537,12 @@ class Unet(nn.Module):
# text encoding conditioning (optional)
self.text_to_cond = None
self.text_embed_dim = None
if cond_on_text_encodings:
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
self.text_embed_dim = text_embed_dim
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
@@ -1769,6 +1771,8 @@ class Unet(nn.Module):
text_tokens = None
if exists(text_encodings) and self.cond_on_text_encodings:
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
text_tokens = self.text_to_cond(text_encodings)
text_tokens = text_tokens[:, :self.max_text_len]
@@ -2043,7 +2047,7 @@ class Decoder(nn.Module):
self.noise_schedulers = nn.ModuleList([])
for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
assert sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
assert not exists(sample_timesteps) or sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
noise_scheduler = NoiseScheduler(
beta_schedule = unet_beta_schedule,

View File

@@ -234,7 +234,7 @@ class DecoderConfig(BaseModel):
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
sampling_timesteps: Optional[SingularOrIterable(int)] = None
sample_timesteps: Optional[SingularOrIterable(int)] = None
loss_type: str = 'l2'
beta_schedule: ListOrTuple(str) = 'cosine'
learned_variance: bool = True

View File

@@ -1 +1 @@
__version__ = '0.19.0'
__version__ = '0.19.3'