diff --git a/dalle2_pytorch/tokenizer.py b/dalle2_pytorch/tokenizer.py index 009ef04..7c01008 100644 --- a/dalle2_pytorch/tokenizer.py +++ b/dalle2_pytorch/tokenizer.py @@ -2,7 +2,6 @@ # to give users a quick easy start to training DALL-E without doing BPE import torch -import youtokentome as yttm import html import os @@ -11,6 +10,8 @@ import regex as re from functools import lru_cache from pathlib import Path +from dalle2_pytorch.utils import import_or_print_error + # OpenAI simple tokenizer @lru_cache() @@ -156,7 +157,9 @@ class YttmTokenizer: bpe_path = Path(bpe_path) assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist' - tokenizer = yttm.BPE(model = str(bpe_path)) + self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`') + + tokenizer = self.yttm.BPE(model = str(bpe_path)) self.tokenizer = tokenizer self.vocab_size = tokenizer.vocab_size() @@ -167,7 +170,7 @@ class YttmTokenizer: return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0})) def encode(self, texts): - encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID) + encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID) return list(map(torch.tensor, encoded)) def tokenize(self, texts, context_length = 256, truncate_text = False): diff --git a/dalle2_pytorch/trackers.py b/dalle2_pytorch/trackers.py index 9ae56ff..9204f2e 100644 --- a/dalle2_pytorch/trackers.py +++ b/dalle2_pytorch/trackers.py @@ -6,6 +6,8 @@ from itertools import zip_longest import torch from torch import nn +from dalle2_pytorch.utils import import_or_print_error + # constants DEFAULT_DATA_PATH = './.tracker-data' @@ -15,14 +17,6 @@ DEFAULT_DATA_PATH = './.tracker-data' def exists(val): return val is not None -def import_or_print_error(pkg_name, err_str = None): - try: - return importlib.import_module(pkg_name) - except ModuleNotFoundError as e: - if exists(err_str): - print(err_str) - exit() - # load state dict functions def load_wandb_state_dict(run_path, file_path, **kwargs): diff --git a/dalle2_pytorch/utils.py b/dalle2_pytorch/utils.py index 9d52be2..7208f3e 100644 --- a/dalle2_pytorch/utils.py +++ b/dalle2_pytorch/utils.py @@ -17,3 +17,13 @@ class Timer: def print_ribbon(s, symbol = '=', repeat = 40): flank = symbol * repeat return f'{flank} {s} {flank}' + +# import helpers + +def import_or_print_error(pkg_name, err_str = None): + try: + return importlib.import_module(pkg_name) + except ModuleNotFoundError as e: + if exists(err_str): + print(err_str) + exit() diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index e2f45ae..7c9e66e 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.6.5' +__version__ = '0.6.6' diff --git a/setup.py b/setup.py index 199007f..ab2ba08 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,6 @@ setup( 'tqdm', 'vector-quantize-pytorch', 'x-clip>=0.4.4', - 'youtokentome', 'webdataset>=0.2.5', 'fsspec>=2022.1.0', 'torchmetrics[image]>=0.8.0'