mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
Improved upsampler training (#181)
Sampling is now possible without the first decoder unet Non-training unets are deleted in the decoder trainer since they are never used and it is harder merge the models is they have keys in this state dict Fixed a mistake where clip was not re-added after saving
This commit is contained in:
@@ -75,6 +75,8 @@ def cast_tuple(val, length = None, validate = True):
|
||||
return out
|
||||
|
||||
def module_device(module):
|
||||
if isinstance(module, nn.Identity):
|
||||
return 'cpu' # It doesn't matter
|
||||
return next(module.parameters()).device
|
||||
|
||||
def zero_init_(m):
|
||||
@@ -2326,7 +2328,7 @@ class Decoder(nn.Module):
|
||||
|
||||
@property
|
||||
def condition_on_text_encodings(self):
|
||||
return any([unet.cond_on_text_encodings for unet in self.unets])
|
||||
return any([unet.cond_on_text_encodings for unet in self.unets if isinstance(unet, Unet)])
|
||||
|
||||
def get_unet(self, unet_number):
|
||||
assert 0 < unet_number <= self.num_unets
|
||||
@@ -2646,11 +2648,13 @@ class Decoder(nn.Module):
|
||||
@eval_decorator
|
||||
def sample(
|
||||
self,
|
||||
image = None,
|
||||
image_embed = None,
|
||||
text = None,
|
||||
text_encodings = None,
|
||||
batch_size = 1,
|
||||
cond_scale = 1.,
|
||||
start_at_unet_number = 1,
|
||||
stop_at_unet_number = None,
|
||||
distributed = False,
|
||||
inpaint_image = None,
|
||||
@@ -2671,14 +2675,22 @@ class Decoder(nn.Module):
|
||||
assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting'
|
||||
|
||||
img = None
|
||||
if start_at_unet_number > 1:
|
||||
# Then we are not generating the first image and one must have been passed in
|
||||
assert exists(image), 'image must be passed in if starting at unet number > 1'
|
||||
assert image.shape[0] == batch_size, 'image must have batch size of {} if starting at unet number > 1'.format(batch_size)
|
||||
prev_unet_output_size = self.image_sizes[start_at_unet_number - 2]
|
||||
img = resize_image_to(image, prev_unet_output_size, nearest = True)
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
|
||||
num_unets = self.num_unets
|
||||
cond_scale = cast_tuple(cond_scale, num_unets)
|
||||
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
|
||||
if unet_number < start_at_unet_number:
|
||||
continue # It's the easiest way to do it
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
|
||||
|
||||
with context:
|
||||
# prepare low resolution conditioning for upsamplers
|
||||
|
||||
Reference in New Issue
Block a user