TreeDict
The TreeDict
class is a convenience wrapper for associating keys with nodes in a hierarchical tree.
Each key is mapped to a node in the tree and assigned to a specific partition. This is useful in tasks
such as machine learning classification where validation data is grouped by partition, or in hierarchical
softmax applications.
TreeDict extends the regular Python dictionary class and provides additional methods to:
Add and retrieve nodes associated with keys
Track partition membership
Render or visualize the tree structure
Serialize and deserialize the tree and key mapping
Truncate the tree to a specific depth
Output tree summaries in human-readable and visual form
TreeDict objects use a classification tree based on SoftmaxNode
.
from hierarchicalsoftmax import TreeDict, SoftmaxNode
my_tree = SoftmaxNode("root")
a = SoftmaxNode("a", parent=my_tree)
aa = SoftmaxNode("aa", parent=a)
ab = SoftmaxNode("ab", parent=a)
b = SoftmaxNode("b", parent=my_tree)
ba = SoftmaxNode("ba", parent=b)
bb = SoftmaxNode("bb", parent=b)
tree = TreeDict(my_tree)
Add keys to the TreeDict using the add
method:
from hierarchicalsoftmax import TreeDict
tree = TreeDict()
tree.add("item_aa_1", aa, partition=0)
tree.add("item_aa_2", aa, partition=0)
Now you can retrieve the node ID and the partition associated with a key:
node_detail = tree["item_aa_1"]
print(node_detail.partition)
print(node_detail.node_id)
You can also get the actual node object:
node = tree.node("item_aa_1")
Classes
- class hierarchicalsoftmax.treedict.TreeDict(classification_tree: SoftmaxNode | None = None)
Bases:
UserDict
- add(key: str, node: SoftmaxNode, partition: int) NodeDetail
Associate a key with a node and a partition.
- Parameters:
key (str) – The unique identifier for the item.
node (SoftmaxNode) – The node in the classification tree to associate with the key.
partition (int) – The partition index for the key.
- Raises:
AlreadyExists – If the key already exists with a different node.
- Returns:
The metadata object for the added key.
- Return type:
- add_counts()
Count the number of keys assigned to each node, and store the count in each node.
- add_partition_counts()
Count the number of keys in each partition per node and store it in the node.
- keys(partition: int | None = None)
Return keys in the TreeDict, optionally filtering by partition.
- Parameters:
partition (int | None) – The partition to filter keys by. If None, return all keys.
- Returns:
An iterator over the keys.
- Return type:
Iterator[str]
- keys_in_partition(partition: int)
Yield all keys that belong to a given partition.
- Parameters:
partition (int) – The partition to filter by.
- Yields:
str – Keys in the specified partition.
- keys_to_file(file: Path) None
Write all keys to a text file, one per line.
- Parameters:
file (Path) – Path to the output text file.
- classmethod load(path: Path)
Load a TreeDict from a pickle file.
- Parameters:
path (Path) – The path to the serialized TreeDict.
- Returns:
The loaded TreeDict instance.
- Return type:
- node(key: str)
Retrieve the node associated with a key.
- Parameters:
key (str) – The key for which to retrieve the node.
- Returns:
The node corresponding to the key.
- Return type:
- pickle_tree(output: Path)
Save only the classification tree (not the key-to-node mapping) to a pickle file.
- Parameters:
output (Path) – Path to the output file.
- render(count: bool = False, partition_counts: bool = False, **kwargs)
Render the tree as text, optionally showing key counts or partition counts.
- Parameters:
count (bool) – If True, show the number of keys at each node.
partition_counts (bool) – If True, show partition-wise key counts at each node.
**kwargs – Additional arguments passed to the underlying tree render method.
- Returns:
The rendered tree.
- Return type:
anytree.RenderTree or str
- save(path: Path)
Save the TreeDict to a pickle file.
- Parameters:
path (Path) – The file path to save the TreeDict.
- set_indexes()
Ensure the tree has assigned node indexes, and record the node_id for each key.
- sunburst(**kwargs) go.Figure
Generate a Plotly sunburst plot based on the TreeDict.
Node values are based on the number of keys mapped to each node.
- Parameters:
**kwargs – Additional keyword arguments passed to Plotly layout.
- Returns:
A sunburst plot.
- Return type:
plotly.graph_objects.Figure
- truncate(max_depth: int) TreeDict
Truncate the classification tree to a specified maximum depth and return a new TreeDict.
Keys deeper than the depth limit will be reassigned to the ancestor node at that depth.
- Parameters:
max_depth (int) – The maximum number of ancestor levels to keep.
- Returns:
A new truncated TreeDict.
- Return type:
- class hierarchicalsoftmax.treedict.NodeDetail(partition: int, node: SoftmaxNode = None, node_id: int = None)
Stores metadata for a key in the TreeDict.
- partition
The partition ID this key belongs to.
- Type:
int
- node
The node in the classification tree associated with the key.
- Type:
- node_id
The index of the node in the tree (used during pickling).
- Type:
int
- node: SoftmaxNode
- node_id: int
- partition: int
Command Line Interface
The CLI is provided through the Typer app and installed as the command treedict
.
$ treedict --help
This CLI provides several subcommands:
keys
Print the list of keys in a TreeDict. Optionally filter by partition.
$ treedict keys data/tree.pkl $ treedict keys data/tree.pkl --partition 0
render
Render the tree structure to the console or to a file. You may include node counts or per-partition counts.
$ treedict render data/tree.pkl --count $ treedict render data/tree.pkl --partition-counts --output out.txt
count
Print the total number of keys in the TreeDict.
$ treedict count data/tree.pkl
sunburst
Generate a sunburst plot of the tree using Plotly. You may save to a file or display it interactively.
$ treedict sunburst data/tree.pkl --output tree.html $ treedict sunburst data/tree.pkl --show
truncate
Truncate the tree to a specified maximum depth and save the new TreeDict to a file.
$ treedict truncate data/tree.pkl 3 out/tree-truncated.pkl
layer-size
Print the size of the neural network output layer required for classifying against the current tree.
$ treedict layer-size data/tree.pkl
pickle-tree
Serialize only the classification tree (excluding keys) to a pickle file.
$ treedict pickle-tree data/tree.pkl out/tree-only.pkl