From 1cc5d0afa7e922f360ccabbc59a60bcbf3818dd2 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 25 Aug 2022 10:37:02 -0700 Subject: [PATCH] upgrade to best downsample --- README.md | 10 ++++++++++ dalle2_pytorch/dalle2_pytorch.py | 9 +++++++-- dalle2_pytorch/trainer.py | 2 +- dalle2_pytorch/version.py | 2 +- 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 5bae632..6a4052d 100644 --- a/README.md +++ b/README.md @@ -1285,4 +1285,14 @@ For detailed information on training the diffusion prior, please refer to the [d } ``` +```bibtex +@article{Sunkara2022NoMS, + title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects}, + author = {Raja Sunkara and Tie Luo}, + journal = {ArXiv}, + year = {2022}, + volume = {abs/2208.03641} +} +``` + *Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 5292d44..e87788d 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1479,9 +1479,14 @@ class PixelShuffleUpsample(nn.Module): def forward(self, x): return self.net(x) -def Downsample(dim, *, dim_out = None): +def Downsample(dim, dim_out = None): + # https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample + # named SP-conv in the paper, but basically a pixel unshuffle dim_out = default(dim_out, dim) - return nn.Conv2d(dim, dim_out, 4, 2, 1) + return nn.Sequential( + Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2), + nn.Conv2d(dim * 4, dim_out, 1) + ) class WeightStandardizedConv2d(nn.Conv2d): """ diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 9b756e0..a6b6bbc 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -519,7 +519,7 @@ class DecoderTrainer(nn.Module): clip = decoder.clip clip.to(precision_type) - decoder, train_dataloader, *optimizers = list(self.accelerator.prepare(decoder, dataloaders['train'], *optimizers)) + decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers)) self.decoder = decoder diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index e5102d3..52af183 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.9.0' +__version__ = '1.10.0'