This commit is contained in:
Phil Wang
2022-04-18 12:43:10 -07:00
parent 960a79857b
commit c6bfd7fdc8

View File

@@ -197,10 +197,10 @@ clip = CLIP(
dim_image = 512, dim_image = 512,
dim_latent = 512, dim_latent = 512,
num_text_tokens = 49408, num_text_tokens = 49408,
text_enc_depth = 1, text_enc_depth = 6,
text_seq_len = 256, text_seq_len = 256,
text_heads = 8, text_heads = 8,
visual_enc_depth = 1, visual_enc_depth = 6,
visual_image_size = 256, visual_image_size = 256,
visual_patch_size = 32, visual_patch_size = 32,
visual_heads = 8 visual_heads = 8
@@ -209,14 +209,15 @@ clip = CLIP(
# 2 unets for the decoder (a la cascading DDPM) # 2 unets for the decoder (a la cascading DDPM)
unet1 = Unet( unet1 = Unet(
dim = 16, dim = 32,
image_embed_dim = 512, image_embed_dim = 512,
cond_dim = 128,
channels = 3, channels = 3,
dim_mults = (1, 2, 4, 8) dim_mults = (1, 2, 4, 8)
).cuda() ).cuda()
unet2 = Unet( unet2 = Unet(
dim = 16, dim = 32,
image_embed_dim = 512, image_embed_dim = 512,
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
@@ -228,8 +229,8 @@ unet2 = Unet(
decoder = Decoder( decoder = Decoder(
clip = clip, clip = clip,
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here) unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 100, timesteps = 1000,
cond_drop_prob = 0.2 cond_drop_prob = 0.2
).cuda() ).cuda()