import marimo as mo
CIFAR Demonstration
This notebook demonstrates how to use the hierarchicalsoftmax
module to train a neural network on the CIFAR dataset.
First, choose the hyperparameters.
cifar_radio = mo.ui.radio(options=["10","100"], value=mo.cli_args().get("cifar") or "100", label="CIFAR Dataset")
batch_size_input = mo.ui.number(value=mo.cli_args().get("batch") or 32, label="Batch Size")
epochs_input = mo.ui.number(value=mo.cli_args().get("batch") or 10, label="Epochs")
mo.vstack([cifar_radio, epochs_input, batch_size_input])
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
assert cifar_radio.value in ["10","100"]
batch_size = batch_size_input.value
epochs = epochs_input.value
cifar_dataset = datasets.CIFAR10 if cifar_radio.value == "10" else datasets.CIFAR100
# Use the same data augmentation strategies as in https://arxiv.org/pdf/1605.07146v4
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4, padding_mode="reflect"),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
train_data = cifar_dataset(root=".", train=True, download=True, transform=transform)
test_data = cifar_dataset(root=".", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
Plot the first 10 images
import plotly.graph_objects as go
from plotly.subplots import make_subplots
num_images = 10
# Create a row of subplots
cifar_fig = make_subplots(
rows=1, cols=num_images,
subplot_titles=[train_data.classes[train_data[i][1]] for i in range(num_images)],
horizontal_spacing=0,
)
for i in range(num_images):
img, label = train_data[i]
img = img.permute(1, 2, 0).numpy() # (C, H, W) -> (H, W, C) and convert to numpy
cifar_fig.add_trace(
go.Image(z=(img * 255).astype('uint8')),
row=1, col=i+1
)
# Update layout: remove axes and tighten spacing
thumbnail_size = 105
cifar_fig.update_layout(
height=thumbnail_size, # adjust height as needed
width=thumbnail_size * num_images, # 150px per image
showlegend=False,
margin=dict(l=0, r=0, t=30, b=0)
)
# Hide axes
for i in range(1, num_images + 1):
cifar_fig.update_xaxes(visible=False, row=1, col=i)
cifar_fig.update_yaxes(visible=False, row=1, col=i)
cifar_fig
Non-hierarchical model
First we create a basic non-hierarchical model as a baseline
import torch
from torch import nn
from torchmetrics import Accuracy
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride):
super().__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_planes)
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != out_planes:
self.shortcut = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
out += self.shortcut(x)
return out
class WideResNetBody(nn.Module):
def __init__(self, depth=16, width_factor=8):
super().__init__()
assert (depth - 4) % 6 == 0, "Depth should be 6n+4"
n = (depth - 4) // 6
k = width_factor
self.in_planes = 16
# Initial conv
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
# 3 groups
self.layer1 = self._make_layer(16*k, n, stride=1)
self.layer2 = self._make_layer(32*k, n, stride=2)
self.layer3 = self._make_layer(64*k, n, stride=2)
self.bn = nn.BatchNorm2d(64*k)
def _make_layer(self, out_planes, blocks, stride):
strides = [stride] + [1]*(blocks-1)
layers = []
for s in strides:
layers.append(BasicBlock(self.in_planes, out_planes, s))
self.in_planes = out_planes
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.relu(self.bn(out))
out = F.avg_pool2d(out, 8)
out = out.view(out.size(0), -1)
return out
class BasicImageClassifier(L.LightningModule):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
WideResNetBody(),
nn.LazyLinear(out_features=len(train_data.classes))
)
self.loss_fn = nn.CrossEntropyLoss()
self.metrics = [
Accuracy(task="multiclass", num_classes=len(train_data.classes))
]
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.loss_fn(logits, y)
self.log('train_loss', loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.loss_fn(logits, y)
self.log('val_loss', loss, prog_bar=True)
for metric in self.metrics:
metric = metric.to(logits.device)
result = metric(logits, y)
if isinstance(result, dict):
for name, value in result.items():
self.log(f"val_{name}", value, on_step=False, on_epoch=True, prog_bar=True)
else:
self.log(f"val_{metric.__class__.__name__}", result, on_step=False, on_epoch=True, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
basic_model = BasicImageClassifier()
basic_model
Train the basic model
from lightning.pytorch.loggers import CSVLogger
basic_logger = CSVLogger(save_dir="lightning_logs", name="basic_model")
basic_trainer = L.Trainer(max_epochs=epochs, accelerator="auto", enable_checkpointing=False, logger=basic_logger)
basic_trainer.fit(basic_model, train_dataloaders=train_loader, val_dataloaders=test_loader)
import pandas as pd
from pathlib import Path
import plotly.io as pio
pio.templates.default = "plotly_white"
basic_metrics_df = pd.read_csv(Path(basic_logger.log_dir) / "metrics.csv")
basic_metrics_df = basic_metrics_df.dropna(subset=["val_MulticlassAccuracy"])
basic_fig = go.Figure()
basic_fig.add_trace(go.Scatter(x=basic_metrics_df["epoch"], y=basic_metrics_df["val_MulticlassAccuracy"], mode='lines', name='class'))
basic_fig.update_layout(
xaxis_title="Epochs",
yaxis_title="Accuracy",
)
basic_fig.show()
Hierarchical model
Let’s now create a hierarchical model. First we need to create a tree structure for the CIFAR dataset.
from hierarchicalsoftmax import (
SoftmaxNode,
HierarchicalSoftmaxLazyLinear,
HierarchicalSoftmaxLoss,
)
from hierarchicalsoftmax.metrics import RankAccuracyTorchMetric
if len(train_data.classes) == 10:
# CIFAR-10
superclasses = {
"animals": ["bird", "cat", "deer", "dog", "frog", "horse"],
"vehicles": ["airplane", "automobile", "ship", "truck"],
}
else:
# CIFAR-100
superclasses = {
"aquatic mammals": ["beaver", "dolphin", "otter", "seal", "whale"],
"fish": ["aquarium_fish", "flatfish", "ray", "shark", "trout"],
"flowers": ["orchid", "poppy", "rose", "sunflower", "tulip"],
"food containers": ["bottle", "bowl", "can", "cup", "plate"],
"fruit and vegetables": ["apple", "mushroom", "orange", "pear", "sweet_pepper"],
"household electrical devices": ["clock", "keyboard", "lamp", "telephone", "television"],
"household furniture": ["bed", "chair", "couch", "table", "wardrobe"],
"insects": ["bee", "beetle", "butterfly", "caterpillar", "cockroach"],
"large carnivores": ["bear", "leopard", "lion", "tiger", "wolf"],
"large man-made outdoor things": ["bridge", "castle", "house", "road", "skyscraper"],
"large natural outdoor scenes": ["cloud", "forest", "mountain", "plain", "sea"],
"large omnivores and herbivores": ["camel", "cattle", "chimpanzee", "elephant", "kangaroo"],
"medium-sized mammals": ["fox", "porcupine", "possum", "raccoon", "skunk"],
"non-insect invertebrates": ["crab", "lobster", "snail", "spider", "worm"],
"people": ["baby", "boy", "girl", "man", "woman"],
"reptiles": ["crocodile", "dinosaur", "lizard", "snake", "turtle"],
"small mammals": ["hamster", "mouse", "rabbit", "shrew", "squirrel"],
"trees": ["maple_tree", "oak_tree", "palm_tree", "pine_tree", "willow_tree"],
"vehicles 1": ["bicycle", "bus", "motorcycle", "pickup_truck", "train"],
"vehicles 2": ["lawn_mower", "rocket", "streetcar", "tank", "tractor"],
}
root = SoftmaxNode("root")
for superclass, classes in superclasses.items():
superclass_node = SoftmaxNode(superclass, parent=root)
for class_name in classes:
SoftmaxNode(class_name, parent=superclass_node)
# Now that the tree is built, we can set the indexes
# This makes the tree read-only
root.set_indexes()
name_to_node_id = {node.name: root.node_to_id[node] for node in root.leaves}
index_to_node_id = {
i: name_to_node_id[name] for i, name in enumerate(train_data.classes)
}
# Render the hierarchy
mo.Html(root.svg())
Create DataLoaders with hierarchical labels
class HierarchicalDataset(torch.utils.data.Dataset):
def __init__(self, dataset, index_to_node_id):
self.dataset = dataset
self.index_to_node_id = index_to_node_id
def __getitem__(self, idx):
image, label = self.dataset[idx]
return image, self.index_to_node_id[label]
def __len__(self):
return len(self.dataset)
hierarchical_train_loader = DataLoader(HierarchicalDataset(train_data, index_to_node_id), batch_size=batch_size, shuffle=True)
hierarchical_test_loader = DataLoader(HierarchicalDataset(test_data, index_to_node_id), batch_size=batch_size, shuffle=False)
Create the Hierarchical Image Classifier model
class HierarchicalImageClassifier(BasicImageClassifier):
# Just overriding the init - keep the rest of the code
def __init__(self, root: SoftmaxNode):
super().__init__()
self.model = nn.Sequential(
WideResNetBody(),
HierarchicalSoftmaxLazyLinear(root=root)
)
self.loss_fn = HierarchicalSoftmaxLoss(root)
self.metrics = [
RankAccuracyTorchMetric(
root,
{1: "superclass_accuracy", 2: "class_accuracy"},
),
]
self.root = root
hierarchical_model = HierarchicalImageClassifier(root)
hierarchical_model
hierarchical_logger = CSVLogger(save_dir="lightning_logs", name="hierarchical_model")
hierarchical_trainer = L.Trainer(max_epochs=epochs, accelerator="auto", enable_checkpointing=False, logger=hierarchical_logger)
hierarchical_trainer.fit(hierarchical_model, train_dataloaders=hierarchical_train_loader, val_dataloaders=hierarchical_test_loader)
Plot the validation results at both the superclass and the class levels
hierarchical_df = pd.read_csv(Path(hierarchical_logger.log_dir) / "metrics.csv")
hierarchical_df = hierarchical_df.dropna(subset=["val_class_accuracy"])
hierarchical_fig = go.Figure()
hierarchical_fig.add_trace(go.Scatter(x=hierarchical_df["epoch"], y=hierarchical_df["val_superclass_accuracy"], mode='lines', name='superclass'))
hierarchical_fig.add_trace(go.Scatter(x=hierarchical_df["epoch"], y=hierarchical_df["val_class_accuracy"], mode='lines', name='class'))
hierarchical_fig.update_layout(
xaxis_title="Epochs",
yaxis_title="Accuracy",
)
hierarchical_fig