API Overview
WiDiT(
*,
spatial_dim: int, # 2 (images) or 3 (volumes)
patch_size: int | Sequence[int] = 2, # per-axis tuple allowed
in_channels: int = 1,
hidden_size: int = 768, # even; divisible by num_heads
depth: int = 12,
num_heads: int = 12,
window_size: int | Sequence[int] = 8, # per-axis tuple allowed
mlp_ratio: float = 4.0,
learn_sigma: bool = True,
use_conditioning: bool = True, # expect a second image unless set False
)
forward(
input_tensor: torch.Tensor, # (N, C, *spatial)
timestep: torch.Tensor | None = None, # (N,) or None
*, # keyword-only from here
conditioned: torch.Tensor | None = None, # (N, C, *spatial) if use_conditioning=True
) -> torch.Tensor # (N, out_channels, *spatial)
Shapes & contracts
*spatialis(H, W)for 2D and(D, H, W)for 3D.patch_sizemust evenly divide each spatial dimension.window_sizecan be an int or a per-axis tuple; internal padding ensures full windows (removed before returning).hidden_sizemust be even (split across the two patch embedders whenuse_conditioning=True) and divisible bynum_heads.If
learn_sigma=True, output channels =2 * in_channels(mean + sigma style).If
use_conditioning=True, you must passconditioned=...toforward. Ifuse_conditioning=False, passingconditionedwill raise an assertion.
Conditioning
timestepis optional. PassNoneto disable AdaLN conditioning (the blocks reduce to standard LN + residual).If provided, the model uses
widit.timesteps.TimestepEmbedderto produce a per-sample vector projected to the token dimension.
Building Blocks
These are used internally, but you can also import them for custom stacks.
widit.blocks.WiDiTBlock– N-D windowed MSA + MLP with AdaLN-Zerowidit.blocks.WiDiTFinalLayer– final projection head with AdaLN-Zerowidit.patch.PatchEmbed– unified 2D/3D patch embedding (withinit_weights())widit.timesteps.TimestepEmbedder– sinusoidal → MLP conditioning (withinit_weights())
All of the above expose init_weights() so the model can initialize components
cleanly (adaLN-Zero policy for blocks & head; Xavier for projections; Normal for
timestep MLP weights).