mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 23:44:22 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
85ed77d512 | ||
|
|
fd53fa17db |
@@ -902,7 +902,7 @@ Please note that the script internally passes text_embed and image_embed to the
|
|||||||
### Usage
|
### Usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ pyhon train_diffusion_prior.py
|
$ python train_diffusion_prior.py
|
||||||
```
|
```
|
||||||
|
|
||||||
The most significant parameters for the script are as follows:
|
The most significant parameters for the script are as follows:
|
||||||
|
|||||||
@@ -765,7 +765,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
# but let's just do it right
|
# but let's just do it right
|
||||||
|
|
||||||
if exists(mask):
|
if exists(mask):
|
||||||
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||||
|
|
||||||
time_embed = self.time_embeddings(diffusion_timesteps)
|
time_embed = self.time_embeddings(diffusion_timesteps)
|
||||||
time_embed = rearrange(time_embed, 'b d -> b 1 d')
|
time_embed = rearrange(time_embed, 'b d -> b 1 d')
|
||||||
@@ -776,6 +776,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
text_encodings,
|
text_encodings,
|
||||||
text_embed,
|
text_embed,
|
||||||
time_embed,
|
time_embed,
|
||||||
|
image_embed,
|
||||||
learned_queries
|
learned_queries
|
||||||
), dim = -2)
|
), dim = -2)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user