mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Distributed Training of the Decoder (#121)
* Converted decoder trainer to use accelerate * Fixed issue where metric evaluation would hang on distributed mode * Implemented functional saving Loading still fails due to some issue with the optimizer * Fixed issue with loading decoders * Fixed issue with tracker config * Fixed issue with amp Updated logging to be more logical * Saving checkpoint now saves position in training as well Fixed an issue with running out of gpu space due to loading weights into the gpu twice * Fixed ema for distributed training * Fixed isue where get_pkg_version was reintroduced * Changed decoder trainer to upload config as a file Fixed issue where loading best would error
This commit is contained in:
@@ -2099,7 +2099,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
text_encodings = None,
|
||||
batch_size = 1,
|
||||
cond_scale = 1.,
|
||||
stop_at_unet_number = None
|
||||
stop_at_unet_number = None,
|
||||
distributed = False,
|
||||
):
|
||||
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
|
||||
|
||||
@@ -2118,7 +2119,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)):
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
|
||||
|
||||
with context:
|
||||
lowres_cond_img = None
|
||||
|
||||
Reference in New Issue
Block a user