mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8260fc933a | ||
|
|
ebe01749ed | ||
|
|
63195cc2cb |
@@ -760,7 +760,7 @@ decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 1,
|
||||
timesteps = 1000,
|
||||
condition_on_text_encodings = True
|
||||
).cuda()
|
||||
|
||||
@@ -778,6 +778,12 @@ for unet_number in (1, 2):
|
||||
loss.backward()
|
||||
|
||||
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
||||
|
||||
# after much training
|
||||
# you can sample from the exponentially moving averaged unets as so
|
||||
|
||||
mock_image_embed = torch.randn(4, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||
```
|
||||
|
||||
## CLI (wip)
|
||||
|
||||
@@ -1540,7 +1540,13 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
||||
def sample(
|
||||
self,
|
||||
image_embed,
|
||||
text = None,
|
||||
cond_scale = 1.,
|
||||
stop_at_unet_number = None
|
||||
):
|
||||
batch_size = image_embed.shape[0]
|
||||
|
||||
text_encodings = text_mask = None
|
||||
@@ -1552,7 +1558,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
img = None
|
||||
|
||||
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
|
||||
|
||||
@@ -1584,6 +1590,9 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
img = vae.decode(img)
|
||||
|
||||
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
|
||||
break
|
||||
|
||||
return img
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -144,6 +144,10 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
@property
|
||||
def unets(self):
|
||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||
|
||||
def scale(self, loss, *, unet_number):
|
||||
assert 1 <= unet_number <= self.num_unets
|
||||
index = unet_number - 1
|
||||
@@ -169,7 +173,26 @@ class DecoderTrainer(nn.Module):
|
||||
ema_unet = self.ema_unets[index]
|
||||
ema_unet.update()
|
||||
|
||||
def forward(self, x, *, unet_number, **kwargs):
|
||||
@torch.no_grad()
|
||||
def sample(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
trainable_unets = self.decoder.unets
|
||||
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||
|
||||
output = self.decoder.sample(*args, **kwargs)
|
||||
|
||||
if self.use_ema:
|
||||
self.decoder.unets = trainable_unets # restore original training unets
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
*,
|
||||
unet_number,
|
||||
divisor = 1,
|
||||
**kwargs
|
||||
):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.decoder(x, unet_number = unet_number, **kwargs)
|
||||
return self.scale(loss, unet_number = unet_number)
|
||||
return self.scale(loss / divisor, unet_number = unet_number)
|
||||
|
||||
Reference in New Issue
Block a user