Compare commits

...

5 Commits

Author SHA1 Message Date
Phil Wang
a851168633 make youtokentome optional package, due to reported installation difficulties 2022-06-01 09:25:35 -07:00
Phil Wang
1ffeecd0ca lower default ema beta value 2022-05-31 11:55:21 -07:00
Phil Wang
3df899f7a4 patch 2022-05-31 09:03:43 -07:00
Aidan Dempster
09534119a1 Fixed non deterministic optimizer creation (#130) 2022-05-31 09:03:20 -07:00
Phil Wang
6f8b90d4d7 add packaging package 2022-05-30 11:45:00 -07:00
7 changed files with 27 additions and 18 deletions

View File

@@ -1,8 +1,10 @@
from torch.optim import AdamW, Adam
def separate_weight_decayable_params(params):
no_wd_params = set([param for param in params if param.ndim < 2])
wd_params = set(params) - no_wd_params
wd_params, no_wd_params = [], []
for param in params:
param_list = no_wd_params if param.ndim < 2 else wd_params
param_list.append(param)
return wd_params, no_wd_params
def get_optimizer(
@@ -25,8 +27,8 @@ def get_optimizer(
wd_params, no_wd_params = separate_weight_decayable_params(params)
params = [
{'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0},
{'params': wd_params},
{'params': no_wd_params, 'weight_decay': 0},
]
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -2,7 +2,6 @@
# to give users a quick easy start to training DALL-E without doing BPE
import torch
import youtokentome as yttm
import html
import os
@@ -11,6 +10,8 @@ import regex as re
from functools import lru_cache
from pathlib import Path
from dalle2_pytorch.utils import import_or_print_error
# OpenAI simple tokenizer
@lru_cache()
@@ -156,7 +157,9 @@ class YttmTokenizer:
bpe_path = Path(bpe_path)
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
tokenizer = yttm.BPE(model = str(bpe_path))
self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`')
tokenizer = self.yttm.BPE(model = str(bpe_path))
self.tokenizer = tokenizer
self.vocab_size = tokenizer.vocab_size()
@@ -167,7 +170,7 @@ class YttmTokenizer:
return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
def encode(self, texts):
encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID)
encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID)
return list(map(torch.tensor, encoded))
def tokenize(self, texts, context_length = 256, truncate_text = False):

View File

@@ -6,6 +6,8 @@ from itertools import zip_longest
import torch
from torch import nn
from dalle2_pytorch.utils import import_or_print_error
# constants
DEFAULT_DATA_PATH = './.tracker-data'
@@ -15,14 +17,6 @@ DEFAULT_DATA_PATH = './.tracker-data'
def exists(val):
return val is not None
def import_or_print_error(pkg_name, err_str = None):
try:
return importlib.import_module(pkg_name)
except ModuleNotFoundError as e:
if exists(err_str):
print(err_str)
exit()
# load state dict functions
def load_wandb_state_dict(run_path, file_path, **kwargs):

View File

@@ -178,7 +178,7 @@ class EMA(nn.Module):
def __init__(
self,
model,
beta = 0.9999,
beta = 0.99,
update_after_step = 1000,
update_every = 10,
):

View File

@@ -17,3 +17,13 @@ class Timer:
def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat
return f'{flank} {s} {flank}'
# import helpers
def import_or_print_error(pkg_name, err_str = None):
try:
return importlib.import_module(pkg_name)
except ModuleNotFoundError as e:
if exists(err_str):
print(err_str)
exit()

View File

@@ -1 +1 @@
__version__ = '0.6.2'
__version__ = '0.6.6'

View File

@@ -32,6 +32,7 @@ setup(
'embedding-reader',
'kornia>=0.5.4',
'numpy',
'packaging',
'pillow',
'pydantic',
'resize-right>=0.0.2',
@@ -41,7 +42,6 @@ setup(
'tqdm',
'vector-quantize-pytorch',
'x-clip>=0.4.4',
'youtokentome',
'webdataset>=0.2.5',
'fsspec>=2022.1.0',
'torchmetrics[image]>=0.8.0'