Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from __future__ import annotations 

2from anytree.exporter import DotExporter 

3from typing import Union 

4from pathlib import Path 

5import torch 

6from anytree import Node, RenderTree, PreOrderIter, PostOrderIter, LevelOrderIter, LevelOrderGroupIter, ZigZagGroupIter 

7from typing import List, Optional 

8from rich.console import Console 

9console = Console() 

10 

11 

12class ReadOnlyError(RuntimeError): 

13 """ 

14 Raised when trying to edit a SoftmaxNode tree after it has been set to read only. 

15 """ 

16 

17 

18class IndexNotSetError(RuntimeError): 

19 """ 

20 Raised when set_indexes not set for the SoftmaxNode root. 

21 """ 

22 

23 

24class AlreadyIndexedError(RuntimeError): 

25 """ 

26 Raised when set_indexes run more than once on a node. 

27 """ 

28 

29 

30class SoftmaxNode(Node): 

31 """ 

32 Creates a hierarchical tree to perform a softmax at each level. 

33 """ 

34 def __init__(self, *args, alpha:float=1.0, weight=None, label_smoothing:float=0.0, gamma:float = None, readonly:bool=False, **kwargs): 

35 self.softmax_start_index = None 

36 self.softmax_end_index = None 

37 self.children_softmax_end_index = None 

38 self.node_to_id = None 

39 self.node_list = None 

40 self.alpha = alpha 

41 self.weight = weight 

42 self.label_smoothing = label_smoothing 

43 self.readonly = readonly 

44 self.gamma = gamma # for Focal Loss 

45 self.children_dict = dict() 

46 super().__init__(*args, **kwargs) 

47 

48 def __str__(self): 

49 return self.name 

50 

51 def __repr__(self): 

52 return str(self) 

53 

54 def set_indexes(self, index_in_parent:Optional[int]=None, current_index:int=0) -> int: 

55 """ 

56 Sets all the indexes for this node and its descendants so that each node can be referenced by the root. 

57 

58 This should be called without arguments only on the root of a hierarchy tree. 

59 After calling this function the tree from the root down will be read only. 

60 

61 Args: 

62 index_in_parent (int, optional): The index of this node in the parent's list of children.  

63 Defaults to None which is appropriate for the root of a tree. 

64 current_index (int, optional): An index value for the root node to reference this node.  

65 Defaults to 0 which is appropriate for the root of a tree. 

66 

67 Returns: 

68 int: Returns the current_index 

69 """ 

70 if self.softmax_start_index is not None: 

71 raise AlreadyIndexedError(f"Node {self} already has been indexed. It cannot be indexed again.") 

72 

73 self.index_in_parent = index_in_parent 

74 self.index_in_parent_tensor = torch.as_tensor([index_in_parent], dtype=torch.long) if index_in_parent is not None else None 

75 

76 self.index_in_softmax_layer = self.index_in_parent 

77 if self.parent: 

78 self.index_in_softmax_layer += self.parent.softmax_start_index 

79 

80 if self.children: 

81 self.softmax_start_index = current_index 

82 current_index += len(self.children) 

83 self.softmax_end_index = current_index 

84 

85 for child_index, child in enumerate(self.children): 

86 current_index = child.set_indexes(child_index, current_index) 

87 

88 self.children_softmax_end_index = current_index 

89 

90 # If this is the root, then traverse the tree and make an index of all children 

91 if self.softmax_start_index == 0: 

92 self.node_list = [None] * len(self.descendants) 

93 self.node_to_id = dict() 

94 self.softmax_index_to_node = dict() 

95 for node in self.descendants: 

96 self.node_to_id[node] = node.index_in_softmax_layer 

97 self.node_list[node.index_in_softmax_layer] = node 

98 

99 self.leaf_indexes_in_softmax_layer = torch.as_tensor([leaf.index_in_softmax_layer for leaf in self.leaves]) 

100 

101 self.readonly = True 

102 return current_index 

103 

104 def set_indexes_if_unset(self) -> None: 

105 """  

106 Calls set_indexes if it has not been called yet. 

107  

108 This is only appropriate for the root node. 

109 """ 

110 if self.root.softmax_start_index is None: 

111 self.root.set_indexes() 

112 

113 def render(self, attr:Optional[str]=None, print:bool=False, filepath:Union[str, Path, None] = None, **kwargs) -> RenderTree: 

114 """ 

115 Renders this node and all its descendants in a tree format. 

116 

117 Args: 

118 attr (str, optional): An attribute to print for this rendering of the tree. If None, then the name of each node is used. 

119 print (bool): Whether or not the tree should be printed. Defaults to False. 

120 filepath: (str, Path, optional): A path to save the tree to using graphviz. Requires graphviz to be installed. 

121 

122 Returns: 

123 RenderTree: The tree rendered by anytree. 

124 """ 

125 rendered = RenderTree(self, **kwargs) 

126 if attr: 

127 rendered = rendered.by_attr(attr) 

128 if print: 

129 console.print(rendered) 

130 

131 if filepath: 

132 filepath = Path(filepath) 

133 filepath.parent.mkdir(exist_ok=True, parents=True) 

134 

135 rendered_tree_graph = DotExporter(self) 

136 

137 if filepath.suffix == ".dot": 

138 rendered_tree_graph.to_dotfile(str(filepath)) 

139 else: 

140 rendered_tree_graph.to_picture(str(filepath)) 

141 

142 return rendered 

143 

144 def _pre_attach(self, parent:Node): 

145 if self.readonly or parent.readonly: 

146 raise ReadOnlyError() 

147 

148 def _pre_detach(self, parent:Node): 

149 if self.readonly or parent.readonly: 

150 raise ReadOnlyError() 

151 

152 def _post_attach(self, parent:Node): 

153 """Method call after attaching to `parent`.""" 

154 parent.children_dict[self.name] = self 

155 

156 def _post_detach(self, parent:Node): 

157 """Method call after detaching from `parent`.""" 

158 del parent.children_dict[self.name] 

159 

160 def get_child_by_name(self, name:str) -> SoftmaxNode: 

161 """ 

162 Returns the child node that has the same name as what is given. 

163 

164 Args: 

165 name (str): The name of the child node requested. 

166 

167 Returns: 

168 SoftmaxNode: The child node that has the same name as what is given. If not child node exists with this name then `None` is returned. 

169 """ 

170 return self.children_dict.get(name, None) 

171 

172 def get_node_ids(self, nodes:List) -> List[int]: 

173 """ 

174 Gets the index values for descendant nodes. 

175 

176 This should only be used for root nodes.  

177 If `set_indexes` has been yet called on this object then it is performed as part of this function call. 

178 

179 Args: 

180 nodes (List): A list of descendant nodes. 

181 

182 Returns: 

183 List[int]: A list of indexes for the descendant nodes requested. 

184 """ 

185 if self.node_to_id is None: 

186 self.set_indexes() 

187 

188 return [self.node_to_id[node] for node in nodes] 

189 

190 def get_node_ids_tensor(self, nodes:List) -> torch.Tensor: 

191 """ 

192 Gets the index values for descendant nodes. 

193 

194 This should only be used for root nodes.  

195 If `set_indexes` has been yet called on this object then it is performed as part of this function call. 

196 

197 Args: 

198 nodes (List): A list of descendant nodes. 

199 

200 Returns: 

201 torch.Tensor: A tensor which contains the indexes for the descendant nodes requested. 

202 """ 

203 return torch.as_tensor( self.get_node_ids(nodes), dtype=int) 

204 

205 @property 

206 def layer_size(self) -> int: 

207 self.root.set_indexes_if_unset() 

208 

209 return self.children_softmax_end_index 

210 

211 def render_equal(self, string_representation:str, **kwargs) -> bool: 

212 """ 

213 Checks if the string representation of this node and its descendants matches the given string. 

214 

215 Args: 

216 string_representation (str): The string representation to compare to. 

217 """ 

218 my_render = str(self.render(**kwargs)) 

219 lines1 = str(my_render).strip().split("\n") 

220 lines2 = str(string_representation).strip().split("\n") 

221 

222 if len(lines1) != len(lines2): 

223 return False 

224 

225 for line1, line2 in zip(lines1, lines2): 

226 if line1.strip() != line2.strip(): 

227 return False 

228 

229 return True 

230 

231 def pre_order_iter(self, depth=None, **kwargs) -> PreOrderIter: 

232 """  

233 Returns a pre-order iterator. 

234  

235 See https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.preorderiter.PreOrderIter 

236 """ 

237 if depth is not None: 

238 kwargs["maxlevel"] = depth + 1 

239 return PreOrderIter(self, **kwargs) 

240 

241 def post_order_iter(self, depth=None, **kwargs) -> PostOrderIter: 

242 """  

243 Returns a post-order iterator. 

244  

245 See https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.postorderiter.PostOrderIter 

246 """ 

247 if depth is not None: 

248 kwargs["maxlevel"] = depth + 1 

249 return PostOrderIter(self, **kwargs) 

250 

251 def level_order_iter(self, depth=None, **kwargs) -> LevelOrderIter: 

252 """  

253 Returns a level-order iterator. 

254  

255 See https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.levelorderiter.LevelOrderIter 

256 """ 

257 if depth is not None: 

258 kwargs["maxlevel"] = depth + 1 

259 return LevelOrderIter(self, **kwargs) 

260 

261 def level_order_group_iter(self, depth=None, **kwargs) -> LevelOrderGroupIter: 

262 """  

263 Returns a level-order iterator with grouping starting at this node. 

264  

265 https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.levelordergroupiter.LevelOrderGroupIter 

266 """ 

267 if depth is not None: 

268 kwargs["maxlevel"] = depth + 1 

269 return LevelOrderGroupIter(self, **kwargs) 

270 

271 def zig_zag_group_iter(self, depth=None, **kwargs) -> ZigZagGroupIter: 

272 """  

273 Returns a zig-zag iterator with grouping starting at this node. 

274  

275 https://anytree.readthedocs.io/en/latest/api/anytree.iterators.html#anytree.iterators.zigzaggroupiter.ZigZagGroupIter 

276 """ 

277 if depth is not None: 

278 kwargs["maxlevel"] = depth + 1 

279 return ZigZagGroupIter(self, **kwargs) 

280