mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9646dfc0e6 | ||
|
|
62043acb2f |
@@ -2589,7 +2589,7 @@ class Decoder(nn.Module):
|
||||
if is_inpaint and not (is_last_timestep or is_last_resample_step):
|
||||
# in repaint, you renoise and resample up to 10 times every step
|
||||
time_next_cond = torch.full((batch,), time_next, device = device, dtype = torch.long)
|
||||
img = noise_scheduler.q_sample_from_to(img, time_cond, time_next_cond)
|
||||
img = noise_scheduler.q_sample_from_to(img, time_next_cond, time_cond)
|
||||
|
||||
if exists(inpaint_image):
|
||||
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
|
||||
|
||||
@@ -300,7 +300,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
# all processes need to load checkpoint. no restriction here
|
||||
if isinstance(path_or_state, str):
|
||||
path = Path(path)
|
||||
path = Path(path_or_state)
|
||||
assert path.exists()
|
||||
loaded_obj = torch.load(str(path), map_location=self.device)
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '1.0.4'
|
||||
__version__ = '1.0.6'
|
||||
|
||||
Reference in New Issue
Block a user