Quickstart
WiDiT is a SwinIR-style DiT backbone that unifies 2D images and 3D volumes with N-D windowed attention, optional Swin shifts, and AdaLN-Zero conditioning.
Single model class:
widit.models.WiDiTAlso includes a configurable
widit.models.Unetfor 2D/3D U-Net baselinesOptional timestep conditioning (pass
timestep=Noneif unused)Shared blocks for 2D/3D via N-D window partitioning
Presets for quick experiments in both 2D and 3D
Installation
Install using pip:
pip install widit
Or
pip install git+https://github.com/rbturnbull/widit.git
WiDiT depends on torch.
Warning
WiDiT is currently in alpha testing phase. More updates are coming soon.
Quick Start (2D)
import torch
from widit.models import WiDiT
# Example: 2D RGB input & conditioning (e.g., low-res guidance)
N, C, H, W = 2, 3, 128, 96
x = torch.randn(N, C, H, W)
cond = torch.randn_like(x)
t = torch.randint(0, 1000, (N,), dtype=torch.long) # optional
model = WiDiT(
spatial_dim=2,
patch_size=2, # must divide H and W
in_channels=C,
hidden_size=256, # must be divisible by num_heads and even
depth=6,
num_heads=8,
window_size=8, # can be int or (wh, ww)
mlp_ratio=4.0,
learn_sigma=True, # output channels = 2*C if True
use_conditioning=True, # expect a conditioning image
)
# NEW CALL SIGNATURE:
# forward(input, timestep=None, *, conditioned=None)
y = model(x, t, conditioned=cond) # (N, 2*C, H, W) if learn_sigma=True
Quick Start (3D)
import torch
from widit.models import WiDiT
# Example: 3D single-channel volumes
N, C, D, H, W = 1, 1, 64, 64, 48
x = torch.randn(N, C, D, H, W)
cond = torch.randn_like(x)
model = WiDiT(
spatial_dim=3,
patch_size=2, # must divide D/H/W
in_channels=C,
hidden_size=256,
depth=4,
num_heads=8,
window_size=(4, 4, 4), # can be int or (wd, wh, ww)
mlp_ratio=4.0,
learn_sigma=False, # output channels = C if False
use_conditioning=True,
)
y = model(x, timestep=None, conditioned=cond) # (N, C, D, H, W)
Unconditioned Image Path (no second image)
import torch
from widit.models import WiDiT
N, C, H, W = 2, 3, 128, 96
x = torch.randn(N, C, H, W)
t = torch.randint(0, 1000, (N,))
model = WiDiT(
spatial_dim=2,
in_channels=C,
hidden_size=256,
depth=4,
num_heads=8,
patch_size=2,
window_size=8,
learn_sigma=True,
use_conditioning=False, # <-- no conditioning image expected
)
# Do NOT pass `conditioned` when use_conditioning=False
y = model(x, t) # (N, 2*C, H, W)
Presets
Presets provide ready-made configurations for common model sizes (2D & 3D), all
using patch_size=2 and Swin-style window attention:
from widit.models import PRESETS
import torch
# 2D: B, M, L, XL
model_2d = PRESETS["WiDiT2D-L"](in_channels=3, learn_sigma=True)
# 3D: B, M, L, XL
model_3d = PRESETS["WiDiT3D-M"](in_channels=1, learn_sigma=False)
# Example inputs
x2d = torch.randn(1, 3, 64, 48)
c2d = torch.randn_like(x2d)
t2d = torch.randint(0, 1000, (1,))
x3d = torch.randn(1, 1, 32, 32, 24)
c3d = torch.randn_like(x3d)
t3d = torch.randint(0, 1000, (1,))
# Run
y2d = model_2d(x2d, t2d, conditioned=c2d)
y3d = model_3d(x3d, timestep=None, conditioned=c3d)
Loading Models
Use load_model to load a saved WiDiT or Unet checkpoint. It infers the
correct class from the stored config:
from widit import load_model
model = load_model("path/to/checkpoint.pt")