mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-21 06:44:21 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9646dfc0e6 | ||
|
|
62043acb2f |
@@ -2589,7 +2589,7 @@ class Decoder(nn.Module):
|
|||||||
if is_inpaint and not (is_last_timestep or is_last_resample_step):
|
if is_inpaint and not (is_last_timestep or is_last_resample_step):
|
||||||
# in repaint, you renoise and resample up to 10 times every step
|
# in repaint, you renoise and resample up to 10 times every step
|
||||||
time_next_cond = torch.full((batch,), time_next, device = device, dtype = torch.long)
|
time_next_cond = torch.full((batch,), time_next, device = device, dtype = torch.long)
|
||||||
img = noise_scheduler.q_sample_from_to(img, time_cond, time_next_cond)
|
img = noise_scheduler.q_sample_from_to(img, time_next_cond, time_cond)
|
||||||
|
|
||||||
if exists(inpaint_image):
|
if exists(inpaint_image):
|
||||||
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
|
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
|
||||||
|
|||||||
@@ -300,7 +300,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
# all processes need to load checkpoint. no restriction here
|
# all processes need to load checkpoint. no restriction here
|
||||||
if isinstance(path_or_state, str):
|
if isinstance(path_or_state, str):
|
||||||
path = Path(path)
|
path = Path(path_or_state)
|
||||||
assert path.exists()
|
assert path.exists()
|
||||||
loaded_obj = torch.load(str(path), map_location=self.device)
|
loaded_obj = torch.load(str(path), map_location=self.device)
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.0.4'
|
__version__ = '1.0.6'
|
||||||
|
|||||||
Reference in New Issue
Block a user