mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
DecoderTrainer sample method uses the exponentially moving averaged
This commit is contained in:
@@ -760,7 +760,7 @@ decoder = Decoder(
|
|||||||
unet = (unet1, unet2),
|
unet = (unet1, unet2),
|
||||||
image_sizes = (128, 256),
|
image_sizes = (128, 256),
|
||||||
clip = clip,
|
clip = clip,
|
||||||
timesteps = 1,
|
timesteps = 1000,
|
||||||
condition_on_text_encodings = True
|
condition_on_text_encodings = True
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
@@ -778,6 +778,12 @@ for unet_number in (1, 2):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
||||||
|
|
||||||
|
# after much training
|
||||||
|
# you can sample from the exponentially moving averaged unets as so
|
||||||
|
|
||||||
|
mock_image_embed = torch.randn(4, 512).cuda()
|
||||||
|
images = decoder.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||||
```
|
```
|
||||||
|
|
||||||
## CLI (wip)
|
## CLI (wip)
|
||||||
|
|||||||
@@ -144,6 +144,10 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unets(self):
|
||||||
|
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||||
|
|
||||||
def scale(self, loss, *, unet_number):
|
def scale(self, loss, *, unet_number):
|
||||||
assert 1 <= unet_number <= self.num_unets
|
assert 1 <= unet_number <= self.num_unets
|
||||||
index = unet_number - 1
|
index = unet_number - 1
|
||||||
@@ -169,6 +173,18 @@ class DecoderTrainer(nn.Module):
|
|||||||
ema_unet = self.ema_unets[index]
|
ema_unet = self.ema_unets[index]
|
||||||
ema_unet.update()
|
ema_unet.update()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self, *args, **kwargs):
|
||||||
|
if self.use_ema:
|
||||||
|
trainable_unets = self.decoder.unets
|
||||||
|
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||||
|
|
||||||
|
output = self.decoder.sample(*args, **kwargs)
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
self.decoder.unets = trainable_unets # restore original training unets
|
||||||
|
return output
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
|
|||||||
Reference in New Issue
Block a user