start a file for all attention-related modules, use attention-based upsampling in the unets in dalle-2

This commit is contained in:
Phil Wang
2022-04-25 18:59:10 -07:00
parent 3b520dfa85
commit f75d49c781
4 changed files with 130 additions and 107 deletions

View File

@@ -17,6 +17,7 @@ from kornia.filters import gaussian_blur2d
from dalle2_pytorch.tokenizer import tokenizer
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
from dalle2_pytorch.attention import QueryAttnUpsample
# use x-clip
@@ -692,7 +693,7 @@ class DiffusionPrior(nn.Module):
# decoder
def Upsample(dim):
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
return QueryAttnUpsample(dim)
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)