Compare commits

...

2 Commits

Author SHA1 Message Date
Phil Wang
a851168633 make youtokentome optional package, due to reported installation difficulties 2022-06-01 09:25:35 -07:00
Phil Wang
1ffeecd0ca lower default ema beta value 2022-05-31 11:55:21 -07:00
6 changed files with 20 additions and 14 deletions

View File

@@ -2,7 +2,6 @@
# to give users a quick easy start to training DALL-E without doing BPE
import torch
import youtokentome as yttm
import html
import os
@@ -11,6 +10,8 @@ import regex as re
from functools import lru_cache
from pathlib import Path
from dalle2_pytorch.utils import import_or_print_error
# OpenAI simple tokenizer
@lru_cache()
@@ -156,7 +157,9 @@ class YttmTokenizer:
bpe_path = Path(bpe_path)
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
tokenizer = yttm.BPE(model = str(bpe_path))
self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`')
tokenizer = self.yttm.BPE(model = str(bpe_path))
self.tokenizer = tokenizer
self.vocab_size = tokenizer.vocab_size()
@@ -167,7 +170,7 @@ class YttmTokenizer:
return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
def encode(self, texts):
encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID)
encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID)
return list(map(torch.tensor, encoded))
def tokenize(self, texts, context_length = 256, truncate_text = False):

View File

@@ -6,6 +6,8 @@ from itertools import zip_longest
import torch
from torch import nn
from dalle2_pytorch.utils import import_or_print_error
# constants
DEFAULT_DATA_PATH = './.tracker-data'
@@ -15,14 +17,6 @@ DEFAULT_DATA_PATH = './.tracker-data'
def exists(val):
return val is not None
def import_or_print_error(pkg_name, err_str = None):
try:
return importlib.import_module(pkg_name)
except ModuleNotFoundError as e:
if exists(err_str):
print(err_str)
exit()
# load state dict functions
def load_wandb_state_dict(run_path, file_path, **kwargs):

View File

@@ -178,7 +178,7 @@ class EMA(nn.Module):
def __init__(
self,
model,
beta = 0.9999,
beta = 0.99,
update_after_step = 1000,
update_every = 10,
):

View File

@@ -17,3 +17,13 @@ class Timer:
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'
# import helpers
def import_or_print_error(pkg_name, err_str = None):
try:
return importlib.import_module(pkg_name)
except ModuleNotFoundError as e:
if exists(err_str):
print(err_str)
exit()

View File

@@ -1 +1 @@
__version__ = '0.6.4'
__version__ = '0.6.6'

View File

@@ -42,7 +42,6 @@ setup(
'tqdm',
'vector-quantize-pytorch',
'x-clip>=0.4.4',
'youtokentome',
'webdataset>=0.2.5',
'fsspec>=2022.1.0',
'torchmetrics[image]>=0.8.0'