unet_number on decoder trainer only needs to be passed in if there is greater than 1 unet, so that unconditional training of a single ddpm is seamless (experiment in progress locally)

This commit is contained in:
Phil Wang
2022-05-16 09:17:17 -07:00
parent 4a59dea4cf
commit bb151ca6b1
5 changed files with 14 additions and 6 deletions

View File

@@ -7,9 +7,11 @@ import numpy as np
import torch
import clip
from torch import nn
from dalle2_pytorch.dataloaders import make_splits
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.train import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from embedding_reader import EmbeddingReader