Compare commits

...

2 Commits

3 changed files with 26 additions and 10 deletions

View File

@@ -820,8 +820,8 @@ clip = CLIP(
# mock data
text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(32, 3, 256, 256).cuda()
text = torch.randint(0, 49408, (512, 256)).cuda()
images = torch.randn(512, 3, 256, 256).cuda()
# prior networks (with transformer)
@@ -854,7 +854,7 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
# after much of the above three lines in a loop
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
image_embeds = diffusion_prior_trainer.sample(text, max_batch_size = 4) # (512, 512) - exponential moving averaged image embeddings
```
## Bonus
@@ -867,7 +867,7 @@ ex.
```python
import torch
from dalle2_pytorch import Unet, Decoder
from dalle2_pytorch import Unet, Decoder, DecoderTrainer
# unet for the cascading ddpm
@@ -890,20 +890,24 @@ decoder = Decoder(
unconditional = True
).cuda()
# mock images (get a lot of this)
# decoder trainer
decoder_trainer = DecoderTrainer(decoder)
# images (get a lot of this)
images = torch.randn(1, 3, 512, 512).cuda()
# feed images into decoder
for i in (1, 2):
loss = decoder(images, unet_number = i)
loss.backward()
loss = decoder_trainer(images, unet_number = i)
decoder_trainer.update(unet_number = i)
# do the above for many many many many steps
# do the above for many many many many images
# then it will learn to generate images
images = decoder.sample(batch_size = 2) # (2, 3, 512, 512)
images = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512)
```
## Dataloaders

View File

@@ -235,6 +235,16 @@ class EMA(nn.Module):
# diffusion prior trainer
def prior_sample_in_chunks(fn):
@wraps(fn)
def inner(self, *args, max_batch_size = None, **kwargs):
if not exists(max_batch_size):
return fn(self, *args, **kwargs)
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
return torch.cat(outputs, dim = 0)
return inner
class DiffusionPriorTrainer(nn.Module):
def __init__(
self,
@@ -295,11 +305,13 @@ class DiffusionPriorTrainer(nn.Module):
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.46',
version = '0.3.0',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',