mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
final update to dalle2 repository for a while - sampling from prior in chunks automatically with max_batch_size keyword given
This commit is contained in:
@@ -820,8 +820,8 @@ clip = CLIP(
|
|||||||
|
|
||||||
# mock data
|
# mock data
|
||||||
|
|
||||||
text = torch.randint(0, 49408, (32, 256)).cuda()
|
text = torch.randint(0, 49408, (512, 256)).cuda()
|
||||||
images = torch.randn(32, 3, 256, 256).cuda()
|
images = torch.randn(512, 3, 256, 256).cuda()
|
||||||
|
|
||||||
# prior networks (with transformer)
|
# prior networks (with transformer)
|
||||||
|
|
||||||
@@ -854,7 +854,7 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
|
|||||||
# after much of the above three lines in a loop
|
# after much of the above three lines in a loop
|
||||||
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
|
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
|
||||||
|
|
||||||
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
|
image_embeds = diffusion_prior_trainer.sample(text, max_batch_size = 4) # (512, 512) - exponential moving averaged image embeddings
|
||||||
```
|
```
|
||||||
|
|
||||||
## Bonus
|
## Bonus
|
||||||
|
|||||||
@@ -235,6 +235,16 @@ class EMA(nn.Module):
|
|||||||
|
|
||||||
# diffusion prior trainer
|
# diffusion prior trainer
|
||||||
|
|
||||||
|
def prior_sample_in_chunks(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(self, *args, max_batch_size = None, **kwargs):
|
||||||
|
if not exists(max_batch_size):
|
||||||
|
return fn(self, *args, **kwargs)
|
||||||
|
|
||||||
|
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
|
||||||
|
return torch.cat(outputs, dim = 0)
|
||||||
|
return inner
|
||||||
|
|
||||||
class DiffusionPriorTrainer(nn.Module):
|
class DiffusionPriorTrainer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -295,11 +305,13 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
|
@prior_sample_in_chunks
|
||||||
def p_sample_loop(self, *args, **kwargs):
|
def p_sample_loop(self, *args, **kwargs):
|
||||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
|
@prior_sample_in_chunks
|
||||||
def sample(self, *args, **kwargs):
|
def sample(self, *args, **kwargs):
|
||||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user