mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 12:44:28 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
164d9be444 | ||
|
|
5562ec6be2 | ||
|
|
89ff04cfe2 |
@@ -14,6 +14,12 @@ Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord
|
|||||||
|
|
||||||
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||||
|
|
||||||
|
## Status
|
||||||
|
|
||||||
|
- A research group has used the code in this repository to train a functional diffusion prior for their CLIP generations. Will share their work once they release their preprint. This, and <a href="https://github.com/crowsonkb">Katherine's</a> own experiments, validate OpenAI's finding that the extra prior increases variety of generations.
|
||||||
|
|
||||||
|
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -47,6 +47,14 @@ def groupby_prefix_and_trim(prefix, d):
|
|||||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||||
return kwargs_without_prefix, kwargs
|
return kwargs_without_prefix, kwargs
|
||||||
|
|
||||||
|
def num_to_groups(num, divisor):
|
||||||
|
groups = num // divisor
|
||||||
|
remainder = num % divisor
|
||||||
|
arr = [divisor] * groups
|
||||||
|
if remainder > 0:
|
||||||
|
arr.append(remainder)
|
||||||
|
return arr
|
||||||
|
|
||||||
# decorators
|
# decorators
|
||||||
|
|
||||||
def cast_torch_tensor(fn):
|
def cast_torch_tensor(fn):
|
||||||
@@ -195,7 +203,11 @@ class EMA(nn.Module):
|
|||||||
def update(self):
|
def update(self):
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
if self.step <= self.update_after_step or (self.step % self.update_every) != 0:
|
if (self.step % self.update_every) != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.step <= self.update_after_step:
|
||||||
|
self.copy_params_from_model_to_ema()
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.initted:
|
if not self.initted:
|
||||||
@@ -318,6 +330,22 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
# decoder trainer
|
# decoder trainer
|
||||||
|
|
||||||
|
def decoder_sample_in_chunks(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(self, *args, max_batch_size = None, **kwargs):
|
||||||
|
if not exists(max_batch_size):
|
||||||
|
return fn(self, *args, **kwargs)
|
||||||
|
|
||||||
|
if self.decoder.unconditional:
|
||||||
|
batch_size = kwargs.get('batch_size')
|
||||||
|
batch_sizes = num_to_groups(batch_size, max_batch_size)
|
||||||
|
outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
|
||||||
|
else:
|
||||||
|
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
|
||||||
|
|
||||||
|
return torch.cat(outputs, dim = 0)
|
||||||
|
return inner
|
||||||
|
|
||||||
class DecoderTrainer(nn.Module):
|
class DecoderTrainer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -407,18 +435,17 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
|
@decoder_sample_in_chunks
|
||||||
def sample(self, *args, **kwargs):
|
def sample(self, *args, **kwargs):
|
||||||
if kwargs.pop('use_non_ema', False):
|
if kwargs.pop('use_non_ema', False) or not self.use_ema:
|
||||||
return self.decoder.sample(*args, **kwargs)
|
return self.decoder.sample(*args, **kwargs)
|
||||||
|
|
||||||
if self.use_ema:
|
trainable_unets = self.decoder.unets
|
||||||
trainable_unets = self.decoder.unets
|
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||||
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
|
||||||
|
|
||||||
output = self.decoder.sample(*args, **kwargs)
|
output = self.decoder.sample(*args, **kwargs)
|
||||||
|
|
||||||
if self.use_ema:
|
self.decoder.unets = trainable_unets # restore original training unets
|
||||||
self.decoder.unets = trainable_unets # restore original training unets
|
|
||||||
|
|
||||||
# cast the ema_model unets back to original device
|
# cast the ema_model unets back to original device
|
||||||
for ema in self.ema_unets:
|
for ema in self.ema_unets:
|
||||||
|
|||||||
Reference in New Issue
Block a user