diff --git a/scripts/demo/sv4d_helpers.py b/scripts/demo/sv4d_helpers.py index d64266c..b0f2cbf 100755 --- a/scripts/demo/sv4d_helpers.py +++ b/scripts/demo/sv4d_helpers.py @@ -724,6 +724,7 @@ def run_img2vid( cond_view=None, decoding_t=None, cond_mv=True, + part_maps=False, ): options = version_dict["options"] H = version_dict["H"] @@ -792,6 +793,7 @@ def run_img2vid( force_cond_zero_embeddings=options.get("force_cond_zero_embeddings", None), return_latents=False, decoding_t=decoding_t, + part_maps=part_maps, ) return samples @@ -921,6 +923,7 @@ def do_sample( T=None, additional_batch_uc_fields=None, decoding_t=None, + part_maps=False, ): force_uc_zero_embeddings = default(force_uc_zero_embeddings, []) batch2model_input = default(batch2model_input, []) @@ -989,7 +992,10 @@ def do_sample( else: additional_model_inputs[k] = batch[k] - shape = (math.prod(num_samples), C, H // F, W // F) + if part_maps: + shape = (math.prod(num_samples), C * 2, H // F, W // F) + else: + shape = (math.prod(num_samples), C, H // F, W // F) randn = torch.randn(shape).to("cuda") def denoiser(input, sigma, c): diff --git a/scripts/sampling/configs/sp4d.yaml b/scripts/sampling/configs/sp4d.yaml new file mode 100755 index 0000000..cd2272d --- /dev/null +++ b/scripts/sampling/configs/sp4d.yaml @@ -0,0 +1,210 @@ +N_TIME: 4 +N_VIEW: 12 +N_FRAMES: 48 + +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.18215 + en_and_decode_n_samples_a_time: 8 + disable_first_stage_autocast: True + ckpt_path: checkpoints/sp4d.safetensors + dual_concat: True + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.Denoiser + params: + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise + + network_config: + target: sgm.modules.diffusionmodules.video_model.DualSpatialUNetWithCrossComm + params: + unet_config: + adm_in_channels: 1280 + attention_resolutions: [4, 2, 1] + channel_mult: [1, 2, 4, 4] + context_dim: 1024 + motion_context_dim: 4 + extra_ff_mix_layer: True + in_channels: 8 + legacy: False + model_channels: 320 + num_classes: sequential + num_head_channels: 64 + num_res_blocks: 2 + out_channels: 4 + replicate_time_mix_bug: True + spatial_transformer_attn_type: softmax-xformers + time_block_merge_factor: 0.0 + time_block_merge_strategy: learned_with_images + time_kernel_size: [3, 1, 1] + time_mix_legacy: False + transformer_depth: 1 + use_checkpoint: False + use_linear_in_transformer: True + use_spatial_context: True + use_spatial_transformer: True + separate_motion_merge_factor: True + use_motion_attention: True + use_3d_attention: True + use_camera_emb: True + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + + - input_key: cond_frames_without_noise + target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder + is_trainable: False + params: + n_cond_frames: ${N_TIME} + n_copies: 1 + open_clip_embedding_config: + target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder + params: + freeze: True + + - input_key: cond_frames + target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder + is_trainable: False + params: + is_ae: True + n_cond_frames: ${N_FRAMES} + n_copies: 1 + encoder_config: + target: sgm.models.autoencoder.AutoencoderKLModeOnly + params: + ddconfig: + attn_resolutions: [] + attn_type: vanilla-xformers + ch: 128 + ch_mult: [1, 2, 4, 4] + double_z: True + dropout: 0.0 + in_channels: 3 + num_res_blocks: 2 + out_ch: 3 + resolution: 256 + z_channels: 4 + embed_dim: 4 + lossconfig: + target: torch.nn.Identity + monitor: val/rec_loss + sigma_cond_config: + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler + + - input_key: polar_rad + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 512 + + - input_key: azimuth_rad + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 512 + + - input_key: cond_view + is_trainable: False + target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder + params: + is_ae: True + n_cond_frames: ${N_VIEW} + n_copies: 1 + encoder_config: + target: sgm.models.autoencoder.AutoencoderKLModeOnly + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_resolutions: [] + attn_type: vanilla-xformers + ch: 128 + ch_mult: [1, 2, 4, 4] + double_z: True + dropout: 0.0 + in_channels: 3 + num_res_blocks: 2 + out_ch: 3 + resolution: 256 + z_channels: 4 + lossconfig: + target: torch.nn.Identity + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler + + - input_key: cond_motion + is_trainable: False + target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder + params: + is_ae: True + n_cond_frames: ${N_TIME} + n_copies: 1 + encoder_config: + target: sgm.models.autoencoder.AutoencoderKLModeOnly + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_resolutions: [] + attn_type: vanilla-xformers + ch: 128 + ch_mult: [1, 2, 4, 4] + double_z: True + dropout: 0.0 + in_channels: 3 + num_res_blocks: 2 + out_ch: 3 + resolution: 256 + z_channels: 4 + lossconfig: + target: torch.nn.Identity + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler + + first_stage_config: + target: sgm.models.autoencoder.AutoencodingEngine + params: + loss_config: + target: torch.nn.Identity + regularizer_config: + target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer + encoder_config: + target: torch.nn.Identity + decoder_config: + target: sgm.modules.diffusionmodules.model.DecoderDual + params: + attn_resolutions: [] + attn_type: vanilla-xformers + ch: 128 + ch_mult: [1, 2, 4, 4] + double_z: True + dropout: 0.0 + in_channels: 3 + num_res_blocks: 2 + out_ch: 3 + resolution: 256 + z_channels: 4 + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler + params: + num_steps: 50 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization + params: + sigma_max: 500.0 + guider_config: + target: sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider + params: + max_scale: 1.5 + min_scale: 1.5 + num_frames: ${N_FRAMES} + num_views: ${N_VIEW} + additional_cond_keys: [ cond_view, cond_motion ] diff --git a/scripts/sampling/simple_video_sample_sp4d.py b/scripts/sampling/simple_video_sample_sp4d.py new file mode 100755 index 0000000..7f5a30f --- /dev/null +++ b/scripts/sampling/simple_video_sample_sp4d.py @@ -0,0 +1,198 @@ +import os +import sys +from glob import glob +from typing import List, Optional + +from tqdm import tqdm + +sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../"))) +import numpy as np +import torch +from fire import Fire +from scripts.demo.sv4d_helpers import ( + load_model, + preprocess_video, + read_video, + run_img2vid, + save_video, +) +from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder + +sp4d_configs = { + "sp4d": { + "T": 4, # number of frames per sample + "V": 12, # number of views per sample + "model_config": "scripts/sampling/configs/sp4d.yaml", + "version_dict": { + "T": 48, + "options": { + "discretization": 1, + "cfg": 3.0, + "min_cfg": 1.5, + "num_views": 12, + "sigma_min": 0.002, + "sigma_max": 700.0, + "rho": 7.0, + "guider": 2, + "force_uc_zero_embeddings": [ + "cond_frames", + "cond_frames_without_noise", + "cond_view", + "cond_motion", + ], + "additional_guider_kwargs": { + "additional_cond_keys": ["cond_view", "cond_motion"] + }, + }, + }, + }, +} + + +def sample( + input_path: str = "assets/sv4d_videos/camel.gif", # Can either be image file or folder with image files + model_path: Optional[str] = "checkpoints/sp4d.safetensors", + output_folder: Optional[str] = "outputs", + num_steps: Optional[int] = 50, + img_size: int = 576, # image resolution + n_frames: int = 4, # number of input and output video frames + seed: int = 23, + encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary. + decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. + device: str = "cuda", + elevations_deg: Optional[List[float]] = 0.0, + azimuths_deg: Optional[List[float]] = None, + image_frame_ratio: Optional[float] = 0.9, + verbose: Optional[bool] = False, + remove_bg: bool = False, +): + """ + Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each + image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`. + """ + # Set model config + assert os.path.basename(model_path) in [ + "sp4d.safetensors", + ] + sp4d_model = os.path.splitext(os.path.basename(model_path))[0] + config = sp4d_configs[sp4d_model] + print(sp4d_model, config) + T = config["T"] + V = config["V"] + model_config = config["model_config"] + version_dict = config["version_dict"] + F = 8 # vae factor to downsize image->latent + C = 4 + H, W = img_size, img_size + n_views = V + 1 # number of output video views (1 input view + 8 novel views) + subsampled_views = np.arange(n_views) + version_dict["H"] = H + version_dict["W"] = W + version_dict["C"] = C + version_dict["f"] = F + version_dict["options"]["num_steps"] = num_steps + + torch.manual_seed(seed) + output_folder = os.path.join(output_folder, sp4d_model) + os.makedirs(output_folder, exist_ok=True) + + # Read input video frames i.e. images at view 0 + print(f"Reading {input_path}") + base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // (n_frames + 1) + processed_input_path = preprocess_video( + input_path, + remove_bg=remove_bg, + n_frames=n_frames, + W=W, + H=H, + output_folder=output_folder, + image_frame_ratio=image_frame_ratio, + base_count=base_count, + ) + images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device) + images_t0 = torch.zeros(n_views, 3, H, W).float().to(device) + + # Get camera viewpoints + if isinstance(elevations_deg, float) or isinstance(elevations_deg, int): + elevations_deg = [elevations_deg] * n_views + assert ( + len(elevations_deg) == n_views + ), f"Please provide 1 value, or a list of {n_views} values for elevations_deg! Given {len(elevations_deg)}" + if azimuths_deg is None: + azimuths_deg = np.linspace(0, 360, n_views + 1)[1:] % 360 + assert ( + len(azimuths_deg) == n_views + ), f"Please provide a list of {n_views} values for azimuths_deg! Given {len(azimuths_deg)}" + polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg]) + azimuths_rad = np.array( + [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg] + ) + + # Initialize image matrix + img_matrix = [[None] * n_views for _ in range(n_frames)] + for i, v in enumerate(subsampled_views): + img_matrix[0][i] = images_t0[v].unsqueeze(0) + for t in range(n_frames): + img_matrix[t][0] = images_v0[t] + + # Load SV4D++ model + model, _ = load_model( + model_config, + device, + version_dict["T"], + num_steps, + verbose, + model_path, + ) + model.en_and_decode_n_samples_a_time = decoding_t + for emb in model.conditioner.embedders: + if isinstance(emb, VideoPredictionEmbedderWithEncoder): + emb.en_and_decode_n_samples_a_time = encoding_t + + # Sampling novel-view videos + v0 = 0 + view_indices = np.arange(V) + 1 + t0_list = range(0, n_frames - T + 1, T - 1) + for t0 in tqdm(t0_list): + if t0 + T > n_frames: + t0 = n_frames - T + frame_indices = t0 + np.arange(T) + print(f"Sampling frames {frame_indices}") + image = img_matrix[t0][v0] + cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0) + cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0) + polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() + azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() + polars = (polars - polars_rad[v0] + torch.pi / 2) % (torch.pi * 2) + azims = (azims - azimuths_rad[v0]) % (torch.pi * 2) + samples = run_img2vid( + version_dict, + model, + image, + seed, + polars, + azims, + cond_motion, + cond_view, + decoding_t, + cond_mv=False, + part_maps=True, + ) + samples = samples.view(T, V, 3, H, -1) + + for i, t in enumerate(frame_indices): + for j, v in enumerate(view_indices): + img_matrix[t][v] = samples[i, j][None] * 2 - 1 + + # Save output videos + for t in frame_indices: + vid_file = os.path.join(output_folder, f"{base_count:06d}_{t:03d}.mp4") + print(f"Saving {vid_file}") + save_video( + vid_file, + [img_matrix[t][v] for v in range(1, n_views) if img_matrix[t][v] is not None], + ) + + +if __name__ == "__main__": + Fire(sample) diff --git a/sgm/models/diffusion.py b/sgm/models/diffusion.py index 2f3efd3..73294c6 100644 --- a/sgm/models/diffusion.py +++ b/sgm/models/diffusion.py @@ -38,6 +38,7 @@ class DiffusionEngine(pl.LightningModule): no_cond_log: bool = False, compile_model: bool = False, en_and_decode_n_samples_a_time: Optional[int] = None, + dual_concat: bool = False, ): super().__init__() self.log_keys = log_keys @@ -47,7 +48,7 @@ class DiffusionEngine(pl.LightningModule): ) model = instantiate_from_config(network_config) self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( - model, compile_model=compile_model + model, compile_model=compile_model, dual_concat=dual_concat ) self.denoiser = instantiate_from_config(denoiser_config) diff --git a/sgm/modules/diffusionmodules/model.py b/sgm/modules/diffusionmodules/model.py index 4cf9d92..574a89f 100644 --- a/sgm/modules/diffusionmodules/model.py +++ b/sgm/modules/diffusionmodules/model.py @@ -746,3 +746,170 @@ class Decoder(nn.Module): if self.tanh_out: h = torch.tanh(h) return h + + +class DecoderDual(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + logpy.info( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + make_attn_cls = self._make_attn() + make_resblock_cls = self._make_resblock() + make_conv_cls = self._make_conv() + + # z to block_in (处理单个 latent) + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = make_resblock_cls( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type) + self.mid.block_2 = make_resblock_cls( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + make_resblock_cls( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn_cls(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = make_conv_cls( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def _make_attn(self) -> Callable: + return make_attn + + def _make_resblock(self) -> Callable: + return ResnetBlock + + def _make_conv(self) -> Callable: + return torch.nn.Conv2d + + def get_last_layer(self, **kwargs): + return self.conv_out.weight + + def forward(self, z, **kwargs): + """ + 输入 z 的形状应为 (B, 2 * z_channels, H, W) + - 其中前一半通道为第一个 latent,后一半通道为第二个 latent + - 分离后分别解码,最终在 W 维度拼接 + """ + # 断言检查,确保输入的通道数是 2 倍的 z_channels + assert ( + z.shape[1] == 2 * self.z_shape[1] + ), f"Expected {2 * self.z_shape[1]} channels, got {z.shape[1]}" + + # 分割 latent 为两个部分 + z1, z2 = torch.chunk(z, 2, dim=1) # 按照通道维度 (C) 切分 + + # 分别解码 + img1 = self.decode_single(z1, **kwargs) + img2 = self.decode_single(z2, **kwargs) + + # 沿着 W 维度拼接 + output = torch.cat([img1, img2], dim=-1) # 在 width 维度拼接 + + return output + + def decode_single(self, z, **kwargs): + """解码单个 latent 到一张图像""" + self.last_z_shape = z.shape + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, None, **kwargs) + h = self.mid.attn_1(h, **kwargs) + h = self.mid.block_2(h, None, **kwargs) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, None, **kwargs) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, **kwargs) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h, **kwargs) + if self.tanh_out: + h = torch.tanh(h) + + return h + \ No newline at end of file diff --git a/sgm/modules/diffusionmodules/video_model.py b/sgm/modules/diffusionmodules/video_model.py index 9cee066..ab6326f 100644 --- a/sgm/modules/diffusionmodules/video_model.py +++ b/sgm/modules/diffusionmodules/video_model.py @@ -13,6 +13,7 @@ from ...modules.spacetime_attention import ( from ...util import default from .util import AlphaBlender, get_alpha +import torch class VideoResBlock(ResBlock): def __init__( @@ -1252,3 +1253,157 @@ class SpatialUNetModelWithTime(nn.Module): ) h = h.type(x.dtype) return self.out(h) + + +class CrossNetworkLayer(nn.Module): + def __init__(self, feature_dim: int): + super().__init__() + self.fusion_conv = nn.Sequential( + nn.Conv2d(feature_dim * 2, feature_dim, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(feature_dim, feature_dim, kernel_size=1), + ) + + def forward(self, h1: torch.Tensor, h2: torch.Tensor): + """ + h1, h2: (B, C, H, W) + return: (out1, out2), (B, C, H, W) + """ + fused_input = torch.cat([h1, h2], dim=1) # (B, 2C, H, W) + fused_output = self.fusion_conv(fused_input) # (B, C, H, W) + out1 = fused_output + h1 + out2 = fused_output + h2 + return out1, out2 + + +class DualSpatialUNetWithCrossComm(nn.Module): + def __init__(self, unet_config): + super().__init__() + self.num_classes = unet_config["num_classes"] + self.model_channels = unet_config["model_channels"] + + self.net1 = SpatialUNetModelWithTime(**unet_config) + self.net2 = SpatialUNetModelWithTime(**unet_config) + + self.input_cross_layers = nn.ModuleList() + for block in self.net1.input_blocks: + out_ch = self._get_block_out_channels(block) + self.input_cross_layers.append(CrossNetworkLayer(feature_dim=out_ch)) + + middle_out_ch = self._get_block_out_channels(self.net1.middle_block) + self.middle_cross = CrossNetworkLayer(feature_dim=middle_out_ch) + + self.output_cross_layers = nn.ModuleList() + for block in self.net1.output_blocks: + out_ch = self._get_block_out_channels(block) + self.output_cross_layers.append(CrossNetworkLayer(feature_dim=out_ch)) + + def _get_block_out_channels(self, block: nn.Module) -> int: + mod_list = list(block.children()) + for m in reversed(mod_list): + if hasattr(m, "out_channels"): + return m.out_channels + + if isinstance( + m, + (SpatialTransformer, PostHocSpatialTransformerWithTimeMixingAndMotion), + ): + return m.in_channels + + if isinstance(m, nn.Conv2d): + return m.out_channels + + raise ValueError(f"Cannot determine out_channels from block: {block}") + + def forward( + self, + x: th.Tensor, + timesteps: th.Tensor, + context: Optional[th.Tensor] = None, + y: Optional[th.Tensor] = None, + cam: Optional[th.Tensor] = None, + time_context: Optional[th.Tensor] = None, + num_video_frames: Optional[int] = None, + image_only_indicator: Optional[th.Tensor] = None, + cond_view: Optional[th.Tensor] = None, + cond_motion: Optional[th.Tensor] = None, + time_step: Optional[int] = None, + ): + + # ============ encoder ============ + h1, h2 = x[:, : x.shape[1] // 2], x[:, x.shape[1] // 2 :] + + encoder_feats1 = [] + encoder_feats2 = [] + + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional -> no, relax this TODO" + + t_emb = timestep_embedding( + timesteps, self.model_channels, repeat_only=False + ) # 21 x 320 + + emb = self.net1.time_embed(t_emb) # 21 x 1280 + time = str(timesteps[0].data.cpu().numpy()) + + if self.num_classes is not None: + assert y.shape[0] == h1.shape[0] + emb = emb + self.net1.label_emb(y) # 21 x 1280 + + filtered_args = { + "emb": emb, + "context": context, + "cam": cam, + "cond_view": cond_view, + "cond_motion": cond_motion, + "time_context": time_context, + "num_video_frames": num_video_frames, + "image_only_indicator": image_only_indicator, + "time_step": time_step, + } + + for i, (block1, block2) in enumerate( + zip(self.net1.input_blocks, self.net2.input_blocks) + ): + h1 = block1(h1, name="encoder_{}_{}".format(time, i), **filtered_args) + h2 = block2(h2, name="encoder_{}_{}".format(time, i), **filtered_args) + + # cross + h1, h2 = self.input_cross_layers[i](h1, h2) + + encoder_feats1.append(h1) + encoder_feats2.append(h2) + + # ============ middle block ============ + h1 = self.net1.middle_block( + h1, name="middle_{}_0".format(time, i), **filtered_args + ) + h2 = self.net2.middle_block( + h2, name="middle_{}_0".format(time, i), **filtered_args + ) + + # cross + h1, h2 = self.middle_cross(h1, h2) + + # ============ decoder ============ + for i, (block1, block2) in enumerate( + zip(self.net1.output_blocks, self.net2.output_blocks) + ): + skip1 = encoder_feats1.pop() + skip2 = encoder_feats2.pop() + h1 = torch.cat([h1, skip1], dim=1) + h2 = torch.cat([h2, skip2], dim=1) + + h1 = block1(h1, name="decoder_{}_{}".format(time, i), **filtered_args) + h2 = block2(h2, name="decoder_{}_{}".format(time, i), **filtered_args) + + # cross + h1, h2 = self.output_cross_layers[i](h1, h2) + + # ============ output ============ + out1 = self.net1.out(h1) # shape: (B, out_channels, H, W) + out2 = self.net2.out(h2) # same shape + out = torch.cat([out1, out2], dim=1) + + return out \ No newline at end of file diff --git a/sgm/modules/diffusionmodules/wrappers.py b/sgm/modules/diffusionmodules/wrappers.py index 23c7d07..385545c 100644 --- a/sgm/modules/diffusionmodules/wrappers.py +++ b/sgm/modules/diffusionmodules/wrappers.py @@ -6,7 +6,7 @@ OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" class IdentityWrapper(nn.Module): - def __init__(self, diffusion_model, compile_model: bool = False): + def __init__(self, diffusion_model, compile_model: bool = False, dual_concat: bool = False): super().__init__() compile = ( torch.compile @@ -15,6 +15,7 @@ class IdentityWrapper(nn.Module): else lambda x: x ) self.diffusion_model = compile(diffusion_model) + self.dual_concat = dual_concat def forward(self, *args, **kwargs): return self.diffusion_model(*args, **kwargs) @@ -24,7 +25,14 @@ class OpenAIWrapper(IdentityWrapper): def forward( self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs ) -> torch.Tensor: - x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) + if self.dual_concat: + x_1 = x[:, : x.shape[1] // 2] + x_2 = x[:, x.shape[1] // 2 :] + x_1 = torch.cat((x_1, c.get("concat", torch.Tensor([]).type_as(x_1))), dim=1) + x_2 = torch.cat((x_2, c.get("concat", torch.Tensor([]).type_as(x_2))), dim=1) + x = torch.cat((x_1, x_2), dim=1) + else: + x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) if "cond_view" in c: return self.diffusion_model( x,