make sure entire readme runs without errors

This commit is contained in:
Phil Wang
2022-07-28 10:17:43 -07:00
parent 36fb46a95e
commit 80046334ad
4 changed files with 14 additions and 12 deletions

View File

@@ -396,7 +396,7 @@ decoder = Decoder(
).cuda() ).cuda()
for unet_number in (1, 2): for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward() loss.backward()
# do above for many steps # do above for many steps
@@ -861,25 +861,23 @@ unet1 = Unet(
text_embed_dim = 512, text_embed_dim = 512,
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults=(1, 2, 4, 8) dim_mults=(1, 2, 4, 8),
cond_on_text_encodings = True,
).cuda() ).cuda()
unet2 = Unet( unet2 = Unet(
dim = 16, dim = 16,
image_embed_dim = 512, image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults = (1, 2, 4, 8, 16), dim_mults = (1, 2, 4, 8, 16),
cond_on_text_encodings = True
).cuda() ).cuda()
decoder = Decoder( decoder = Decoder(
unet = (unet1, unet2), unet = (unet1, unet2),
image_sizes = (128, 256), image_sizes = (128, 256),
clip = clip, clip = clip,
timesteps = 1000, timesteps = 1000
condition_on_text_encodings = True
).cuda() ).cuda()
decoder_trainer = DecoderTrainer( decoder_trainer = DecoderTrainer(
@@ -904,8 +902,8 @@ for unet_number in (1, 2):
# after much training # after much training
# you can sample from the exponentially moving averaged unets as so # you can sample from the exponentially moving averaged unets as so
mock_image_embed = torch.randn(4, 512).cuda() mock_image_embed = torch.randn(32, 512).cuda()
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256) images = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256)
``` ```
### Diffusion Prior Training ### Diffusion Prior Training

View File

@@ -1831,7 +1831,7 @@ class Unet(nn.Module):
channels == self.channels and \ channels == self.channels and \
cond_on_image_embeds == self.cond_on_image_embeds and \ cond_on_image_embeds == self.cond_on_image_embeds and \
cond_on_text_encodings == self.cond_on_text_encodings and \ cond_on_text_encodings == self.cond_on_text_encodings and \
cond_on_lowres_noise == self.cond_on_lowres_noise and \ lowres_noise_cond == self.lowres_noise_cond and \
channels_out == self.channels_out: channels_out == self.channels_out:
return self return self

View File

@@ -174,7 +174,7 @@ class DiffusionPriorTrainer(nn.Module):
def __init__( def __init__(
self, self,
diffusion_prior, diffusion_prior,
accelerator, accelerator = None,
use_ema = True, use_ema = True,
lr = 3e-4, lr = 3e-4,
wd = 1e-2, wd = 1e-2,
@@ -186,8 +186,12 @@ class DiffusionPriorTrainer(nn.Module):
): ):
super().__init__() super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior) assert isinstance(diffusion_prior, DiffusionPrior)
assert isinstance(accelerator, Accelerator)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)
if not exists(accelerator):
accelerator = Accelerator(**accelerator_kwargs)
# assign some helpful member vars # assign some helpful member vars

View File

@@ -1 +1 @@
__version__ = '1.2.1' __version__ = '1.2.2'