be transparent

This commit is contained in:
Phil Wang
2022-04-13 10:32:11 -07:00
parent 1bf071af78
commit 3aa6f91e7a
2 changed files with 7 additions and 2 deletions

View File

@@ -363,7 +363,12 @@ class DiffusionPrior(nn.Module):
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, **text_cond))
if self.predict_x0:
x_recon = self.net(x, t, **text_cond)
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, **text_cond))
if clip_denoised:
x_recon.clamp_(-1., 1.)