mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Add data flexibility to decoder trainer (#165)
* Added the ability to train decoder with text embeddings * Added the ability to train using on the fly generated embeddings with clip * Clip now generates embeddings for whatever is not precomputed
This commit is contained in:
@@ -578,6 +578,18 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def embed_text(self, *args, **kwargs):
|
||||
return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def embed_image(self, *args, **kwargs):
|
||||
return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs)
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user