mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
remove unnecessary decoder setting, and if not unconditional, always make sure the first unet is condition-able on text
This commit is contained in:
@@ -1766,14 +1766,13 @@ class Decoder(nn.Module):
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||
clip_denoised = True,
|
||||
clip_x_start = True,
|
||||
clip_adapter_overrides = dict(),
|
||||
learned_variance = True,
|
||||
learned_variance_constrain_frac = False,
|
||||
vb_loss_weight = 0.001,
|
||||
unconditional = False,
|
||||
unconditional = False, # set to True for generating images without conditioning
|
||||
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||
use_dynamic_thres = False, # from the Imagen paper
|
||||
dynamic_thres_percentile = 0.9,
|
||||
@@ -1852,8 +1851,8 @@ class Decoder(nn.Module):
|
||||
|
||||
one_unet = one_unet.cast_model_parameters(
|
||||
lowres_cond = not is_first,
|
||||
cond_on_image_embeds = is_first and not unconditional,
|
||||
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
|
||||
cond_on_image_embeds = not unconditional and is_first,
|
||||
cond_on_text_encodings = not unconditional and (is_first or one_unet.cond_on_text_encodings),
|
||||
channels = unet_channels,
|
||||
channels_out = unet_channels_out
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user