some cleanup

This commit is contained in:
Phil Wang
2022-05-09 16:50:21 -07:00
parent db805e73e1
commit 64f7be1926
5 changed files with 49 additions and 40 deletions

View File

@@ -6,7 +6,8 @@ import numpy as np
import torch
from torch import nn
from embedding_reader import EmbeddingReader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, load_diffusion_model, save_diffusion_model
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.optimizer import get_optimizer
from torch.cuda.amp import autocast,GradScaler
@@ -153,7 +154,7 @@ def train(image_embed_dim,
os.makedirs(save_path)
# Get image and text embeddings from the servers
print("==============Downloading embeddings - image and text====================")
print_ribbon("Downloading embeddings - image and text")
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
num_data_points = text_reader.count