Compare commits

...

3 Commits

4 changed files with 44 additions and 6 deletions

View File

@@ -760,7 +760,7 @@ decoder = Decoder(
unet = (unet1, unet2), unet = (unet1, unet2),
image_sizes = (128, 256), image_sizes = (128, 256),
clip = clip, clip = clip,
timesteps = 1, timesteps = 1000,
condition_on_text_encodings = True condition_on_text_encodings = True
).cuda() ).cuda()
@@ -778,6 +778,12 @@ for unet_number in (1, 2):
loss.backward() loss.backward()
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
# after much training
# you can sample from the exponentially moving averaged unets as so
mock_image_embed = torch.randn(4, 512).cuda()
images = decoder.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
``` ```
## CLI (wip) ## CLI (wip)

View File

@@ -1540,7 +1540,13 @@ class Decoder(BaseGaussianDiffusion):
@torch.no_grad() @torch.no_grad()
@eval_decorator @eval_decorator
def sample(self, image_embed, text = None, cond_scale = 1.): def sample(
self,
image_embed,
text = None,
cond_scale = 1.,
stop_at_unet_number = None
):
batch_size = image_embed.shape[0] batch_size = image_embed.shape[0]
text_encodings = text_mask = None text_encodings = text_mask = None
@@ -1552,7 +1558,7 @@ class Decoder(BaseGaussianDiffusion):
img = None img = None
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)): for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context() context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
@@ -1584,6 +1590,9 @@ class Decoder(BaseGaussianDiffusion):
img = vae.decode(img) img = vae.decode(img)
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
break
return img return img
def forward( def forward(

View File

@@ -144,6 +144,10 @@ class DecoderTrainer(nn.Module):
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
def scale(self, loss, *, unet_number): def scale(self, loss, *, unet_number):
assert 1 <= unet_number <= self.num_unets assert 1 <= unet_number <= self.num_unets
index = unet_number - 1 index = unet_number - 1
@@ -169,7 +173,26 @@ class DecoderTrainer(nn.Module):
ema_unet = self.ema_unets[index] ema_unet = self.ema_unets[index]
ema_unet.update() ema_unet.update()
def forward(self, x, *, unet_number, **kwargs): @torch.no_grad()
def sample(self, *args, **kwargs):
if self.use_ema:
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
output = self.decoder.sample(*args, **kwargs)
if self.use_ema:
self.decoder.unets = trainable_unets # restore original training unets
return output
def forward(
self,
x,
*,
unet_number,
divisor = 1,
**kwargs
):
with autocast(enabled = self.amp): with autocast(enabled = self.amp):
loss = self.decoder(x, unet_number = unet_number, **kwargs) loss = self.decoder(x, unet_number = unet_number, **kwargs)
return self.scale(loss, unet_number = unet_number) return self.scale(loss / divisor, unet_number = unet_number)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.79', version = '0.0.82',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',