mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 19:44:26 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
85ed77d512 | ||
|
|
fd53fa17db |
@@ -902,7 +902,7 @@ Please note that the script internally passes text_embed and image_embed to the
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
$ pyhon train_diffusion_prior.py
|
||||
$ python train_diffusion_prior.py
|
||||
```
|
||||
|
||||
The most significant parameters for the script are as follows:
|
||||
|
||||
@@ -765,7 +765,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
# but let's just do it right
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
time_embed = self.time_embeddings(diffusion_timesteps)
|
||||
time_embed = rearrange(time_embed, 'b d -> b 1 d')
|
||||
@@ -776,6 +776,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
text_encodings,
|
||||
text_embed,
|
||||
time_embed,
|
||||
image_embed,
|
||||
learned_queries
|
||||
), dim = -2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user