mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 19:44:26 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa900213e7 | ||
|
|
cb26187450 |
@@ -1066,13 +1066,14 @@ class Unet(nn.Module):
|
||||
self,
|
||||
*,
|
||||
lowres_cond,
|
||||
channels
|
||||
channels,
|
||||
cond_on_image_embeds
|
||||
):
|
||||
if lowres_cond == self.lowres_cond and channels == self.channels:
|
||||
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds:
|
||||
return self
|
||||
|
||||
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels}
|
||||
return self.__class__(**updated_kwargs)
|
||||
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds}
|
||||
return self.__class__(**{**self._locals, **updated_kwargs})
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
@@ -1279,6 +1280,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
one_unet = one_unet.cast_model_parameters(
|
||||
lowres_cond = not is_first,
|
||||
cond_on_image_embeds = is_first,
|
||||
channels = unet_channels
|
||||
)
|
||||
|
||||
|
||||
@@ -545,6 +545,7 @@ class VQGanVAE(nn.Module):
|
||||
l2_recon_loss = False,
|
||||
use_hinge_loss = True,
|
||||
vgg = None,
|
||||
vq_codebook_dim = 256,
|
||||
vq_codebook_size = 512,
|
||||
vq_decay = 0.8,
|
||||
vq_commitment_weight = 1.,
|
||||
@@ -579,6 +580,7 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
self.vq = VQ(
|
||||
dim = self.enc_dec.encoded_dim,
|
||||
codebook_dim = vq_codebook_dim,
|
||||
codebook_size = vq_codebook_size,
|
||||
decay = vq_decay,
|
||||
commitment_weight = vq_commitment_weight,
|
||||
|
||||
Reference in New Issue
Block a user