Coverage for hierarchicalsoftmax/nodes.py: 100.00%
155 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-02 01:49 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-02 01:49 +0000
1from __future__ import annotations
2from anytree.exporter import DotExporter
3from typing import Union
4from pathlib import Path
5import torch
6from graphviz import Source
7from anytree import Node, RenderTree, PreOrderIter, PostOrderIter, LevelOrderIter, LevelOrderGroupIter, ZigZagGroupIter
8from typing import List, Optional
9from rich.console import Console
10console = Console()
13class ReadOnlyError(RuntimeError):
14 """
15 Raised when trying to edit a SoftmaxNode tree after it has been set to read only.
16 """
19class IndexNotSetError(RuntimeError):
20 """
21 Raised when set_indexes not set for the SoftmaxNode root.
22 """
25class AlreadyIndexedError(RuntimeError):
26 """
27 Raised when set_indexes run more than once on a node.
28 """
31class SoftmaxNode(Node):
32 """
33 Creates a hierarchical tree to perform a softmax at each level.
34 """
35 def __init__(self, *args, alpha:float=1.0, weight=None, label_smoothing:float=0.0, gamma:float = None, readonly:bool=False, **kwargs):
36 self.softmax_start_index = None
37 self.softmax_end_index = None
38 self.children_softmax_end_index = None
39 self.node_to_id = None
40 self.node_list = None
41 self.alpha = alpha
42 self.weight = weight
43 self.label_smoothing = label_smoothing
44 self.readonly = readonly
45 self.gamma = gamma # for Focal Loss
46 self.children_dict = dict()
47 super().__init__(*args, **kwargs)
49 def __str__(self):
50 return self.name
52 def __repr__(self):
53 return str(self)
55 def set_indexes(self, index_in_parent:Optional[int]=None, current_index:int=0) -> int:
56 """
57 Sets all the indexes for this node and its descendants so that each node can be referenced by the root.
59 This should be called without arguments only on the root of a hierarchy tree.
60 After calling this function the tree from the root down will be read only.
62 Args:
63 index_in_parent (int, optional): The index of this node in the parent's list of children.
64 Defaults to None which is appropriate for the root of a tree.
65 current_index (int, optional): An index value for the root node to reference this node.
66 Defaults to 0 which is appropriate for the root of a tree.
68 Returns:
69 int: Returns the current_index
70 """
71 if self.softmax_start_index is not None:
72 raise AlreadyIndexedError(f"Node {self} already has been indexed. It cannot be indexed again.")
74 self.index_in_parent = index_in_parent
75 self.index_in_parent_tensor = torch.as_tensor([index_in_parent], dtype=torch.long) if index_in_parent is not None else None
77 self.index_in_softmax_layer = self.index_in_parent
78 if self.parent:
79 # If the parent has just one child, then this node is skipped in the softmax layer because it isn't needed
80 if len(self.parent.children) == 1:
81 self.index_in_softmax_layer = None
82 else:
83 self.index_in_softmax_layer += self.parent.softmax_start_index
85 if self.children:
86 self.softmax_start_index = current_index
87 current_index += len(self.children) if len(self.children) > 1 else 0
88 self.softmax_end_index = current_index
90 for child_index, child in enumerate(self.children):
91 current_index = child.set_indexes(child_index, current_index)
93 self.children_softmax_end_index = current_index
95 # If this is the root, then traverse the tree and make an index of all children
96 if self.parent is None:
97 self.node_list = [None] * len(self.descendants)
98 self.node_to_id = dict()
99 non_softmax_index = self.children_softmax_end_index
100 for node in self.descendants:
101 if node.index_in_softmax_layer is None:
102 node_id = non_softmax_index
103 non_softmax_index += 1
104 else:
105 node_id = node.index_in_softmax_layer
107 self.node_to_id[node] = node_id
108 self.node_list[node_id] = node
110 self.node_list_softmax = self.node_list[:self.children_softmax_end_index] if self.children_softmax_end_index < len(self.node_list) else self.node_list
111 self.leaf_list_softmax = [node for node in self.node_list_softmax if not node.children]
112 self.node_indexes_in_softmax_layer = torch.as_tensor([node.index_in_softmax_layer for node in self.node_list_softmax])
113 self.leaf_indexes = [leaf.best_index_in_softmax_layer() for leaf in self.leaves]
114 try:
115 self.leaf_indexes = torch.as_tensor(self.leaf_indexes, dtype=torch.long)
116 except TypeError:
117 pass
119 self.readonly = True
120 return current_index
122 def best_index_in_softmax_layer(self) -> int|None:
123 if self.index_in_softmax_layer is not None:
124 return self.index_in_softmax_layer
126 if self.parent:
127 return self.parent.best_index_in_softmax_layer()
129 return None
131 def set_indexes_if_unset(self) -> None:
132 """
133 Calls set_indexes if it has not been called yet.
135 This is only appropriate for the root node.
136 """
137 if self.root.softmax_start_index is None:
138 self.root.set_indexes()
140 def render(self, attr:Optional[str]=None, print:bool=False, filepath:Union[str, Path, None] = None, **kwargs) -> RenderTree:
141 """
142 Renders this node and all its descendants in a tree format.
144 Args:
145 attr (str, optional): An attribute to print for this rendering of the tree. If None, then the name of each node is used.
146 print (bool): Whether or not the tree should be printed. Defaults to False.
147 filepath: (str, Path, optional): A path to save the tree to using graphviz. Requires graphviz to be installed.
149 Returns:
150 RenderTree: The tree rendered by anytree.
151 """
152 rendered = RenderTree(self, **kwargs)
153 if attr:
154 rendered = rendered.by_attr(attr)
155 if print:
156 console.print(rendered)
158 if filepath:
159 filepath = Path(filepath)
160 filepath.parent.mkdir(exist_ok=True, parents=True)
162 rendered_tree_graph = DotExporter(self)
164 if filepath.suffix == ".txt":
165 filepath.write_text(str(rendered))
166 elif filepath.suffix == ".dot":
167 rendered_tree_graph.to_dotfile(str(filepath))
168 else:
169 rendered_tree_graph.to_picture(str(filepath))
171 return rendered
173 def graphviz(
174 self,
175 options=None,
176 horizontal:bool=True,
177 ) -> Source:
178 """
179 Renders this node and all its descendants in a tree format using graphviz.
180 """
181 options = options or []
182 if horizontal:
183 options.append('rankdir="LR";')
185 dot_string = "\n".join(DotExporter(self, options=options))
187 return Source(dot_string)
189 def svg(
190 self,
191 options=None,
192 horizontal:bool=True,
193 ) -> str:
194 """
195 Renders this node and all its descendants in a tree format using graphviz.
196 """
197 source = self.graphviz(options=options, horizontal=horizontal)
198 return source.pipe(format="svg").decode("utf-8")
200 def _pre_attach(self, parent:Node):
201 if self.readonly or parent.readonly:
202 raise ReadOnlyError()
204 def _pre_detach(self, parent:Node):
205 if self.readonly or parent.readonly:
206 raise ReadOnlyError()
208 def _post_attach(self, parent:Node):
209 """Method call after attaching to `parent`."""
210 parent.children_dict[self.name] = self
212 def _post_detach(self, parent:Node):
213 """Method call after detaching from `parent`."""
214 del parent.children_dict[self.name]
216 def get_child_by_name(self, name:str) -> SoftmaxNode:
217 """
218 Returns the child node that has the same name as what is given.
220 Args:
221 name (str): The name of the child node requested.
223 Returns:
224 SoftmaxNode: The child node that has the same name as what is given. If not child node exists with this name then `None` is returned.
225 """
226 return self.children_dict.get(name, None)
228 def get_node_ids(self, nodes:List) -> List[int]:
229 """
230 Gets the index values for descendant nodes.
232 This should only be used for root nodes.
233 If `set_indexes` has been yet called on this object then it is performed as part of this function call.
235 Args:
236 nodes (List): A list of descendant nodes.
238 Returns:
239 List[int]: A list of indexes for the descendant nodes requested.
240 """
241 if self.node_to_id is None:
242 self.set_indexes()
244 return [self.node_to_id[node] for node in nodes]
246 def get_node_ids_tensor(self, nodes:List) -> torch.Tensor:
247 """
248 Gets the index values for descendant nodes.
250 This should only be used for root nodes.
251 If `set_indexes` has been yet called on this object then it is performed as part of this function call.
253 Args:
254 nodes (List): A list of descendant nodes.
256 Returns:
257 torch.Tensor: A tensor which contains the indexes for the descendant nodes requested.
258 """
259 return torch.as_tensor( self.get_node_ids(nodes), dtype=int)
261 @property
262 def layer_size(self) -> int:
263 self.root.set_indexes_if_unset()
265 return self.children_softmax_end_index
267 def render_equal(self, string_representation:str, **kwargs) -> bool:
268 """
269 Checks if the string representation of this node and its descendants matches the given string.
271 Args:
272 string_representation (str): The string representation to compare to.
273 """
274 my_render = str(self.render(**kwargs))
275 lines1 = str(my_render).strip().split("\n")
276 lines2 = str(string_representation).strip().split("\n")
278 if len(lines1) != len(lines2):
279 return False
281 for line1, line2 in zip(lines1, lines2):
282 if line1.strip() != line2.strip():
283 return False
285 return True
287 def pre_order_iter(self, depth=None, **kwargs) -> PreOrderIter:
288 """
289 Returns a pre-order iterator.
291 See https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.preorderiter.PreOrderIter
292 """
293 if depth is not None:
294 kwargs["maxlevel"] = depth + 1
295 return PreOrderIter(self, **kwargs)
297 def post_order_iter(self, depth=None, **kwargs) -> PostOrderIter:
298 """
299 Returns a post-order iterator.
301 See https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.postorderiter.PostOrderIter
302 """
303 if depth is not None:
304 kwargs["maxlevel"] = depth + 1
305 return PostOrderIter(self, **kwargs)
307 def level_order_iter(self, depth=None, **kwargs) -> LevelOrderIter:
308 """
309 Returns a level-order iterator.
311 See https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.levelorderiter.LevelOrderIter
312 """
313 if depth is not None:
314 kwargs["maxlevel"] = depth + 1
315 return LevelOrderIter(self, **kwargs)
317 def level_order_group_iter(self, depth=None, **kwargs) -> LevelOrderGroupIter:
318 """
319 Returns a level-order iterator with grouping starting at this node.
321 https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.levelordergroupiter.LevelOrderGroupIter
322 """
323 if depth is not None:
324 kwargs["maxlevel"] = depth + 1
325 return LevelOrderGroupIter(self, **kwargs)
327 def zig_zag_group_iter(self, depth=None, **kwargs) -> ZigZagGroupIter:
328 """
329 Returns a zig-zag iterator with grouping starting at this node.
331 https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.zigzaggroupiter.ZigZagGroupIter
332 """
333 if depth is not None:
334 kwargs["maxlevel"] = depth + 1
335 return ZigZagGroupIter(self, **kwargs)