Compare commits

..

2 Commits
0.3.4 ... 0.3.5

Author SHA1 Message Date
Phil Wang
430961cb97 it was correct the first time, my bad 2022-05-20 18:05:15 -07:00
Phil Wang
721f9687c1 fix wandb logging in tracker, and do some cleanup 2022-05-20 17:27:43 -07:00
2 changed files with 2 additions and 4 deletions

View File

@@ -1,6 +1,5 @@
import os import os
from pathlib import Path from pathlib import Path
from enum import Enum
import importlib import importlib
from itertools import zip_longest from itertools import zip_longest
@@ -39,7 +38,6 @@ def load_local_state_dict(file_path, **kwargs):
class BaseTracker(nn.Module): class BaseTracker(nn.Module):
def __init__(self, data_path = DEFAULT_DATA_PATH): def __init__(self, data_path = DEFAULT_DATA_PATH):
super().__init__() super().__init__()
assert data_path is not None, "Tracker must have a data_path to save local content"
self.data_path = Path(data_path) self.data_path = Path(data_path)
self.data_path.mkdir(parents = True, exist_ok = True) self.data_path.mkdir(parents = True, exist_ok = True)
@@ -106,7 +104,7 @@ class WandbTracker(BaseTracker):
Takes a tensor of images and a list of captions and logs them to wandb. Takes a tensor of images and a list of captions and logs them to wandb.
""" """
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)] wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
self.wandb.log({ image_section: wandb_images }, **kwargs) self.log({ image_section: wandb_images }, **kwargs)
def save_state_dict(self, state_dict, relative_path, **kwargs): def save_state_dict(self, state_dict, relative_path, **kwargs):
""" """

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.3.4', version = '0.3.5',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',