mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
it was correct the first time, my bad
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from enum import Enum
|
|
||||||
import importlib
|
import importlib
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
|
|
||||||
@@ -105,7 +104,7 @@ class WandbTracker(BaseTracker):
|
|||||||
Takes a tensor of images and a list of captions and logs them to wandb.
|
Takes a tensor of images and a list of captions and logs them to wandb.
|
||||||
"""
|
"""
|
||||||
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
||||||
self.wandb.log({ image_section: wandb_images }, **kwargs)
|
self.log({ image_section: wandb_images }, **kwargs)
|
||||||
|
|
||||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user