mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
some cleanup
This commit is contained in:
@@ -6,7 +6,8 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from embedding_reader import EmbeddingReader
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, load_diffusion_model, save_diffusion_model
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
|
||||
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
from torch.cuda.amp import autocast,GradScaler
|
||||
|
||||
@@ -153,7 +154,7 @@ def train(image_embed_dim,
|
||||
os.makedirs(save_path)
|
||||
|
||||
# Get image and text embeddings from the servers
|
||||
print("==============Downloading embeddings - image and text====================")
|
||||
print_ribbon("Downloading embeddings - image and text")
|
||||
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
|
||||
num_data_points = text_reader.count
|
||||
|
||||
Reference in New Issue
Block a user