Coverage for hierarchicalsoftmax/nodes.py: 100.00%

155 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-02 01:49 +0000

1from __future__ import annotations 

2from anytree.exporter import DotExporter 

3from typing import Union 

4from pathlib import Path 

5import torch 

6from graphviz import Source 

7from anytree import Node, RenderTree, PreOrderIter, PostOrderIter, LevelOrderIter, LevelOrderGroupIter, ZigZagGroupIter 

8from typing import List, Optional 

9from rich.console import Console 

10console = Console() 

11 

12 

13class ReadOnlyError(RuntimeError): 

14 """ 

15 Raised when trying to edit a SoftmaxNode tree after it has been set to read only. 

16 """ 

17 

18 

19class IndexNotSetError(RuntimeError): 

20 """ 

21 Raised when set_indexes not set for the SoftmaxNode root. 

22 """ 

23 

24 

25class AlreadyIndexedError(RuntimeError): 

26 """ 

27 Raised when set_indexes run more than once on a node. 

28 """ 

29 

30 

31class SoftmaxNode(Node): 

32 """ 

33 Creates a hierarchical tree to perform a softmax at each level. 

34 """ 

35 def __init__(self, *args, alpha:float=1.0, weight=None, label_smoothing:float=0.0, gamma:float = None, readonly:bool=False, **kwargs): 

36 self.softmax_start_index = None 

37 self.softmax_end_index = None 

38 self.children_softmax_end_index = None 

39 self.node_to_id = None 

40 self.node_list = None 

41 self.alpha = alpha 

42 self.weight = weight 

43 self.label_smoothing = label_smoothing 

44 self.readonly = readonly 

45 self.gamma = gamma # for Focal Loss 

46 self.children_dict = dict() 

47 super().__init__(*args, **kwargs) 

48 

49 def __str__(self): 

50 return self.name 

51 

52 def __repr__(self): 

53 return str(self) 

54 

55 def set_indexes(self, index_in_parent:Optional[int]=None, current_index:int=0) -> int: 

56 """ 

57 Sets all the indexes for this node and its descendants so that each node can be referenced by the root. 

58 

59 This should be called without arguments only on the root of a hierarchy tree. 

60 After calling this function the tree from the root down will be read only. 

61 

62 Args: 

63 index_in_parent (int, optional): The index of this node in the parent's list of children.  

64 Defaults to None which is appropriate for the root of a tree. 

65 current_index (int, optional): An index value for the root node to reference this node.  

66 Defaults to 0 which is appropriate for the root of a tree. 

67 

68 Returns: 

69 int: Returns the current_index 

70 """ 

71 if self.softmax_start_index is not None: 

72 raise AlreadyIndexedError(f"Node {self} already has been indexed. It cannot be indexed again.") 

73 

74 self.index_in_parent = index_in_parent 

75 self.index_in_parent_tensor = torch.as_tensor([index_in_parent], dtype=torch.long) if index_in_parent is not None else None 

76 

77 self.index_in_softmax_layer = self.index_in_parent 

78 if self.parent: 

79 # If the parent has just one child, then this node is skipped in the softmax layer because it isn't needed 

80 if len(self.parent.children) == 1: 

81 self.index_in_softmax_layer = None 

82 else: 

83 self.index_in_softmax_layer += self.parent.softmax_start_index 

84 

85 if self.children: 

86 self.softmax_start_index = current_index 

87 current_index += len(self.children) if len(self.children) > 1 else 0 

88 self.softmax_end_index = current_index 

89 

90 for child_index, child in enumerate(self.children): 

91 current_index = child.set_indexes(child_index, current_index) 

92 

93 self.children_softmax_end_index = current_index 

94 

95 # If this is the root, then traverse the tree and make an index of all children 

96 if self.parent is None: 

97 self.node_list = [None] * len(self.descendants) 

98 self.node_to_id = dict() 

99 non_softmax_index = self.children_softmax_end_index 

100 for node in self.descendants: 

101 if node.index_in_softmax_layer is None: 

102 node_id = non_softmax_index 

103 non_softmax_index += 1 

104 else: 

105 node_id = node.index_in_softmax_layer 

106 

107 self.node_to_id[node] = node_id 

108 self.node_list[node_id] = node 

109 

110 self.node_list_softmax = self.node_list[:self.children_softmax_end_index] if self.children_softmax_end_index < len(self.node_list) else self.node_list 

111 self.leaf_list_softmax = [node for node in self.node_list_softmax if not node.children] 

112 self.node_indexes_in_softmax_layer = torch.as_tensor([node.index_in_softmax_layer for node in self.node_list_softmax]) 

113 self.leaf_indexes = [leaf.best_index_in_softmax_layer() for leaf in self.leaves] 

114 try: 

115 self.leaf_indexes = torch.as_tensor(self.leaf_indexes, dtype=torch.long) 

116 except TypeError: 

117 pass 

118 

119 self.readonly = True 

120 return current_index 

121 

122 def best_index_in_softmax_layer(self) -> int|None: 

123 if self.index_in_softmax_layer is not None: 

124 return self.index_in_softmax_layer 

125 

126 if self.parent: 

127 return self.parent.best_index_in_softmax_layer() 

128 

129 return None 

130 

131 def set_indexes_if_unset(self) -> None: 

132 """  

133 Calls set_indexes if it has not been called yet. 

134  

135 This is only appropriate for the root node. 

136 """ 

137 if self.root.softmax_start_index is None: 

138 self.root.set_indexes() 

139 

140 def render(self, attr:Optional[str]=None, print:bool=False, filepath:Union[str, Path, None] = None, **kwargs) -> RenderTree: 

141 """ 

142 Renders this node and all its descendants in a tree format. 

143 

144 Args: 

145 attr (str, optional): An attribute to print for this rendering of the tree. If None, then the name of each node is used. 

146 print (bool): Whether or not the tree should be printed. Defaults to False. 

147 filepath: (str, Path, optional): A path to save the tree to using graphviz. Requires graphviz to be installed. 

148 

149 Returns: 

150 RenderTree: The tree rendered by anytree. 

151 """ 

152 rendered = RenderTree(self, **kwargs) 

153 if attr: 

154 rendered = rendered.by_attr(attr) 

155 if print: 

156 console.print(rendered) 

157 

158 if filepath: 

159 filepath = Path(filepath) 

160 filepath.parent.mkdir(exist_ok=True, parents=True) 

161 

162 rendered_tree_graph = DotExporter(self) 

163 

164 if filepath.suffix == ".txt": 

165 filepath.write_text(str(rendered)) 

166 elif filepath.suffix == ".dot": 

167 rendered_tree_graph.to_dotfile(str(filepath)) 

168 else: 

169 rendered_tree_graph.to_picture(str(filepath)) 

170 

171 return rendered 

172 

173 def graphviz( 

174 self, 

175 options=None, 

176 horizontal:bool=True, 

177 ) -> Source: 

178 """ 

179 Renders this node and all its descendants in a tree format using graphviz. 

180 """ 

181 options = options or [] 

182 if horizontal: 

183 options.append('rankdir="LR";') 

184 

185 dot_string = "\n".join(DotExporter(self, options=options)) 

186 

187 return Source(dot_string) 

188 

189 def svg( 

190 self, 

191 options=None, 

192 horizontal:bool=True, 

193 ) -> str: 

194 """ 

195 Renders this node and all its descendants in a tree format using graphviz. 

196 """ 

197 source = self.graphviz(options=options, horizontal=horizontal) 

198 return source.pipe(format="svg").decode("utf-8") 

199 

200 def _pre_attach(self, parent:Node): 

201 if self.readonly or parent.readonly: 

202 raise ReadOnlyError() 

203 

204 def _pre_detach(self, parent:Node): 

205 if self.readonly or parent.readonly: 

206 raise ReadOnlyError() 

207 

208 def _post_attach(self, parent:Node): 

209 """Method call after attaching to `parent`.""" 

210 parent.children_dict[self.name] = self 

211 

212 def _post_detach(self, parent:Node): 

213 """Method call after detaching from `parent`.""" 

214 del parent.children_dict[self.name] 

215 

216 def get_child_by_name(self, name:str) -> SoftmaxNode: 

217 """ 

218 Returns the child node that has the same name as what is given. 

219 

220 Args: 

221 name (str): The name of the child node requested. 

222 

223 Returns: 

224 SoftmaxNode: The child node that has the same name as what is given. If not child node exists with this name then `None` is returned. 

225 """ 

226 return self.children_dict.get(name, None) 

227 

228 def get_node_ids(self, nodes:List) -> List[int]: 

229 """ 

230 Gets the index values for descendant nodes. 

231 

232 This should only be used for root nodes.  

233 If `set_indexes` has been yet called on this object then it is performed as part of this function call. 

234 

235 Args: 

236 nodes (List): A list of descendant nodes. 

237 

238 Returns: 

239 List[int]: A list of indexes for the descendant nodes requested. 

240 """ 

241 if self.node_to_id is None: 

242 self.set_indexes() 

243 

244 return [self.node_to_id[node] for node in nodes] 

245 

246 def get_node_ids_tensor(self, nodes:List) -> torch.Tensor: 

247 """ 

248 Gets the index values for descendant nodes. 

249 

250 This should only be used for root nodes.  

251 If `set_indexes` has been yet called on this object then it is performed as part of this function call. 

252 

253 Args: 

254 nodes (List): A list of descendant nodes. 

255 

256 Returns: 

257 torch.Tensor: A tensor which contains the indexes for the descendant nodes requested. 

258 """ 

259 return torch.as_tensor( self.get_node_ids(nodes), dtype=int) 

260 

261 @property 

262 def layer_size(self) -> int: 

263 self.root.set_indexes_if_unset() 

264 

265 return self.children_softmax_end_index 

266 

267 def render_equal(self, string_representation:str, **kwargs) -> bool: 

268 """ 

269 Checks if the string representation of this node and its descendants matches the given string. 

270 

271 Args: 

272 string_representation (str): The string representation to compare to. 

273 """ 

274 my_render = str(self.render(**kwargs)) 

275 lines1 = str(my_render).strip().split("\n") 

276 lines2 = str(string_representation).strip().split("\n") 

277 

278 if len(lines1) != len(lines2): 

279 return False 

280 

281 for line1, line2 in zip(lines1, lines2): 

282 if line1.strip() != line2.strip(): 

283 return False 

284 

285 return True 

286 

287 def pre_order_iter(self, depth=None, **kwargs) -> PreOrderIter: 

288 """  

289 Returns a pre-order iterator. 

290  

291 See https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.preorderiter.PreOrderIter 

292 """ 

293 if depth is not None: 

294 kwargs["maxlevel"] = depth + 1 

295 return PreOrderIter(self, **kwargs) 

296 

297 def post_order_iter(self, depth=None, **kwargs) -> PostOrderIter: 

298 """  

299 Returns a post-order iterator. 

300  

301 See https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.postorderiter.PostOrderIter 

302 """ 

303 if depth is not None: 

304 kwargs["maxlevel"] = depth + 1 

305 return PostOrderIter(self, **kwargs) 

306 

307 def level_order_iter(self, depth=None, **kwargs) -> LevelOrderIter: 

308 """  

309 Returns a level-order iterator. 

310  

311 See https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.levelorderiter.LevelOrderIter 

312 """ 

313 if depth is not None: 

314 kwargs["maxlevel"] = depth + 1 

315 return LevelOrderIter(self, **kwargs) 

316 

317 def level_order_group_iter(self, depth=None, **kwargs) -> LevelOrderGroupIter: 

318 """  

319 Returns a level-order iterator with grouping starting at this node. 

320  

321 https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.levelordergroupiter.LevelOrderGroupIter 

322 """ 

323 if depth is not None: 

324 kwargs["maxlevel"] = depth + 1 

325 return LevelOrderGroupIter(self, **kwargs) 

326 

327 def zig_zag_group_iter(self, depth=None, **kwargs) -> ZigZagGroupIter: 

328 """  

329 Returns a zig-zag iterator with grouping starting at this node. 

330  

331 https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.zigzaggroupiter.ZigZagGroupIter 

332 """ 

333 if depth is not None: 

334 kwargs["maxlevel"] = depth + 1 

335 return ZigZagGroupIter(self, **kwargs) 

336