mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-01-29 10:24:27 +01:00
Stable Video Diffusion
This commit is contained in:
@@ -29,25 +29,14 @@ model:
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [ 1, 2, 4 ]
|
||||
ch_mult: [1, 2, 4]
|
||||
num_res_blocks: 4
|
||||
attn_resolutions: [ ]
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
|
||||
decoder_config:
|
||||
target: sgm.modules.diffusionmodules.model.Decoder
|
||||
params:
|
||||
attn_type: none
|
||||
double_z: False
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [ 1, 2, 4 ]
|
||||
num_res_blocks: 4
|
||||
attn_resolutions: [ ]
|
||||
dropout: 0.0
|
||||
params: ${model.params.encoder_config.params}
|
||||
|
||||
data:
|
||||
target: sgm.data.dataset.StableDataModuleFromConfig
|
||||
@@ -55,18 +44,18 @@ data:
|
||||
train:
|
||||
datapipeline:
|
||||
urls:
|
||||
- "DATA-PATH"
|
||||
- DATA-PATH
|
||||
pipeline_config:
|
||||
shardshuffle: 10000
|
||||
sample_shuffle: 10000
|
||||
|
||||
decoders:
|
||||
- "pil"
|
||||
- pil
|
||||
|
||||
postprocessors:
|
||||
- target: sdata.mappers.TorchVisionImageTransforms
|
||||
params:
|
||||
key: 'jpg'
|
||||
key: jpg
|
||||
transforms:
|
||||
- target: torchvision.transforms.Resize
|
||||
params:
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
model:
|
||||
base_learning_rate: 4.5e-6
|
||||
target: sgm.models.autoencoder.AutoencodingEngine
|
||||
params:
|
||||
input_key: jpg
|
||||
monitor: val/loss/rec
|
||||
disc_start_iter: 0
|
||||
|
||||
encoder_config:
|
||||
target: sgm.modules.diffusionmodules.model.Encoder
|
||||
params:
|
||||
attn_type: vanilla-xformers
|
||||
double_z: true
|
||||
z_channels: 8
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
|
||||
decoder_config:
|
||||
target: sgm.modules.diffusionmodules.model.Decoder
|
||||
params: ${model.params.encoder_config.params}
|
||||
|
||||
regularizer_config:
|
||||
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||
|
||||
loss_config:
|
||||
target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
|
||||
params:
|
||||
perceptual_weight: 0.25
|
||||
disc_start: 20001
|
||||
disc_weight: 0.5
|
||||
learn_logvar: True
|
||||
|
||||
regularization_weights:
|
||||
kl_loss: 1.0
|
||||
|
||||
data:
|
||||
target: sgm.data.dataset.StableDataModuleFromConfig
|
||||
params:
|
||||
train:
|
||||
datapipeline:
|
||||
urls:
|
||||
- DATA-PATH
|
||||
pipeline_config:
|
||||
shardshuffle: 10000
|
||||
sample_shuffle: 10000
|
||||
|
||||
decoders:
|
||||
- pil
|
||||
|
||||
postprocessors:
|
||||
- target: sdata.mappers.TorchVisionImageTransforms
|
||||
params:
|
||||
key: jpg
|
||||
transforms:
|
||||
- target: torchvision.transforms.Resize
|
||||
params:
|
||||
size: 256
|
||||
interpolation: 3
|
||||
- target: torchvision.transforms.ToTensor
|
||||
- target: sdata.mappers.Rescaler
|
||||
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
|
||||
params:
|
||||
h_key: height
|
||||
w_key: width
|
||||
|
||||
loader:
|
||||
batch_size: 8
|
||||
num_workers: 4
|
||||
|
||||
|
||||
lightning:
|
||||
strategy:
|
||||
target: pytorch_lightning.strategies.DDPStrategy
|
||||
params:
|
||||
find_unused_parameters: True
|
||||
|
||||
modelcheckpoint:
|
||||
params:
|
||||
every_n_train_steps: 5000
|
||||
|
||||
callbacks:
|
||||
metrics_over_trainsteps_checkpoint:
|
||||
params:
|
||||
every_n_train_steps: 50000
|
||||
|
||||
image_logger:
|
||||
target: main.ImageLogger
|
||||
params:
|
||||
enable_autocast: False
|
||||
batch_frequency: 1000
|
||||
max_images: 8
|
||||
increase_log_steps: True
|
||||
|
||||
trainer:
|
||||
devices: 0,
|
||||
limit_val_batches: 50
|
||||
benchmark: True
|
||||
accumulate_grad_batches: 1
|
||||
val_check_interval: 10000
|
||||
Reference in New Issue
Block a user