mirror of
https://github.com/aljazceru/InvSR.git
synced 2025-12-16 22:04:20 +01:00
265 lines
6.7 KiB
YAML
265 lines
6.7 KiB
YAML
trainer:
|
|
target: trainer.TrainerSDTurboSR
|
|
|
|
sd_pipe:
|
|
target: diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline
|
|
num_train_steps: 1000
|
|
enable_grad_checkpoint: True
|
|
compile: False
|
|
vae_split: 8
|
|
params:
|
|
pretrained_model_name_or_path: stabilityai/sd-turbo
|
|
cache_dir: weights
|
|
use_safetensors: True
|
|
torch_dtype: torch.float16
|
|
|
|
llpips:
|
|
target: latent_lpips.lpips.LPIPS
|
|
ckpt_path: weights/vgg16_sdturbo_lpips.pth
|
|
compile: False
|
|
params:
|
|
pretrained: False
|
|
net: vgg16
|
|
lpips: True
|
|
spatial: False
|
|
pnet_rand: False
|
|
pnet_tune: True
|
|
use_dropout: True
|
|
eval_mode: True
|
|
latent: True
|
|
in_chans: 4
|
|
verbose: True
|
|
|
|
model:
|
|
target: diffusers.models.autoencoders.NoisePredictor
|
|
ckpt_start_path: ~ # only used for training the intermidiate model
|
|
ckpt_path: ~ # For initializing
|
|
compile: False
|
|
params:
|
|
in_channels: 3
|
|
down_block_types:
|
|
- AttnDownBlock2D
|
|
- AttnDownBlock2D
|
|
up_block_types:
|
|
- AttnUpBlock2D
|
|
- AttnUpBlock2D
|
|
block_out_channels:
|
|
- 256 # 192, 256
|
|
- 512 # 384, 512
|
|
layers_per_block:
|
|
- 3
|
|
- 3
|
|
act_fn: silu
|
|
latent_channels: 4
|
|
norm_num_groups: 32
|
|
sample_size: 128
|
|
mid_block_add_attention: True
|
|
resnet_time_scale_shift: default
|
|
temb_channels: 512
|
|
attention_head_dim: 64
|
|
freq_shift: 0
|
|
flip_sin_to_cos: True
|
|
double_z: True
|
|
|
|
discriminator:
|
|
target: diffusers.models.unets.unet_2d_condition_discriminator.UNet2DConditionDiscriminator
|
|
enable_grad_checkpoint: True
|
|
compile: False
|
|
params:
|
|
sample_size: 64
|
|
in_channels: 4
|
|
center_input_sample: False
|
|
flip_sin_to_cos: True
|
|
freq_shift: 0
|
|
down_block_types:
|
|
- DownBlock2D
|
|
- CrossAttnDownBlock2D
|
|
- CrossAttnDownBlock2D
|
|
mid_block_type: UNetMidBlock2DCrossAttn
|
|
up_block_types:
|
|
- CrossAttnUpBlock2D
|
|
- CrossAttnUpBlock2D
|
|
- UpBlock2D
|
|
only_cross_attention: False
|
|
block_out_channels:
|
|
- 128
|
|
- 256
|
|
- 512
|
|
layers_per_block:
|
|
- 1
|
|
- 2
|
|
- 2
|
|
downsample_padding: 1
|
|
mid_block_scale_factor: 1
|
|
dropout: 0.0
|
|
act_fn: silu
|
|
norm_num_groups: 32
|
|
norm_eps: 1e-5
|
|
cross_attention_dim: 1024
|
|
transformer_layers_per_block: 1
|
|
reverse_transformer_layers_per_block: ~
|
|
encoder_hid_dim: ~
|
|
encoder_hid_dim_type: ~
|
|
attention_head_dim:
|
|
- 8
|
|
- 16
|
|
- 16
|
|
num_attention_heads: ~
|
|
dual_cross_attention: False
|
|
use_linear_projection: False
|
|
class_embed_type: ~
|
|
addition_embed_type: text
|
|
addition_time_embed_dim: 256
|
|
num_class_embeds: ~
|
|
upcast_attention: ~
|
|
resnet_time_scale_shift: default
|
|
resnet_skip_time_act: False
|
|
resnet_out_scale_factor: 1.0
|
|
time_embedding_type: positional
|
|
time_embedding_dim: ~
|
|
time_embedding_act_fn: ~
|
|
timestep_post_act: ~
|
|
time_cond_proj_dim: ~
|
|
conv_in_kernel: 3
|
|
conv_out_kernel: 3
|
|
projection_class_embeddings_input_dim: 2560
|
|
attention_type: default
|
|
class_embeddings_concat: False
|
|
mid_block_only_cross_attention: ~
|
|
cross_attention_norm: ~
|
|
addition_embed_type_num_heads: 64
|
|
|
|
degradation:
|
|
sf: 4
|
|
# the first degradation process
|
|
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
|
resize_range: [0.15, 1.5]
|
|
gaussian_noise_prob: 0.5
|
|
noise_range: [1, 30]
|
|
poisson_scale_range: [0.05, 3.0]
|
|
gray_noise_prob: 0.4
|
|
jpeg_range: [30, 95]
|
|
|
|
# the second degradation process
|
|
second_order_prob: 0.5
|
|
second_blur_prob: 0.8
|
|
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
|
resize_range2: [0.3, 1.2]
|
|
gaussian_noise_prob2: 0.5
|
|
noise_range2: [1, 25]
|
|
poisson_scale_range2: [0.05, 2.5]
|
|
gray_noise_prob2: 0.4
|
|
jpeg_range2: [30, 95]
|
|
|
|
gt_size: 512
|
|
resize_back: False
|
|
use_sharp: False
|
|
|
|
data:
|
|
train:
|
|
type: realesrgan
|
|
params:
|
|
data_source:
|
|
source1:
|
|
root_path: /mnt/sfs-common/zsyue/database/FFHQ
|
|
image_path: images1024
|
|
moment_path: ~
|
|
text_path: ~
|
|
im_ext: png
|
|
length: 20000
|
|
source2:
|
|
root_path: /mnt/sfs-common/zsyue/database/LSDIR/train
|
|
image_path: images
|
|
moment_path: ~
|
|
text_path: ~
|
|
im_ext: png
|
|
max_token_length: 77 # 77
|
|
io_backend:
|
|
type: disk
|
|
blur_kernel_size: 21
|
|
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
|
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
|
sinc_prob: 0.1
|
|
blur_sigma: [0.2, 3.0]
|
|
betag_range: [0.5, 4.0]
|
|
betap_range: [1, 2.0]
|
|
|
|
blur_kernel_size2: 15
|
|
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
|
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
|
sinc_prob2: 0.1
|
|
blur_sigma2: [0.2, 1.5]
|
|
betag_range2: [0.5, 4.0]
|
|
betap_range2: [1, 2.0]
|
|
|
|
final_sinc_prob: 0.8
|
|
|
|
gt_size: ${degradation.gt_size}
|
|
use_hflip: True
|
|
use_rot: False
|
|
random_crop: True
|
|
val:
|
|
type: base
|
|
params:
|
|
dir_path: /mnt/sfs-common/zsyue/projects/DifInv/SR/testingdata/imagenet512/lq
|
|
transform_type: default
|
|
transform_kwargs:
|
|
mean: 0.0
|
|
std: 1.0
|
|
extra_dir_path: /mnt/sfs-common/zsyue/projects/DifInv/SR/testingdata/imagenet512/gt
|
|
extra_transform_type: default
|
|
extra_transform_kwargs:
|
|
mean: 0.0
|
|
std: 1.0
|
|
im_exts: png
|
|
length: 16
|
|
recursive: False
|
|
|
|
train:
|
|
# predict started inverser
|
|
start_mode: True
|
|
# learning rate
|
|
lr: 5e-5 # learning rate
|
|
lr_min: 5e-5 # learning rate
|
|
lr_schedule: ~
|
|
warmup_iterations: 2000
|
|
# discriminator
|
|
lr_dis: 5e-5 # learning rate for dicriminator
|
|
weight_decay_dis: 1e-3 # weight decay for dicriminator
|
|
dis_init_iterations: 10000 # iterations used for updating the discriminator
|
|
dis_update_freq: 1
|
|
# dataloader
|
|
batch: 64
|
|
microbatch: 16
|
|
num_workers: 4
|
|
prefetch_factor: 2
|
|
use_text: True
|
|
# optimization settings
|
|
weight_decay: 0
|
|
ema_rate: 0.999
|
|
iterations: 200000 # total iterations
|
|
# logging
|
|
save_freq: 5000
|
|
log_freq: [200, 5000] # [training loss, training images, val images]
|
|
local_logging: True # manually save images
|
|
tf_logging: False # tensorboard logging
|
|
# loss
|
|
loss_type: L2
|
|
loss_coef:
|
|
ldif: 1.0
|
|
timesteps: [200, 100]
|
|
num_inference_steps: 5
|
|
# mixed precision
|
|
use_amp: True
|
|
use_fsdp: False
|
|
# random seed
|
|
seed: 123456
|
|
global_seeding: False
|
|
noise_detach: False
|
|
|
|
validate:
|
|
batch: 2
|
|
use_ema: True
|
|
log_freq: 4 # logging frequence
|
|
val_y_channel: True
|