mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 20:35:24 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9340d33d5f |
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from enum import Enum
|
||||||
import importlib
|
import importlib
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
|
|
||||||
@@ -38,6 +39,7 @@ def load_local_state_dict(file_path, **kwargs):
|
|||||||
class BaseTracker(nn.Module):
|
class BaseTracker(nn.Module):
|
||||||
def __init__(self, data_path = DEFAULT_DATA_PATH):
|
def __init__(self, data_path = DEFAULT_DATA_PATH):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
assert data_path is not None, "Tracker must have a data_path to save local content"
|
||||||
self.data_path = Path(data_path)
|
self.data_path = Path(data_path)
|
||||||
self.data_path.mkdir(parents = True, exist_ok = True)
|
self.data_path.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
@@ -104,7 +106,7 @@ class WandbTracker(BaseTracker):
|
|||||||
Takes a tensor of images and a list of captions and logs them to wandb.
|
Takes a tensor of images and a list of captions and logs them to wandb.
|
||||||
"""
|
"""
|
||||||
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
||||||
self.log({ image_section: wandb_images }, **kwargs)
|
self.wandb.log({ image_section: wandb_images }, **kwargs)
|
||||||
|
|
||||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user