Training

import torch
from torch.optim import AdamW
from widit.models import WiDiT

device = "cuda" if torch.cuda.is_available() else "cpu"

model = WiDiT(
    spatial_dim=2,
    in_channels=3,
    hidden_size=256,
    depth=6,
    num_heads=8,
    patch_size=2,
    window_size=8,
    learn_sigma=True,
    use_conditioning=True,
).to(device)

opt = AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

for step in range(100):
    x    = torch.randn(8, 3, 128, 96, device=device)
    cond = torch.randn_like(x)
    t    = torch.randint(0, 1000, (x.shape[0],), device=device)

    y = model(x, t, conditioned=cond)          # (N, 6, H, W) here (mean+sigma for C=3)
    target = torch.randn_like(y)

    loss = torch.nn.functional.mse_loss(y, target)
    opt.zero_grad(set_to_none=True)
    loss.backward()
    opt.step()

Tips & Gotchas

  • Patch size equality in unpatchify: currently the unpatchify path enforces equal patch size along all axes (e.g., patch_size=2 or (2,2,2)). Mixed per-axis patch sizes for output reconstruction are not supported yet.

  • Token grid divisibility: ensure every spatial dimension is divisible by patch_size. Window attention will pad internally to complete windows and crop back, but patch embedding is stride-based.

  • Timestep optional: pass timestep=None to run the model without diffusion conditioning (AdaLN defaults reduce to a vanilla transformer residual path).

  • Conditioning toggle: if you don’t have a conditioning image, set use_conditioning=False and call model(x, timestep)

Reference Shapes

2D

  • Input: (N, C, H, W)

  • Output: (N, 2*C, H, W) if learn_sigma=True, else (N, C, H, W)

3D

  • Input: (N, C, D, H, W)

  • Output: (N, 2*C, D, H, W) if learn_sigma=True, else (N, C, D, H, W)