mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 15:44:20 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7ea8748db | ||
|
|
13382885d9 | ||
|
|
c3d4a7ffe4 |
22
README.md
22
README.md
@@ -820,8 +820,8 @@ clip = CLIP(
|
|||||||
|
|
||||||
# mock data
|
# mock data
|
||||||
|
|
||||||
text = torch.randint(0, 49408, (32, 256)).cuda()
|
text = torch.randint(0, 49408, (512, 256)).cuda()
|
||||||
images = torch.randn(32, 3, 256, 256).cuda()
|
images = torch.randn(512, 3, 256, 256).cuda()
|
||||||
|
|
||||||
# prior networks (with transformer)
|
# prior networks (with transformer)
|
||||||
|
|
||||||
@@ -854,7 +854,7 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
|
|||||||
# after much of the above three lines in a loop
|
# after much of the above three lines in a loop
|
||||||
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
|
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
|
||||||
|
|
||||||
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
|
image_embeds = diffusion_prior_trainer.sample(text, max_batch_size = 4) # (512, 512) - exponential moving averaged image embeddings
|
||||||
```
|
```
|
||||||
|
|
||||||
## Bonus
|
## Bonus
|
||||||
@@ -867,7 +867,7 @@ ex.
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
from dalle2_pytorch import Unet, Decoder
|
from dalle2_pytorch import Unet, Decoder, DecoderTrainer
|
||||||
|
|
||||||
# unet for the cascading ddpm
|
# unet for the cascading ddpm
|
||||||
|
|
||||||
@@ -890,20 +890,24 @@ decoder = Decoder(
|
|||||||
unconditional = True
|
unconditional = True
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
# mock images (get a lot of this)
|
# decoder trainer
|
||||||
|
|
||||||
|
decoder_trainer = DecoderTrainer(decoder)
|
||||||
|
|
||||||
|
# images (get a lot of this)
|
||||||
|
|
||||||
images = torch.randn(1, 3, 512, 512).cuda()
|
images = torch.randn(1, 3, 512, 512).cuda()
|
||||||
|
|
||||||
# feed images into decoder
|
# feed images into decoder
|
||||||
|
|
||||||
for i in (1, 2):
|
for i in (1, 2):
|
||||||
loss = decoder(images, unet_number = i)
|
loss = decoder_trainer(images, unet_number = i)
|
||||||
loss.backward()
|
decoder_trainer.update(unet_number = i)
|
||||||
|
|
||||||
# do the above for many many many many steps
|
# do the above for many many many many images
|
||||||
# then it will learn to generate images
|
# then it will learn to generate images
|
||||||
|
|
||||||
images = decoder.sample(batch_size = 2) # (2, 3, 512, 512)
|
images = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Dataloaders
|
## Dataloaders
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ def separate_weight_decayable_params(params):
|
|||||||
|
|
||||||
def get_optimizer(
|
def get_optimizer(
|
||||||
params,
|
params,
|
||||||
lr = 2e-5,
|
lr = 1e-4,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
betas = (0.9, 0.999),
|
betas = (0.9, 0.999),
|
||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
|
|||||||
@@ -235,6 +235,16 @@ class EMA(nn.Module):
|
|||||||
|
|
||||||
# diffusion prior trainer
|
# diffusion prior trainer
|
||||||
|
|
||||||
|
def prior_sample_in_chunks(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(self, *args, max_batch_size = None, **kwargs):
|
||||||
|
if not exists(max_batch_size):
|
||||||
|
return fn(self, *args, **kwargs)
|
||||||
|
|
||||||
|
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
|
||||||
|
return torch.cat(outputs, dim = 0)
|
||||||
|
return inner
|
||||||
|
|
||||||
class DiffusionPriorTrainer(nn.Module):
|
class DiffusionPriorTrainer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -295,11 +305,13 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
|
@prior_sample_in_chunks
|
||||||
def p_sample_loop(self, *args, **kwargs):
|
def p_sample_loop(self, *args, **kwargs):
|
||||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
|
@prior_sample_in_chunks
|
||||||
def sample(self, *args, **kwargs):
|
def sample(self, *args, **kwargs):
|
||||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||||
|
|
||||||
@@ -351,7 +363,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
decoder,
|
decoder,
|
||||||
use_ema = True,
|
use_ema = True,
|
||||||
lr = 2e-5,
|
lr = 1e-4,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
|
|||||||
Reference in New Issue
Block a user