Coverage for hierarchicalsoftmax/nodes.py : 100.00%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
2from anytree.exporter import DotExporter
3from typing import Union
4from pathlib import Path
5import torch
6from anytree import Node, RenderTree, PreOrderIter, PostOrderIter, LevelOrderIter, LevelOrderGroupIter, ZigZagGroupIter
7from typing import List, Optional
8from rich.console import Console
9console = Console()
12class ReadOnlyError(RuntimeError):
13 """
14 Raised when trying to edit a SoftmaxNode tree after it has been set to read only.
15 """
18class IndexNotSetError(RuntimeError):
19 """
20 Raised when set_indexes not set for the SoftmaxNode root.
21 """
24class AlreadyIndexedError(RuntimeError):
25 """
26 Raised when set_indexes run more than once on a node.
27 """
30class SoftmaxNode(Node):
31 """
32 Creates a hierarchical tree to perform a softmax at each level.
33 """
34 def __init__(self, *args, alpha:float=1.0, weight=None, label_smoothing:float=0.0, gamma:float = None, readonly:bool=False, **kwargs):
35 self.softmax_start_index = None
36 self.softmax_end_index = None
37 self.children_softmax_end_index = None
38 self.node_to_id = None
39 self.node_list = None
40 self.alpha = alpha
41 self.weight = weight
42 self.label_smoothing = label_smoothing
43 self.readonly = readonly
44 self.gamma = gamma # for Focal Loss
45 self.children_dict = dict()
46 super().__init__(*args, **kwargs)
48 def __str__(self):
49 return self.name
51 def __repr__(self):
52 return str(self)
54 def set_indexes(self, index_in_parent:Optional[int]=None, current_index:int=0) -> int:
55 """
56 Sets all the indexes for this node and its descendants so that each node can be referenced by the root.
58 This should be called without arguments only on the root of a hierarchy tree.
59 After calling this function the tree from the root down will be read only.
61 Args:
62 index_in_parent (int, optional): The index of this node in the parent's list of children.
63 Defaults to None which is appropriate for the root of a tree.
64 current_index (int, optional): An index value for the root node to reference this node.
65 Defaults to 0 which is appropriate for the root of a tree.
67 Returns:
68 int: Returns the current_index
69 """
70 if self.softmax_start_index is not None:
71 raise AlreadyIndexedError(f"Node {self} already has been indexed. It cannot be indexed again.")
73 self.index_in_parent = index_in_parent
74 self.index_in_parent_tensor = torch.as_tensor([index_in_parent], dtype=torch.long) if index_in_parent is not None else None
76 self.index_in_softmax_layer = self.index_in_parent
77 if self.parent:
78 self.index_in_softmax_layer += self.parent.softmax_start_index
80 if self.children:
81 self.softmax_start_index = current_index
82 current_index += len(self.children)
83 self.softmax_end_index = current_index
85 for child_index, child in enumerate(self.children):
86 current_index = child.set_indexes(child_index, current_index)
88 self.children_softmax_end_index = current_index
90 # If this is the root, then traverse the tree and make an index of all children
91 if self.softmax_start_index == 0:
92 self.node_list = [None] * len(self.descendants)
93 self.node_to_id = dict()
94 self.softmax_index_to_node = dict()
95 for node in self.descendants:
96 self.node_to_id[node] = node.index_in_softmax_layer
97 self.node_list[node.index_in_softmax_layer] = node
99 self.leaf_indexes_in_softmax_layer = torch.as_tensor([leaf.index_in_softmax_layer for leaf in self.leaves])
101 self.readonly = True
102 return current_index
104 def set_indexes_if_unset(self) -> None:
105 """
106 Calls set_indexes if it has not been called yet.
108 This is only appropriate for the root node.
109 """
110 if self.root.softmax_start_index is None:
111 self.root.set_indexes()
113 def render(self, attr:Optional[str]=None, print:bool=False, filepath:Union[str, Path, None] = None, **kwargs) -> RenderTree:
114 """
115 Renders this node and all its descendants in a tree format.
117 Args:
118 attr (str, optional): An attribute to print for this rendering of the tree. If None, then the name of each node is used.
119 print (bool): Whether or not the tree should be printed. Defaults to False.
120 filepath: (str, Path, optional): A path to save the tree to using graphviz. Requires graphviz to be installed.
122 Returns:
123 RenderTree: The tree rendered by anytree.
124 """
125 rendered = RenderTree(self, **kwargs)
126 if attr:
127 rendered = rendered.by_attr(attr)
128 if print:
129 console.print(rendered)
131 if filepath:
132 filepath = Path(filepath)
133 filepath.parent.mkdir(exist_ok=True, parents=True)
135 rendered_tree_graph = DotExporter(self)
137 if filepath.suffix == ".dot":
138 rendered_tree_graph.to_dotfile(str(filepath))
139 else:
140 rendered_tree_graph.to_picture(str(filepath))
142 return rendered
144 def _pre_attach(self, parent:Node):
145 if self.readonly or parent.readonly:
146 raise ReadOnlyError()
148 def _pre_detach(self, parent:Node):
149 if self.readonly or parent.readonly:
150 raise ReadOnlyError()
152 def _post_attach(self, parent:Node):
153 """Method call after attaching to `parent`."""
154 parent.children_dict[self.name] = self
156 def _post_detach(self, parent:Node):
157 """Method call after detaching from `parent`."""
158 del parent.children_dict[self.name]
160 def get_child_by_name(self, name:str) -> SoftmaxNode:
161 """
162 Returns the child node that has the same name as what is given.
164 Args:
165 name (str): The name of the child node requested.
167 Returns:
168 SoftmaxNode: The child node that has the same name as what is given. If not child node exists with this name then `None` is returned.
169 """
170 return self.children_dict.get(name, None)
172 def get_node_ids(self, nodes:List) -> List[int]:
173 """
174 Gets the index values for descendant nodes.
176 This should only be used for root nodes.
177 If `set_indexes` has been yet called on this object then it is performed as part of this function call.
179 Args:
180 nodes (List): A list of descendant nodes.
182 Returns:
183 List[int]: A list of indexes for the descendant nodes requested.
184 """
185 if self.node_to_id is None:
186 self.set_indexes()
188 return [self.node_to_id[node] for node in nodes]
190 def get_node_ids_tensor(self, nodes:List) -> torch.Tensor:
191 """
192 Gets the index values for descendant nodes.
194 This should only be used for root nodes.
195 If `set_indexes` has been yet called on this object then it is performed as part of this function call.
197 Args:
198 nodes (List): A list of descendant nodes.
200 Returns:
201 torch.Tensor: A tensor which contains the indexes for the descendant nodes requested.
202 """
203 return torch.as_tensor( self.get_node_ids(nodes), dtype=int)
205 @property
206 def layer_size(self) -> int:
207 self.root.set_indexes_if_unset()
209 return self.children_softmax_end_index
211 def render_equal(self, string_representation:str, **kwargs) -> bool:
212 """
213 Checks if the string representation of this node and its descendants matches the given string.
215 Args:
216 string_representation (str): The string representation to compare to.
217 """
218 my_render = str(self.render(**kwargs))
219 lines1 = str(my_render).strip().split("\n")
220 lines2 = str(string_representation).strip().split("\n")
222 if len(lines1) != len(lines2):
223 return False
225 for line1, line2 in zip(lines1, lines2):
226 if line1.strip() != line2.strip():
227 return False
229 return True
231 def pre_order_iter(self, depth=None, **kwargs) -> PreOrderIter:
232 """
233 Returns a pre-order iterator.
235 See https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.preorderiter.PreOrderIter
236 """
237 if depth is not None:
238 kwargs["maxlevel"] = depth + 1
239 return PreOrderIter(self, **kwargs)
241 def post_order_iter(self, depth=None, **kwargs) -> PostOrderIter:
242 """
243 Returns a post-order iterator.
245 See https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.postorderiter.PostOrderIter
246 """
247 if depth is not None:
248 kwargs["maxlevel"] = depth + 1
249 return PostOrderIter(self, **kwargs)
251 def level_order_iter(self, depth=None, **kwargs) -> LevelOrderIter:
252 """
253 Returns a level-order iterator.
255 See https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.levelorderiter.LevelOrderIter
256 """
257 if depth is not None:
258 kwargs["maxlevel"] = depth + 1
259 return LevelOrderIter(self, **kwargs)
261 def level_order_group_iter(self, depth=None, **kwargs) -> LevelOrderGroupIter:
262 """
263 Returns a level-order iterator with grouping starting at this node.
265 https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.levelordergroupiter.LevelOrderGroupIter
266 """
267 if depth is not None:
268 kwargs["maxlevel"] = depth + 1
269 return LevelOrderGroupIter(self, **kwargs)
271 def zig_zag_group_iter(self, depth=None, **kwargs) -> ZigZagGroupIter:
272 """
273 Returns a zig-zag iterator with grouping starting at this node.
275 https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.zigzaggroupiter.ZigZagGroupIter
276 """
277 if depth is not None:
278 kwargs["maxlevel"] = depth + 1
279 return ZigZagGroupIter(self, **kwargs)