mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
makes more sense for blur augmentation to happen before the upsampling
This commit is contained in:
@@ -1075,14 +1075,14 @@ class LowresConditioner(nn.Module):
|
|||||||
if self.training and self.downsample_first and exists(downsample_image_size):
|
if self.training and self.downsample_first and exists(downsample_image_size):
|
||||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
||||||
|
|
||||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
# when training, blur the low resolution conditional image
|
# when training, blur the low resolution conditional image
|
||||||
blur_sigma = default(blur_sigma, self.blur_sigma)
|
blur_sigma = default(blur_sigma, self.blur_sigma)
|
||||||
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
||||||
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||||
|
|
||||||
|
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
||||||
|
|
||||||
return cond_fmap
|
return cond_fmap
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
|
|||||||
Reference in New Issue
Block a user