diff --git a/README.md b/README.md index cad0dbb..e08eaa2 100644 --- a/README.md +++ b/README.md @@ -628,6 +628,82 @@ images = dalle2( Now you'll just have to worry about training the Prior and the Decoder! +## Inpainting + +Inpainting is also built into the `Decoder`. You simply have to pass in the `inpaint_image` and `inpaint_mask` (boolean tensor where `True` indicates which regions of the inpaint image to keep) + +This repository uses the formulation put forth by Lugmayr et al. in Repaint + +```python +import torch +from dalle2_pytorch import Unet, Decoder, CLIP + +# trained clip from step 1 + +clip = CLIP( + dim_text = 512, + dim_image = 512, + dim_latent = 512, + num_text_tokens = 49408, + text_enc_depth = 6, + text_seq_len = 256, + text_heads = 8, + visual_enc_depth = 6, + visual_image_size = 256, + visual_patch_size = 32, + visual_heads = 8 +).cuda() + +# 2 unets for the decoder (a la cascading DDPM) + +unet = Unet( + dim = 16, + image_embed_dim = 512, + cond_dim = 128, + channels = 3, + dim_mults = (1, 1, 1, 1) +).cuda() + + +# decoder, which contains the unet(s) and clip + +decoder = Decoder( + clip = clip, + unet = (unet,), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here) + image_sizes = (256,), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in) + timesteps = 1000, + image_cond_drop_prob = 0.1, + text_cond_drop_prob = 0.5 +).cuda() + +# mock images (get a lot of this) + +images = torch.randn(4, 3, 256, 256).cuda() + +# feed images into decoder, specifying which unet you want to train +# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme + +loss = decoder(images, unet_number = 1) +loss.backward() + +# do the above for many steps for both unets + +mock_image_embed = torch.randn(1, 512).cuda() + +# then to do inpainting + +inpaint_image = torch.randn(1, 3, 256, 256).cuda() # (batch, channels, height, width) +inpaint_mask = torch.ones(1, 256, 256).bool().cuda() # (batch, height, width) + +inpainted_images = decoder.sample( + image_embed = mock_image_embed, + inpaint_image = inpaint_image, # just pass in the inpaint image + inpaint_mask = inpaint_mask # and the mask +) + +inpainted_images.shape # (1, 3, 256, 256) +``` + ## Experimental ### DALL-E2 with Latent Diffusion @@ -1169,4 +1245,14 @@ Once built, images will be saved to the same directory the command is invoked } ``` +```bibtex +@article{Lugmayr2022RePaintIU, + title = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models}, + author = {Andreas Lugmayr and Martin Danelljan and Andr{\'e}s Romero and Fisher Yu and Radu Timofte and Luc Van Gool}, + journal = {ArXiv}, + year = {2022}, + volume = {abs/2201.09865} +} +``` + *Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper