mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
small cleanup
This commit is contained in:
@@ -133,12 +133,6 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
|
|||||||
chunk_size_frac = chunk_size / batch_size
|
chunk_size_frac = chunk_size / batch_size
|
||||||
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
||||||
|
|
||||||
# print helpers
|
|
||||||
|
|
||||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
|
||||||
flank = symbol * repeat
|
|
||||||
return f'{flank} {s} {flank}'
|
|
||||||
|
|
||||||
# saving and loading functions
|
# saving and loading functions
|
||||||
|
|
||||||
# for diffusion prior
|
# for diffusion prior
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
|
# time helpers
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
@@ -9,3 +11,9 @@ class Timer:
|
|||||||
|
|
||||||
def elapsed(self):
|
def elapsed(self):
|
||||||
return time.time() - self.last_time
|
return time.time() - self.last_time
|
||||||
|
|
||||||
|
# print helpers
|
||||||
|
|
||||||
|
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||||
|
flank = symbol * repeat
|
||||||
|
return f'{flank} {s} {flank}'
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.4.5',
|
version = '0.4.6',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
from dalle2_pytorch import Unet, Decoder
|
from dalle2_pytorch import Unet, Decoder
|
||||||
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
|
from dalle2_pytorch.trainer import DecoderTrainer
|
||||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
||||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||||
from dalle2_pytorch.utils import Timer
|
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -9,10 +9,10 @@ from torch import nn
|
|||||||
|
|
||||||
from dalle2_pytorch.dataloaders import make_splits
|
from dalle2_pytorch.dataloaders import make_splits
|
||||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
|
||||||
|
|
||||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||||
from dalle2_pytorch.utils import Timer
|
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||||
|
|
||||||
from embedding_reader import EmbeddingReader
|
from embedding_reader import EmbeddingReader
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user