diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 5712c6a..89d64e9 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -149,7 +149,7 @@ class Transformer(nn.Module): return self.norm(x) -class PriorNetwork(nn.Module): +class DiffusionPriorNetwork(nn.Module): def __init__( self, dim,