mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
fix ddim to use alpha_cumprod
This commit is contained in:
@@ -1259,7 +1259,7 @@ class DiffusionPrior(nn.Module):
|
||||
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
|
||||
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
|
||||
|
||||
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
||||
times = torch.linspace(-1., total_timesteps, steps = timesteps + 1)[:-1]
|
||||
|
||||
times = list(reversed(times.int().tolist()))
|
||||
time_pairs = list(zip(times[:-1], times[1:]))
|
||||
@@ -1294,6 +1294,10 @@ class DiffusionPrior(nn.Module):
|
||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||
x_start = self.l2norm_clamp_embed(x_start)
|
||||
|
||||
if time_next < 0:
|
||||
image_embed = x_start
|
||||
continue
|
||||
|
||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
|
||||
@@ -2845,12 +2849,13 @@ class Decoder(nn.Module):
|
||||
inpaint_mask = None,
|
||||
inpaint_resample_times = 5
|
||||
):
|
||||
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
|
||||
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod, self.ddim_sampling_eta
|
||||
|
||||
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
||||
|
||||
times = list(reversed(times.int().tolist()))
|
||||
time_pairs = list(zip(times[:-1], times[1:]))
|
||||
time_pairs = list(filter(lambda t: t[0] > t[1], time_pairs))
|
||||
|
||||
is_inpaint = exists(inpaint_image)
|
||||
resample_times = inpaint_resample_times if is_inpaint else 1
|
||||
|
||||
Reference in New Issue
Block a user