mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 15:44:20 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
301a97197f |
@@ -1004,9 +1004,9 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
|
|
||||||
# setup self conditioning
|
# setup self conditioning
|
||||||
|
|
||||||
self_cond = None
|
|
||||||
if self.self_cond:
|
if self.self_cond:
|
||||||
self_cond = default(self_cond, lambda: torch.zeros(batch, 1, self.dim, device = device, dtype = dtype))
|
self_cond = default(self_cond, lambda: torch.zeros(batch, self.dim, device = device, dtype = dtype))
|
||||||
|
self_cond = rearrange(self_cond, 'b d -> b 1 d')
|
||||||
|
|
||||||
# in section 2.2, last paragraph
|
# in section 2.2, last paragraph
|
||||||
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
||||||
@@ -1287,7 +1287,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
|
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
|
||||||
|
|
||||||
self_cond = None
|
self_cond = None
|
||||||
if self.net.self_cond and random.random() < 0.5:
|
if self.net.self_cond and random.random() < 1.5:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
|
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.6.1'
|
__version__ = '1.6.2'
|
||||||
|
|||||||
Reference in New Issue
Block a user