DecoderTrainer sample method uses the exponentially moving averaged

This commit is contained in:
Phil Wang
2022-04-30 14:55:34 -07:00
parent 63195cc2cb
commit ebe01749ed
3 changed files with 24 additions and 2 deletions

View File

@@ -760,7 +760,7 @@ decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 1,
timesteps = 1000,
condition_on_text_encodings = True
).cuda()
@@ -778,6 +778,12 @@ for unet_number in (1, 2):
loss.backward()
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
# after much training
# you can sample from the exponentially moving averaged unets as so
mock_image_embed = torch.randn(4, 512).cuda()
images = decoder.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
```
## CLI (wip)

View File

@@ -144,6 +144,10 @@ class DecoderTrainer(nn.Module):
self.max_grad_norm = max_grad_norm
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
def scale(self, loss, *, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
@@ -169,6 +173,18 @@ class DecoderTrainer(nn.Module):
ema_unet = self.ema_unets[index]
ema_unet.update()
@torch.no_grad()
def sample(self, *args, **kwargs):
if self.use_ema:
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
output = self.decoder.sample(*args, **kwargs)
if self.use_ema:
self.decoder.unets = trainable_unets # restore original training unets
return output
def forward(
self,
x,

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.80',
version = '0.0.81',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',