diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index a04c4d4..10aa288 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -133,12 +133,6 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs): chunk_size_frac = chunk_size / batch_size yield chunk_size_frac, (chunked_args, chunked_kwargs) -# print helpers - -def print_ribbon(s, symbol = '=', repeat = 40): - flank = symbol * repeat - return f'{flank} {s} {flank}' - # saving and loading functions # for diffusion prior diff --git a/dalle2_pytorch/utils.py b/dalle2_pytorch/utils.py index f7cbf86..9d52be2 100644 --- a/dalle2_pytorch/utils.py +++ b/dalle2_pytorch/utils.py @@ -1,5 +1,7 @@ import time +# time helpers + class Timer: def __init__(self): self.reset() @@ -9,3 +11,9 @@ class Timer: def elapsed(self): return time.time() - self.last_time + +# print helpers + +def print_ribbon(s, symbol = '=', repeat = 40): + flank = symbol * repeat + return f'{flank} {s} {flank}' diff --git a/setup.py b/setup.py index 09f49ce..b6e907b 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.4.5', + version = '0.4.6', license='MIT', description = 'DALL-E 2', author = 'Phil Wang', diff --git a/train_decoder.py b/train_decoder.py index e2dfebe..cccd81c 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -1,9 +1,9 @@ from dalle2_pytorch import Unet, Decoder -from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon +from dalle2_pytorch.trainer import DecoderTrainer from dalle2_pytorch.dataloaders import create_image_embedding_dataloader from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker from dalle2_pytorch.train_configs import TrainDecoderConfig -from dalle2_pytorch.utils import Timer +from dalle2_pytorch.utils import Timer, print_ribbon import torchvision import torch diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index c7eb1a1..3a625de 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -9,10 +9,10 @@ from torch import nn from dalle2_pytorch.dataloaders import make_splits from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter -from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon +from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker -from dalle2_pytorch.utils import Timer +from dalle2_pytorch.utils import Timer, print_ribbon from embedding_reader import EmbeddingReader