Improved upsampler training (#181)

Sampling is now possible without the first decoder unet

Non-training unets are deleted in the decoder trainer since they are never used and it is harder merge the models is they have keys in this state dict

Fixed a mistake where clip was not re-added after saving
This commit is contained in:
Aidan Dempster
2022-07-19 22:07:50 -04:00
committed by GitHub
parent 4b912a38c6
commit 4145474bab
6 changed files with 104 additions and 49 deletions

View File

@@ -75,6 +75,8 @@ def cast_tuple(val, length = None, validate = True):
return out
def module_device(module):
if isinstance(module, nn.Identity):
return 'cpu' # It doesn't matter
return next(module.parameters()).device
def zero_init_(m):
@@ -2326,7 +2328,7 @@ class Decoder(nn.Module):
@property
def condition_on_text_encodings(self):
return any([unet.cond_on_text_encodings for unet in self.unets])
return any([unet.cond_on_text_encodings for unet in self.unets if isinstance(unet, Unet)])
def get_unet(self, unet_number):
assert 0 < unet_number <= self.num_unets
@@ -2646,11 +2648,13 @@ class Decoder(nn.Module):
@eval_decorator
def sample(
self,
image = None,
image_embed = None,
text = None,
text_encodings = None,
batch_size = 1,
cond_scale = 1.,
start_at_unet_number = 1,
stop_at_unet_number = None,
distributed = False,
inpaint_image = None,
@@ -2671,14 +2675,22 @@ class Decoder(nn.Module):
assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting'
img = None
if start_at_unet_number > 1:
# Then we are not generating the first image and one must have been passed in
assert exists(image), 'image must be passed in if starting at unet number > 1'
assert image.shape[0] == batch_size, 'image must have batch size of {} if starting at unet number > 1'.format(batch_size)
prev_unet_output_size = self.image_sizes[start_at_unet_number - 2]
img = resize_image_to(image, prev_unet_output_size, nearest = True)
is_cuda = next(self.parameters()).is_cuda
num_unets = self.num_unets
cond_scale = cast_tuple(cond_scale, num_unets)
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
if unet_number < start_at_unet_number:
continue # It's the easiest way to do it
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
with context:
# prepare low resolution conditioning for upsamplers