mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-16 02:04:24 +01:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc58f75474 | ||
|
|
3b2cf7b0bc | ||
|
|
984d62a373 | ||
|
|
683dd98b96 |
@@ -360,6 +360,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
is_eos_id = (text == self.eos_id)
|
is_eos_id = (text == self.eos_id)
|
||||||
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
|
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
|
||||||
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
|
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
|
||||||
|
text_mask = text_mask & (text != 0)
|
||||||
assert not self.cleared
|
assert not self.cleared
|
||||||
|
|
||||||
text_embed = self.clip.encode_text(text)
|
text_embed = self.clip.encode_text(text)
|
||||||
@@ -434,6 +435,7 @@ class OpenClipAdapter(BaseClipAdapter):
|
|||||||
is_eos_id = (text == self.eos_id)
|
is_eos_id = (text == self.eos_id)
|
||||||
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
|
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
|
||||||
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
|
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
|
||||||
|
text_mask = text_mask & (text != 0)
|
||||||
assert not self.cleared
|
assert not self.cleared
|
||||||
|
|
||||||
text_embed = self.clip.encode_text(text)
|
text_embed = self.clip.encode_text(text)
|
||||||
@@ -1122,7 +1124,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||||
|
|
||||||
if self.self_cond:
|
if self.self_cond:
|
||||||
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
|
learned_queries = torch.cat((self_cond, learned_queries), dim = -2)
|
||||||
|
|
||||||
tokens = torch.cat((
|
tokens = torch.cat((
|
||||||
text_encodings,
|
text_encodings,
|
||||||
@@ -2494,7 +2496,7 @@ class Decoder(nn.Module):
|
|||||||
dynamic_thres_percentile = 0.95,
|
dynamic_thres_percentile = 0.95,
|
||||||
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
|
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
|
||||||
p2_loss_weight_k = 1,
|
p2_loss_weight_k = 1,
|
||||||
ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict
|
ddim_sampling_eta = 0. # can be set to 0. for deterministic sampling afaict
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.11.2'
|
__version__ = '1.12.2'
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -27,7 +27,7 @@ setup(
|
|||||||
'accelerate',
|
'accelerate',
|
||||||
'click',
|
'click',
|
||||||
'open-clip-torch>=2.0.0,<3.0.0',
|
'open-clip-torch>=2.0.0,<3.0.0',
|
||||||
'clip-anytorch>=2.4.0',
|
'clip-anytorch>=2.5.2',
|
||||||
'coca-pytorch>=0.0.5',
|
'coca-pytorch>=0.0.5',
|
||||||
'ema-pytorch>=0.0.7',
|
'ema-pytorch>=0.0.7',
|
||||||
'einops>=0.4',
|
'einops>=0.4',
|
||||||
|
|||||||
Reference in New Issue
Block a user