Quickstart

WiDiT is a SwinIR-style DiT backbone that unifies 2D images and 3D volumes with N-D windowed attention, optional Swin shifts, and AdaLN-Zero conditioning.

  • Single model class: widit.models.WiDiT

  • Also includes a configurable widit.models.Unet for 2D/3D U-Net baselines

  • Optional timestep conditioning (pass timestep=None if unused)

  • Shared blocks for 2D/3D via N-D window partitioning

  • Presets for quick experiments in both 2D and 3D

Installation

Install using pip:

pip install widit

Or

pip install git+https://github.com/rbturnbull/widit.git

WiDiT depends on torch.

Warning

WiDiT is currently in alpha testing phase. More updates are coming soon.

Quick Start (2D)

import torch
from widit.models import WiDiT

# Example: 2D RGB input & conditioning (e.g., low-res guidance)
N, C, H, W = 2, 3, 128, 96
x      = torch.randn(N, C, H, W)
cond   = torch.randn_like(x)
t      = torch.randint(0, 1000, (N,), dtype=torch.long)  # optional

model = WiDiT(
    spatial_dim=2,
    patch_size=2,             # must divide H and W
    in_channels=C,
    hidden_size=256,          # must be divisible by num_heads and even
    depth=6,
    num_heads=8,
    window_size=8,            # can be int or (wh, ww)
    mlp_ratio=4.0,
    learn_sigma=True,         # output channels = 2*C if True
    use_conditioning=True,    # expect a conditioning image
)

# NEW CALL SIGNATURE:
# forward(input, timestep=None, *, conditioned=None)
y = model(x, t, conditioned=cond)     # (N, 2*C, H, W) if learn_sigma=True

Quick Start (3D)

import torch
from widit.models import WiDiT

# Example: 3D single-channel volumes
N, C, D, H, W = 1, 1, 64, 64, 48
x    = torch.randn(N, C, D, H, W)
cond = torch.randn_like(x)

model = WiDiT(
    spatial_dim=3,
    patch_size=2,             # must divide D/H/W
    in_channels=C,
    hidden_size=256,
    depth=4,
    num_heads=8,
    window_size=(4, 4, 4),    # can be int or (wd, wh, ww)
    mlp_ratio=4.0,
    learn_sigma=False,        # output channels = C if False
    use_conditioning=True,
)

y = model(x, timestep=None, conditioned=cond)  # (N, C, D, H, W)

Unconditioned Image Path (no second image)

import torch
from widit.models import WiDiT

N, C, H, W = 2, 3, 128, 96
x = torch.randn(N, C, H, W)
t = torch.randint(0, 1000, (N,))

model = WiDiT(
    spatial_dim=2,
    in_channels=C,
    hidden_size=256,
    depth=4,
    num_heads=8,
    patch_size=2,
    window_size=8,
    learn_sigma=True,
    use_conditioning=False,       # <-- no conditioning image expected
)

# Do NOT pass `conditioned` when use_conditioning=False
y = model(x, t)  # (N, 2*C, H, W)

Presets

Presets provide ready-made configurations for common model sizes (2D & 3D), all using patch_size=2 and Swin-style window attention:

from widit.models import PRESETS
import torch

# 2D: B, M, L, XL
model_2d = PRESETS["WiDiT2D-L"](in_channels=3, learn_sigma=True)

# 3D: B, M, L, XL
model_3d = PRESETS["WiDiT3D-M"](in_channels=1, learn_sigma=False)

# Example inputs
x2d = torch.randn(1, 3, 64, 48)
c2d = torch.randn_like(x2d)
t2d = torch.randint(0, 1000, (1,))

x3d = torch.randn(1, 1, 32, 32, 24)
c3d = torch.randn_like(x3d)
t3d = torch.randint(0, 1000, (1,))

# Run
y2d = model_2d(x2d, t2d, conditioned=c2d)
y3d = model_3d(x3d, timestep=None, conditioned=c3d)

Loading Models

Use load_model to load a saved WiDiT or Unet checkpoint. It infers the correct class from the stored config:

from widit import load_model

model = load_model("path/to/checkpoint.pt")