mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-20 22:34:21 +01:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0692f1699f | ||
|
|
26c4534bc3 | ||
|
|
5e06cde4cb | ||
|
|
a1a8a78f21 |
@@ -212,10 +212,7 @@ Let's see the whole script below
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
from dalle2_pytorch.dalle2_pytorch import DALLE2
|
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP
|
||||||
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
clip = CLIP(
|
clip = CLIP(
|
||||||
dim_text = 512,
|
dim_text = 512,
|
||||||
@@ -304,6 +301,8 @@ images = dalle2(['cute puppy chasing after a squirrel'])
|
|||||||
|
|
||||||
Everything in this readme should run without error
|
Everything in this readme should run without error
|
||||||
|
|
||||||
|
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
|
||||||
|
|
||||||
## Training CLI (wip)
|
## Training CLI (wip)
|
||||||
|
|
||||||
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
|
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
|
||||||
@@ -365,3 +364,5 @@ Everything in this readme should run without error
|
|||||||
primaryClass = {cs.LG}
|
primaryClass = {cs.LG}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
||||||
|
|||||||
@@ -374,12 +374,13 @@ class DiffusionPrior(nn.Module):
|
|||||||
image_encoding = self.clip.visual_transformer(image)
|
image_encoding = self.clip.visual_transformer(image)
|
||||||
image_cls = image_encoding[:, 0]
|
image_cls = image_encoding[:, 0]
|
||||||
image_embed = self.clip.to_visual_latent(image_cls)
|
image_embed = self.clip.to_visual_latent(image_cls)
|
||||||
return image_embed
|
return l2norm(image_embed)
|
||||||
|
|
||||||
def get_text_cond(self, text):
|
def get_text_cond(self, text):
|
||||||
text_encodings = self.clip.text_transformer(text)
|
text_encodings = self.clip.text_transformer(text)
|
||||||
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
||||||
text_embed = self.clip.to_text_latent(text_cls)
|
text_embed = self.clip.to_text_latent(text_cls)
|
||||||
|
text_embed = l2norm(text_embed)
|
||||||
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
|
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
|
||||||
|
|
||||||
def q_mean_variance(self, x_start, t):
|
def q_mean_variance(self, x_start, t):
|
||||||
@@ -750,7 +751,7 @@ class Decoder(nn.Module):
|
|||||||
image_encoding = self.clip.visual_transformer(image)
|
image_encoding = self.clip.visual_transformer(image)
|
||||||
image_cls = image_encoding[:, 0]
|
image_cls = image_encoding[:, 0]
|
||||||
image_embed = self.clip.to_visual_latent(image_cls)
|
image_embed = self.clip.to_visual_latent(image_cls)
|
||||||
return image_embed
|
return l2norm(image_embed)
|
||||||
|
|
||||||
def q_mean_variance(self, x_start, t):
|
def q_mean_variance(self, x_start, t):
|
||||||
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||||
|
|||||||
Reference in New Issue
Block a user