mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
unet_number on decoder trainer only needs to be passed in if there is greater than 1 unet, so that unconditional training of a single ddpm is seamless (experiment in progress locally)
This commit is contained in:
@@ -7,9 +7,11 @@ import numpy as np
|
||||
import torch
|
||||
import clip
|
||||
from torch import nn
|
||||
|
||||
from dalle2_pytorch.dataloaders import make_splits
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||
from dalle2_pytorch.train import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
||||
|
||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||
|
||||
from embedding_reader import EmbeddingReader
|
||||
|
||||
Reference in New Issue
Block a user