mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
complete vit-vqgan from https://arxiv.org/abs/2110.04627
This commit is contained in:
@@ -644,6 +644,7 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
||||||
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
|
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
|
||||||
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
|
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
|
||||||
|
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||||
- [ ] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
- [ ] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||||
@@ -651,7 +652,6 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||||
- [ ] bring in tools to train vqgan-vae
|
- [ ] bring in tools to train vqgan-vae
|
||||||
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ from torch.autograd import grad as torch_grad
|
|||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
from einops import rearrange, reduce, repeat
|
from einops import rearrange, reduce, repeat
|
||||||
|
from einops_exts import rearrange_many
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
|
|
||||||
from dalle2_pytorch.attention import QueryAttnUpsample
|
from dalle2_pytorch.attention import QueryAttnUpsample
|
||||||
|
|
||||||
@@ -146,6 +148,8 @@ class LayerNormChan(nn.Module):
|
|||||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||||
return (x - mean) / (var + self.eps).sqrt() * self.gamma
|
return (x - mean) / (var + self.eps).sqrt() * self.gamma
|
||||||
|
|
||||||
|
# discriminator
|
||||||
|
|
||||||
class Discriminator(nn.Module):
|
class Discriminator(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -179,6 +183,8 @@ class Discriminator(nn.Module):
|
|||||||
|
|
||||||
return self.to_logits(x)
|
return self.to_logits(x)
|
||||||
|
|
||||||
|
# positional encoding
|
||||||
|
|
||||||
class ContinuousPositionBias(nn.Module):
|
class ContinuousPositionBias(nn.Module):
|
||||||
""" from https://arxiv.org/abs/2111.09883 """
|
""" from https://arxiv.org/abs/2111.09883 """
|
||||||
|
|
||||||
@@ -213,6 +219,84 @@ class ContinuousPositionBias(nn.Module):
|
|||||||
bias = rearrange(rel_pos, 'i j h -> h i j')
|
bias = rearrange(rel_pos, 'i j h -> h i j')
|
||||||
return x + bias
|
return x + bias
|
||||||
|
|
||||||
|
# resnet encoder / decoder
|
||||||
|
|
||||||
|
class ResnetEncDec(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
*,
|
||||||
|
channels = 3,
|
||||||
|
layers = 4,
|
||||||
|
layer_mults = None,
|
||||||
|
num_resnet_blocks = 1,
|
||||||
|
resnet_groups = 16,
|
||||||
|
first_conv_kernel_size = 5,
|
||||||
|
use_attn = True,
|
||||||
|
attn_dim_head = 64,
|
||||||
|
attn_heads = 8,
|
||||||
|
attn_dropout = 0.,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
|
||||||
|
|
||||||
|
self.layers = layers
|
||||||
|
|
||||||
|
self.encoders = MList([])
|
||||||
|
self.decoders = MList([])
|
||||||
|
|
||||||
|
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
|
||||||
|
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
|
||||||
|
|
||||||
|
layer_dims = [dim * mult for mult in layer_mults]
|
||||||
|
dims = (dim, *layer_dims)
|
||||||
|
|
||||||
|
self.encoded_dim = dims[-1]
|
||||||
|
|
||||||
|
dim_pairs = zip(dims[:-1], dims[1:])
|
||||||
|
|
||||||
|
append = lambda arr, t: arr.append(t)
|
||||||
|
prepend = lambda arr, t: arr.insert(0, t)
|
||||||
|
|
||||||
|
if not isinstance(num_resnet_blocks, tuple):
|
||||||
|
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
|
||||||
|
|
||||||
|
if not isinstance(use_attn, tuple):
|
||||||
|
use_attn = (*((False,) * (layers - 1)), use_attn)
|
||||||
|
|
||||||
|
assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
|
||||||
|
assert len(use_attn) == layers
|
||||||
|
|
||||||
|
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
|
||||||
|
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||||
|
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
|
||||||
|
|
||||||
|
if layer_use_attn:
|
||||||
|
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||||
|
|
||||||
|
for _ in range(layer_num_resnet_blocks):
|
||||||
|
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
|
||||||
|
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
|
||||||
|
|
||||||
|
if layer_use_attn:
|
||||||
|
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||||
|
|
||||||
|
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
|
||||||
|
append(self.decoders, nn.Conv2d(dim, channels, 1))
|
||||||
|
|
||||||
|
def get_encoded_fmap_size(self, image_size):
|
||||||
|
return image_size // (2 ** self.layers)
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
for enc in self.encoders:
|
||||||
|
x = enc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
for dec in self.decoders:
|
||||||
|
x = dec(x)
|
||||||
|
return x
|
||||||
|
|
||||||
class GLUResBlock(nn.Module):
|
class GLUResBlock(nn.Module):
|
||||||
def __init__(self, chan, groups = 16):
|
def __init__(self, chan, groups = 16):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -246,6 +330,7 @@ class ResBlock(nn.Module):
|
|||||||
return self.net(x) + x
|
return self.net(x) + x
|
||||||
|
|
||||||
# vqgan attention layer
|
# vqgan attention layer
|
||||||
|
|
||||||
class VQGanAttention(nn.Module):
|
class VQGanAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -290,6 +375,145 @@ class VQGanAttention(nn.Module):
|
|||||||
|
|
||||||
return out + residual
|
return out + residual
|
||||||
|
|
||||||
|
# ViT encoder / decoder
|
||||||
|
|
||||||
|
class RearrangeImage(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
n = x.shape[1]
|
||||||
|
w = h = int(sqrt(n))
|
||||||
|
return rearrange(x, 'b (h w) ... -> b h w ...', h = h, w = w)
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
*,
|
||||||
|
heads = 8,
|
||||||
|
dim_head = 32
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = dim_head ** -0.5
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
|
||||||
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||||
|
self.to_out = nn.Linear(inner_dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
||||||
|
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||||
|
|
||||||
|
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||||
|
attn = sim.softmax(dim = -1)
|
||||||
|
|
||||||
|
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||||
|
|
||||||
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
def FeedForward(dim, mult = 4):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.LayerNorm(dim),
|
||||||
|
nn.Linear(dim, dim * mult, bias = False),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(dim * mult, dim, bias = False)
|
||||||
|
)
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
*,
|
||||||
|
layers,
|
||||||
|
dim_head = 32,
|
||||||
|
heads = 8,
|
||||||
|
ff_mult = 4
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
for _ in range(layers):
|
||||||
|
self.layers.append(nn.ModuleList([
|
||||||
|
Attention(dim = dim, dim_head = dim_head, heads = heads),
|
||||||
|
FeedForward(dim = dim, mult = ff_mult)
|
||||||
|
]))
|
||||||
|
|
||||||
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for attn, ff in self.layers:
|
||||||
|
x = attn(x) + x
|
||||||
|
x = ff(x) + x
|
||||||
|
|
||||||
|
return self.norm(x)
|
||||||
|
|
||||||
|
class ViTEncDec(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
channels = 3,
|
||||||
|
layers = 4,
|
||||||
|
patch_size = 8,
|
||||||
|
dim_head = 32,
|
||||||
|
heads = 8,
|
||||||
|
ff_mult = 4
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.encoded_dim = dim
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
input_dim = channels * (patch_size ** 2)
|
||||||
|
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||||
|
nn.Linear(input_dim, dim),
|
||||||
|
Transformer(
|
||||||
|
dim = dim,
|
||||||
|
dim_head = dim_head,
|
||||||
|
heads = heads,
|
||||||
|
ff_mult = ff_mult,
|
||||||
|
layers = layers
|
||||||
|
),
|
||||||
|
RearrangeImage(),
|
||||||
|
Rearrange('b h w c -> b c h w')
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decoder = nn.Sequential(
|
||||||
|
Rearrange('b c h w -> b (h w) c'),
|
||||||
|
Transformer(
|
||||||
|
dim = dim,
|
||||||
|
dim_head = dim_head,
|
||||||
|
heads = heads,
|
||||||
|
ff_mult = ff_mult,
|
||||||
|
layers = layers
|
||||||
|
),
|
||||||
|
nn.Sequential(
|
||||||
|
nn.Linear(dim, dim * 4, bias = False),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Linear(dim * 4, input_dim, bias = False),
|
||||||
|
),
|
||||||
|
RearrangeImage(),
|
||||||
|
Rearrange('b h w (p1 p2 c) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_encoded_fmap_size(self, image_size):
|
||||||
|
return image_size // self.patch_size
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return self.encoder(x)
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return self.decoder(x)
|
||||||
|
|
||||||
|
# main vqgan-vae classes
|
||||||
|
|
||||||
class NullVQGanVAE(nn.Module):
|
class NullVQGanVAE(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -320,81 +544,43 @@ class VQGanVAE(nn.Module):
|
|||||||
image_size,
|
image_size,
|
||||||
channels = 3,
|
channels = 3,
|
||||||
layers = 4,
|
layers = 4,
|
||||||
layer_mults = None,
|
|
||||||
l2_recon_loss = False,
|
l2_recon_loss = False,
|
||||||
use_hinge_loss = True,
|
use_hinge_loss = True,
|
||||||
num_resnet_blocks = 1,
|
|
||||||
vgg = None,
|
vgg = None,
|
||||||
vq_codebook_size = 512,
|
vq_codebook_size = 512,
|
||||||
vq_decay = 0.8,
|
vq_decay = 0.8,
|
||||||
vq_commitment_weight = 1.,
|
vq_commitment_weight = 1.,
|
||||||
vq_kmeans_init = True,
|
vq_kmeans_init = True,
|
||||||
vq_use_cosine_sim = True,
|
vq_use_cosine_sim = True,
|
||||||
use_attn = True,
|
|
||||||
attn_dim_head = 64,
|
|
||||||
attn_heads = 8,
|
|
||||||
resnet_groups = 16,
|
|
||||||
attn_dropout = 0.,
|
|
||||||
first_conv_kernel_size = 5,
|
|
||||||
use_vgg_and_gan = True,
|
use_vgg_and_gan = True,
|
||||||
|
vae_type = 'resnet',
|
||||||
|
discr_layers = 4,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
|
|
||||||
|
|
||||||
vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
|
vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
|
||||||
|
encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs)
|
||||||
|
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.layers = layers
|
|
||||||
self.fmap_size = image_size // (layers ** 2)
|
|
||||||
self.codebook_size = vq_codebook_size
|
self.codebook_size = vq_codebook_size
|
||||||
|
|
||||||
self.encoders = MList([])
|
if vae_type == 'resnet':
|
||||||
self.decoders = MList([])
|
enc_dec_klass = ResnetEncDec
|
||||||
|
elif vae_type == 'vit':
|
||||||
|
enc_dec_klass = ViTEncDec
|
||||||
|
else:
|
||||||
|
raise ValueError(f'{vae_type} not valid')
|
||||||
|
|
||||||
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
|
self.enc_dec = enc_dec_klass(
|
||||||
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
|
dim = dim,
|
||||||
|
channels = channels,
|
||||||
layer_dims = [dim * mult for mult in layer_mults]
|
layers = layers,
|
||||||
dims = (dim, *layer_dims)
|
**encdec_kwargs
|
||||||
codebook_dim = layer_dims[-1]
|
)
|
||||||
|
|
||||||
self.encoded_dim = dims[-1]
|
|
||||||
|
|
||||||
dim_pairs = zip(dims[:-1], dims[1:])
|
|
||||||
|
|
||||||
append = lambda arr, t: arr.append(t)
|
|
||||||
prepend = lambda arr, t: arr.insert(0, t)
|
|
||||||
|
|
||||||
if not isinstance(num_resnet_blocks, tuple):
|
|
||||||
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
|
|
||||||
|
|
||||||
if not isinstance(use_attn, tuple):
|
|
||||||
use_attn = (*((False,) * (layers - 1)), use_attn)
|
|
||||||
|
|
||||||
assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
|
|
||||||
assert len(use_attn) == layers
|
|
||||||
|
|
||||||
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
|
|
||||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
|
||||||
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
|
|
||||||
|
|
||||||
if layer_use_attn:
|
|
||||||
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
|
||||||
|
|
||||||
for _ in range(layer_num_resnet_blocks):
|
|
||||||
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
|
|
||||||
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
|
|
||||||
|
|
||||||
if layer_use_attn:
|
|
||||||
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
|
||||||
|
|
||||||
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
|
|
||||||
append(self.decoders, nn.Conv2d(dim, channels, 1))
|
|
||||||
|
|
||||||
self.vq = VQ(
|
self.vq = VQ(
|
||||||
dim = codebook_dim,
|
dim = self.enc_dec.encoded_dim,
|
||||||
codebook_size = vq_codebook_size,
|
codebook_size = vq_codebook_size,
|
||||||
decay = vq_decay,
|
decay = vq_decay,
|
||||||
commitment_weight = vq_commitment_weight,
|
commitment_weight = vq_commitment_weight,
|
||||||
@@ -427,13 +613,21 @@ class VQGanVAE(nn.Module):
|
|||||||
|
|
||||||
# gan related losses
|
# gan related losses
|
||||||
|
|
||||||
|
layer_mults = list(map(lambda t: 2 ** t, range(discr_layers)))
|
||||||
|
layer_dims = [dim * mult for mult in layer_mults]
|
||||||
|
dims = (dim, *layer_dims)
|
||||||
|
|
||||||
self.discr = Discriminator(dims = dims, channels = channels)
|
self.discr = Discriminator(dims = dims, channels = channels)
|
||||||
|
|
||||||
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
|
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
|
||||||
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
|
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoded_dim(self):
|
||||||
|
return self.enc_dec.encoded_dim
|
||||||
|
|
||||||
def get_encoded_fmap_size(self, image_size):
|
def get_encoded_fmap_size(self, image_size):
|
||||||
return image_size // (2 ** self.layers)
|
return self.enc_dec.get_encoded_fmap_size(image_size)
|
||||||
|
|
||||||
def copy_for_eval(self):
|
def copy_for_eval(self):
|
||||||
device = next(self.parameters()).device
|
device = next(self.parameters()).device
|
||||||
@@ -459,16 +653,13 @@ class VQGanVAE(nn.Module):
|
|||||||
return self.vq.codebook
|
return self.vq.codebook
|
||||||
|
|
||||||
def encode(self, fmap):
|
def encode(self, fmap):
|
||||||
for enc in self.encoders:
|
fmap = self.enc_dec.encode(fmap)
|
||||||
fmap = enc(fmap)
|
|
||||||
|
|
||||||
return fmap
|
return fmap
|
||||||
|
|
||||||
def decode(self, fmap, return_indices_and_loss = False):
|
def decode(self, fmap, return_indices_and_loss = False):
|
||||||
fmap, indices, commit_loss = self.vq(fmap)
|
fmap, indices, commit_loss = self.vq(fmap)
|
||||||
|
|
||||||
for dec in self.decoders:
|
fmap = self.enc_dec.decode(fmap)
|
||||||
fmap = dec(fmap)
|
|
||||||
|
|
||||||
if not return_indices_and_loss:
|
if not return_indices_and_loss:
|
||||||
return fmap
|
return fmap
|
||||||
|
|||||||
Reference in New Issue
Block a user