Compare commits

..

1 Commits
0.0.6 ... 0.0.4

3 changed files with 7 additions and 9 deletions

View File

@@ -212,7 +212,10 @@ Let's see the whole script below
```python ```python
import torch import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP from dalle2_pytorch.dalle2_pytorch import DALLE2
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP
import torch
clip = CLIP( clip = CLIP(
dim_text = 512, dim_text = 512,
@@ -301,8 +304,6 @@ images = dalle2(['cute puppy chasing after a squirrel'])
Everything in this readme should run without error Everything in this readme should run without error
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## Training CLI (wip) ## Training CLI (wip)
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a> <a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
@@ -364,5 +365,3 @@ For the layperson, no worries, training will all be automated into a CLI tool, a
primaryClass = {cs.LG} primaryClass = {cs.LG}
} }
``` ```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>

View File

@@ -374,13 +374,12 @@ class DiffusionPrior(nn.Module):
image_encoding = self.clip.visual_transformer(image) image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0] image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls) image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed) return image_embed
def get_text_cond(self, text): def get_text_cond(self, text):
text_encodings = self.clip.text_transformer(text) text_encodings = self.clip.text_transformer(text)
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:] text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
text_embed = self.clip.to_text_latent(text_cls) text_embed = self.clip.to_text_latent(text_cls)
text_embed = l2norm(text_embed)
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0) return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
def q_mean_variance(self, x_start, t): def q_mean_variance(self, x_start, t):
@@ -751,7 +750,7 @@ class Decoder(nn.Module):
image_encoding = self.clip.visual_transformer(image) image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0] image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls) image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed) return image_embed
def q_mean_variance(self, x_start, t): def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.5', version = '0.0.4',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',