Compare commits

..

1 Commits

Author SHA1 Message Date
Phil Wang
f1739267e4 simplify more 2022-05-14 17:13:13 -07:00
2 changed files with 6 additions and 7 deletions

View File

@@ -265,9 +265,8 @@ class DiffusionPriorTrainer(nn.Module):
with autocast(enabled = self.amp): with autocast(enabled = self.amp):
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs) loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
loss = loss * chunk_size_frac total_loss += loss.item() * chunk_size_frac
total_loss += loss.item() self.scaler.scale(loss * chunk_size_frac).backward()
self.scaler.scale(loss).backward()
return total_loss return total_loss
@@ -389,8 +388,8 @@ class DecoderTrainer(nn.Module):
with autocast(enabled = self.amp): with autocast(enabled = self.amp):
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs) loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac total_loss += loss.item() * chunk_size_frac
total_loss += loss.item()
self.scale(loss, unet_number = unet_number).backward() self.scale(loss * chunk_size_frac, unet_number = unet_number).backward()
return total_loss return total_loss

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.28', version = '0.2.27',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',