TreeDict

The TreeDict class is a convenience wrapper for associating keys with nodes in a hierarchical tree. Each key is mapped to a node in the tree and assigned to a specific partition. This is useful in tasks such as machine learning classification where validation data is grouped by partition, or in hierarchical softmax applications.

TreeDict extends the regular Python dictionary class and provides additional methods to:

  • Add and retrieve nodes associated with keys

  • Track partition membership

  • Render or visualize the tree structure

  • Serialize and deserialize the tree and key mapping

  • Truncate the tree to a specific depth

  • Output tree summaries in human-readable and visual form

TreeDict objects use a classification tree based on SoftmaxNode.

from hierarchicalsoftmax import TreeDict, SoftmaxNode

 my_tree = SoftmaxNode("root")
 a = SoftmaxNode("a", parent=my_tree)
 aa = SoftmaxNode("aa", parent=a)
 ab = SoftmaxNode("ab", parent=a)
 b = SoftmaxNode("b", parent=my_tree)
 ba = SoftmaxNode("ba", parent=b)
 bb = SoftmaxNode("bb", parent=b)

 tree = TreeDict(my_tree)

Add keys to the TreeDict using the add method:

from hierarchicalsoftmax import TreeDict

tree = TreeDict()
tree.add("item_aa_1", aa, partition=0)
tree.add("item_aa_2", aa, partition=0)

Now you can retrieve the node ID and the partition associated with a key:

node_detail = tree["item_aa_1"]
print(node_detail.partition)
print(node_detail.node_id)

You can also get the actual node object:

node = tree.node("item_aa_1")

Classes

class hierarchicalsoftmax.treedict.TreeDict(classification_tree: SoftmaxNode | None = None)

Bases: UserDict

add(key: str, node: SoftmaxNode, partition: int) NodeDetail

Associate a key with a node and a partition.

Parameters:
  • key (str) – The unique identifier for the item.

  • node (SoftmaxNode) – The node in the classification tree to associate with the key.

  • partition (int) – The partition index for the key.

Raises:

AlreadyExists – If the key already exists with a different node.

Returns:

The metadata object for the added key.

Return type:

NodeDetail

add_counts()

Count the number of keys assigned to each node, and store the count in each node.

add_partition_counts()

Count the number of keys in each partition per node and store it in the node.

keys(partition: int | None = None)

Return keys in the TreeDict, optionally filtering by partition.

Parameters:

partition (int | None) – The partition to filter keys by. If None, return all keys.

Returns:

An iterator over the keys.

Return type:

Iterator[str]

keys_in_partition(partition: int)

Yield all keys that belong to a given partition.

Parameters:

partition (int) – The partition to filter by.

Yields:

str – Keys in the specified partition.

keys_to_file(file: Path) None

Write all keys to a text file, one per line.

Parameters:

file (Path) – Path to the output text file.

classmethod load(path: Path)

Load a TreeDict from a pickle file.

Parameters:

path (Path) – The path to the serialized TreeDict.

Returns:

The loaded TreeDict instance.

Return type:

TreeDict

node(key: str)

Retrieve the node associated with a key.

Parameters:

key (str) – The key for which to retrieve the node.

Returns:

The node corresponding to the key.

Return type:

SoftmaxNode

pickle_tree(output: Path)

Save only the classification tree (not the key-to-node mapping) to a pickle file.

Parameters:

output (Path) – Path to the output file.

render(count: bool = False, partition_counts: bool = False, **kwargs)

Render the tree as text, optionally showing key counts or partition counts.

Parameters:
  • count (bool) – If True, show the number of keys at each node.

  • partition_counts (bool) – If True, show partition-wise key counts at each node.

  • **kwargs – Additional arguments passed to the underlying tree render method.

Returns:

The rendered tree.

Return type:

anytree.RenderTree or str

save(path: Path)

Save the TreeDict to a pickle file.

Parameters:

path (Path) – The file path to save the TreeDict.

set_indexes()

Ensure the tree has assigned node indexes, and record the node_id for each key.

sunburst(**kwargs) go.Figure

Generate a Plotly sunburst plot based on the TreeDict.

Node values are based on the number of keys mapped to each node.

Parameters:

**kwargs – Additional keyword arguments passed to Plotly layout.

Returns:

A sunburst plot.

Return type:

plotly.graph_objects.Figure

truncate(max_depth: int) TreeDict

Truncate the classification tree to a specified maximum depth and return a new TreeDict.

Keys deeper than the depth limit will be reassigned to the ancestor node at that depth.

Parameters:

max_depth (int) – The maximum number of ancestor levels to keep.

Returns:

A new truncated TreeDict.

Return type:

TreeDict

class hierarchicalsoftmax.treedict.NodeDetail(partition: int, node: SoftmaxNode = None, node_id: int = None)

Stores metadata for a key in the TreeDict.

partition

The partition ID this key belongs to.

Type:

int

node

The node in the classification tree associated with the key.

Type:

SoftmaxNode

node_id

The index of the node in the tree (used during pickling).

Type:

int

node: SoftmaxNode
node_id: int
partition: int

Command Line Interface

The CLI is provided through the Typer app and installed as the command treedict.

$ treedict --help

This CLI provides several subcommands:

keys

Print the list of keys in a TreeDict. Optionally filter by partition.

$ treedict keys data/tree.pkl
$ treedict keys data/tree.pkl --partition 0
render

Render the tree structure to the console or to a file. You may include node counts or per-partition counts.

$ treedict render data/tree.pkl --count
$ treedict render data/tree.pkl --partition-counts --output out.txt
count

Print the total number of keys in the TreeDict.

$ treedict count data/tree.pkl
sunburst

Generate a sunburst plot of the tree using Plotly. You may save to a file or display it interactively.

$ treedict sunburst data/tree.pkl --output tree.html
$ treedict sunburst data/tree.pkl --show
truncate

Truncate the tree to a specified maximum depth and save the new TreeDict to a file.

$ treedict truncate data/tree.pkl 3 out/tree-truncated.pkl
layer-size

Print the size of the neural network output layer required for classifying against the current tree.

$ treedict layer-size data/tree.pkl
pickle-tree

Serialize only the classification tree (excluding keys) to a pickle file.

$ treedict pickle-tree data/tree.pkl out/tree-only.pkl