make youtokentome optional package, due to reported installation difficulties

This commit is contained in:
Phil Wang
2022-06-01 09:25:35 -07:00
parent 1ffeecd0ca
commit a851168633
5 changed files with 19 additions and 13 deletions

View File

@@ -2,7 +2,6 @@
# to give users a quick easy start to training DALL-E without doing BPE # to give users a quick easy start to training DALL-E without doing BPE
import torch import torch
import youtokentome as yttm
import html import html
import os import os
@@ -11,6 +10,8 @@ import regex as re
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from dalle2_pytorch.utils import import_or_print_error
# OpenAI simple tokenizer # OpenAI simple tokenizer
@lru_cache() @lru_cache()
@@ -156,7 +157,9 @@ class YttmTokenizer:
bpe_path = Path(bpe_path) bpe_path = Path(bpe_path)
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist' assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
tokenizer = yttm.BPE(model = str(bpe_path)) self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`')
tokenizer = self.yttm.BPE(model = str(bpe_path))
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.vocab_size = tokenizer.vocab_size() self.vocab_size = tokenizer.vocab_size()
@@ -167,7 +170,7 @@ class YttmTokenizer:
return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0})) return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
def encode(self, texts): def encode(self, texts):
encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID) encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID)
return list(map(torch.tensor, encoded)) return list(map(torch.tensor, encoded))
def tokenize(self, texts, context_length = 256, truncate_text = False): def tokenize(self, texts, context_length = 256, truncate_text = False):

View File

@@ -6,6 +6,8 @@ from itertools import zip_longest
import torch import torch
from torch import nn from torch import nn
from dalle2_pytorch.utils import import_or_print_error
# constants # constants
DEFAULT_DATA_PATH = './.tracker-data' DEFAULT_DATA_PATH = './.tracker-data'
@@ -15,14 +17,6 @@ DEFAULT_DATA_PATH = './.tracker-data'
def exists(val): def exists(val):
return val is not None return val is not None
def import_or_print_error(pkg_name, err_str = None):
try:
return importlib.import_module(pkg_name)
except ModuleNotFoundError as e:
if exists(err_str):
print(err_str)
exit()
# load state dict functions # load state dict functions
def load_wandb_state_dict(run_path, file_path, **kwargs): def load_wandb_state_dict(run_path, file_path, **kwargs):

View File

@@ -17,3 +17,13 @@ class Timer:
def print_ribbon(s, symbol = '=', repeat = 40): def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat flank = symbol * repeat
return f'{flank} {s} {flank}' return f'{flank} {s} {flank}'
# import helpers
def import_or_print_error(pkg_name, err_str = None):
try:
return importlib.import_module(pkg_name)
except ModuleNotFoundError as e:
if exists(err_str):
print(err_str)
exit()

View File

@@ -1 +1 @@
__version__ = '0.6.5' __version__ = '0.6.6'

View File

@@ -42,7 +42,6 @@ setup(
'tqdm', 'tqdm',
'vector-quantize-pytorch', 'vector-quantize-pytorch',
'x-clip>=0.4.4', 'x-clip>=0.4.4',
'youtokentome',
'webdataset>=0.2.5', 'webdataset>=0.2.5',
'fsspec>=2022.1.0', 'fsspec>=2022.1.0',
'torchmetrics[image]>=0.8.0' 'torchmetrics[image]>=0.8.0'