Compare commits

..

2 Commits

Author SHA1 Message Date
Phil Wang
984d62a373 default ddim sampling eta to 0 2022-12-23 13:23:09 -08:00
Phil Wang
683dd98b96 extra insurance in case eos id is not there 2022-12-15 10:54:21 -08:00
2 changed files with 4 additions and 2 deletions

View File

@@ -360,6 +360,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
text_mask = text_mask & (text != 0)
assert not self.cleared
text_embed = self.clip.encode_text(text)
@@ -434,6 +435,7 @@ class OpenClipAdapter(BaseClipAdapter):
is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
text_mask = text_mask & (text != 0)
assert not self.cleared
text_embed = self.clip.encode_text(text)
@@ -2494,7 +2496,7 @@ class Decoder(nn.Module):
dynamic_thres_percentile = 0.95,
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1,
ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict
ddim_sampling_eta = 0. # can be set to 0. for deterministic sampling afaict
):
super().__init__()

View File

@@ -1 +1 @@
__version__ = '1.11.2'
__version__ = '1.12.0'