Coverage for hierarchicalsoftmax/inference.py: 100.00%

106 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-02 01:49 +0000

1from typing import List, Optional 

2import torch 

3from pathlib import Path 

4from anytree import PreOrderIter 

5from rich.progress import track 

6 

7from . import nodes 

8from .dotexporter import ThresholdDotExporter 

9 

10 

11class ShapeError(RuntimeError): 

12 """ 

13 Raised when the shape of a tensor is different to what is expected. 

14 """ 

15 

16def node_probabilities(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode, progress_bar:bool=False) -> torch.Tensor: 

17 """ 

18 Takes the prediction scores for a number of samples and converts it to a list of probabilities of nodes in the tree. 

19 """ 

20 probabilities = torch.zeros(size=prediction_tensor.shape, device=prediction_tensor.device) 

21 

22 if root.softmax_start_index is None: 

23 raise nodes.IndexNotSetError(f"The index of the root node {root} has not been set. Call `set_indexes` on this object.") 

24 

25 if prediction_tensor.shape[-1] != root.layer_size: 

26 raise ShapeError( 

27 f"The predictions tensor given to {__name__} has final dimensions of {prediction_tensor.shape[-1]}. " 

28 f"That is not compatible with the root node which expects prediciton tensors to have a final dimension of {root.layer_size}." 

29 ) 

30 

31 for node in track(PreOrderIter(root)) if progress_bar else PreOrderIter(root): 

32 if node.is_leaf: 

33 continue 

34 elif node == root: 

35 my_probability = 1.0 

36 elif node.index_in_softmax_layer != None: 

37 my_probability = probabilities[:,node.index_in_softmax_layer] 

38 my_probability = my_probability[:,None] 

39 

40 if len(node.children) == 1: 

41 # If this has just one child, then skip it 

42 continue 

43 

44 softmax_probabilities = torch.softmax( 

45 prediction_tensor[:,node.softmax_start_index:node.softmax_end_index], 

46 dim=1, 

47 ) 

48 

49 probabilities[:,node.softmax_start_index:node.softmax_end_index] = softmax_probabilities * my_probability 

50 

51 return probabilities 

52 

53 

54def leaf_probabilities(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode) -> torch.Tensor: 

55 """ 

56 Takes the prediction scores for a number of samples and converts it to a list of probabilities of nodes in the tree. 

57 """ 

58 probabilities = node_probabilities(prediction_tensor, root=root) 

59 return torch.index_select(probabilities, 1, root.leaf_indexes.to(probabilities.device)) 

60 

61 

62def greedy_predictions( 

63 prediction_tensor:torch.Tensor, 

64 root:nodes.SoftmaxNode, 

65 max_depth:Optional[int]=None, 

66 threshold:Optional[float]=None, 

67 progress_bar:bool=False, 

68 ) -> List[nodes.SoftmaxNode]: 

69 """ 

70 Takes the prediction scores for a number of samples and converts it to a list of predictions of nodes in the tree. 

71 

72 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree. 

73 

74 Args: 

75 prediction_tensor (torch.Tensor): The output from the softmax layer.  

76 Shape (samples, root.layer_size) 

77 Works with raw scores or probabilities. 

78 root (SoftmaxNode): The root softmax node. Needs `set_indexes` to have been called. 

79 prediction_tensor (torch.Tensor): The predictions coming from the softmax layer. Shape (samples, root.layer_size) 

80 max_depth (int, optional): If set, then it only gives predictions at a maximum of this number of levels from the root. 

81 threshold (int, optional): If set, then it only gives predictions where the value at the node is greater than this threshold. 

82 Designed for use with probabilities. 

83 

84 Returns: 

85 List[nodes.SoftmaxNode]: A list of nodes predicted for each sample. 

86 """ 

87 prediction_nodes = [] 

88 

89 if isinstance(prediction_tensor, tuple) and len(prediction_tensor) == 1: 

90 prediction_tensor = prediction_tensor[0] 

91 

92 if root.softmax_start_index is None: 

93 raise nodes.IndexNotSetError(f"The index of the root node {root} has not been set. Call `set_indexes` on this object.") 

94 

95 if prediction_tensor.shape[-1] != root.layer_size: 

96 raise ShapeError( 

97 f"The predictions tensor given to {__name__} has final dimensions of {prediction_tensor.shape[-1]}. " 

98 f"That is not compatible with the root node which expects prediciton tensors to have a final dimension of {root.layer_size}." 

99 ) 

100 

101 for predictions in track(prediction_tensor) if progress_bar else prediction_tensor: 

102 node = root 

103 depth = 1 

104 while (node.children): 

105 if len(node.children) == 1: 

106 # if this has just one child, then we don't check the prediction 

107 prediction_child_index = 0 

108 else: 

109 # This would be better if we could use torch.argmax but it doesn't work with MPS in the production version of pytorch 

110 # See https://github.com/pytorch/pytorch/issues/98191 

111 # https://github.com/pytorch/pytorch/pull/104374 

112 prediction_child_index = torch.max(predictions[node.softmax_start_index:node.softmax_end_index], dim=0).indices 

113 

114 # Stop if the prediction is below the threshold 

115 if threshold and predictions[node.softmax_start_index+prediction_child_index] < threshold: 

116 break 

117 

118 node = node.children[prediction_child_index] 

119 

120 # Stop if we have reached the maximum depth 

121 if max_depth and depth >= max_depth: 

122 break 

123 

124 depth += 1 

125 

126 prediction_nodes.append(node) 

127 

128 return prediction_nodes 

129 

130 

131def greedy_lineage_probabilities( 

132 prediction_tensor:torch.Tensor, 

133 root:nodes.SoftmaxNode, 

134 max_depth:Optional[int]=None, 

135 threshold:Optional[float]=None, 

136 progress_bar:bool=False, 

137 ) -> List[nodes.SoftmaxNode]: 

138 """ 

139 Takes the prediction scores for a number of samples and converts it to a list of predictions of nodes in the tree. 

140 

141 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree. 

142 

143 Args: 

144 prediction_tensor (torch.Tensor): The output from the softmax layer.  

145 Shape (samples, root.layer_size) 

146 Works with raw scores or probabilities. 

147 root (SoftmaxNode): The root softmax node. Needs `set_indexes` to have been called. 

148 prediction_tensor (torch.Tensor): The predictions coming from the softmax layer. Shape (samples, root.layer_size) 

149 max_depth (int, optional): If set, then it only gives predictions at a maximum of this number of levels from the root. 

150 threshold (int, optional): If set, then it only gives predictions where the value at the node is greater than this threshold. 

151 Designed for use with probabilities. 

152 

153 Returns: 

154 List[List[Tuple[nodes.SoftmaxNode, float]]]: A list of nodes predicted for each sample with their probabilities. 

155 """ 

156 prediction_lineages = [] 

157 

158 if isinstance(prediction_tensor, tuple) and len(prediction_tensor) == 1: 

159 prediction_tensor = prediction_tensor[0] 

160 

161 if root.softmax_start_index is None: 

162 raise nodes.IndexNotSetError(f"The index of the root node {root} has not been set. Call `set_indexes` on this object.") 

163 

164 if prediction_tensor.shape[-1] != root.layer_size: 

165 raise ShapeError( 

166 f"The predictions tensor given to {__name__} has final dimensions of {prediction_tensor.shape[-1]}. " 

167 f"That is not compatible with the root node which expects prediciton tensors to have a final dimension of {root.layer_size}." 

168 ) 

169 

170 for predictions in track(prediction_tensor) if progress_bar else prediction_tensor: 

171 node = root 

172 depth = 1 

173 probability = 1.0 # Start with the root node having a probability of 1.0 

174 my_lineage = [] 

175 while (node.children): 

176 if len(node.children) == 1: 

177 # if this has just one child, then we don't check the prediction 

178 prediction_child_index = 0 

179 else: 

180 # This would be better if we could use torch.argmax but it doesn't work with MPS in the production version of pytorch 

181 # See https://github.com/pytorch/pytorch/issues/98191 

182 # https://github.com/pytorch/pytorch/pull/104374 

183 child_probabilities = torch.softmax( 

184 predictions[node.softmax_start_index:node.softmax_end_index], 

185 dim=0, 

186 ) 

187 prediction_child_index = torch.max(child_probabilities, dim=0).indices 

188 probability *= child_probabilities[prediction_child_index].item() 

189 

190 # Stop if the probability is below the threshold 

191 if threshold and probability < threshold: 

192 break 

193 

194 node = node.children[prediction_child_index] 

195 

196 my_lineage.append( 

197 (node, probability) 

198 ) 

199 

200 # Stop if we have reached the maximum depth 

201 if max_depth and depth >= max_depth: 

202 break 

203 

204 depth += 1 

205 

206 prediction_lineages.append(my_lineage) 

207 

208 return prediction_lineages 

209 

210 

211def greedy_prediction_node_ids( 

212 prediction_tensor:torch.Tensor, 

213 root:nodes.SoftmaxNode, 

214 max_depth:Optional[int]=None, 

215 threshold:Optional[float]=None, 

216 progress_bar:bool=False, 

217 ) -> List[int]: 

218 """ 

219 Takes the prediction scores for a number of samples and converts it to a list of predictions of nodes in the tree. 

220 

221 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree. 

222 

223 Args: 

224 root (SoftmaxNode): The root softmax node. Needs `set_indexes` to have been called. 

225 prediction_tensor (torch.Tensor): The predictions coming from the softmax layer. Shape (samples, root.layer_size) 

226 max_depth (int, optional): If set, then it only gives predictions at a maximum of this number of levels from the root. 

227 

228 Returns: 

229 List[int]: A list of node IDs predicted for each sample. 

230 """ 

231 prediction_nodes = greedy_predictions( 

232 prediction_tensor=prediction_tensor, 

233 root=root, 

234 max_depth=max_depth, 

235 threshold=threshold, 

236 progress_bar=progress_bar,) 

237 return root.get_node_ids(prediction_nodes) 

238 

239 

240def greedy_prediction_node_ids_tensor(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode, max_depth:Optional[int]=None) -> torch.Tensor: 

241 node_ids = greedy_prediction_node_ids(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth) 

242 return torch.as_tensor( node_ids, dtype=int) 

243 

244 

245def render_probabilities( 

246 root:nodes.SoftmaxNode, 

247 filepaths:List[Path]=None, 

248 prediction_color="red", 

249 non_prediction_color="gray", 

250 prediction_tensor:torch.Tensor=None, 

251 probabilities:torch.Tensor=None, 

252 predictions:List[nodes.SoftmaxNode]=None, 

253 horizontal:bool=True, 

254 threshold:float=0.005, 

255 ) -> List[ThresholdDotExporter]: 

256 """ 

257 Renders the probabilities of each node in the tree as a graphviz graph. 

258 

259 See https://anytree.readthedocs.io/en/latest/_modules/anytree/exporter/dotexporter.html for more information. 

260 

261 Args: 

262 prediction_tensor (torch.Tensor): The output activations from the softmax layer. Shape (samples, root.layer_size) 

263 root (SoftmaxNode): The root softmax node. Needs `set_indexes` to have been called. 

264 filepaths (List[Path], optional): Paths to locations where the files can be saved.  

265 Can have extension .dot or another format which can be interpreted by GraphViz such as .png or .svg.  

266 Defaults to None so that files are not saved. 

267 prediction_color (str, optional): The color for the greedy prediction nodes and edges. Defaults to "red". 

268 non_prediction_color (str, optional): The color for the edges which weren't predicted. Defaults to "gray". 

269 

270 Returns: 

271 List[DotExporter]: The list of rendered graphs. 

272 """ 

273 if probabilities is None: 

274 assert prediction_tensor is not None, "Either `prediction_tensor` or `probabilities` must be given." 

275 probabilities = node_probabilities(prediction_tensor, root=root) 

276 

277 if predictions is None: 

278 assert prediction_tensor is not None, "Either `prediction_tensor` or `node_probabilities` must be given." 

279 predictions = greedy_predictions(prediction_tensor, root=root) 

280 

281 graphs = [] 

282 for my_probabilities, my_prediction in zip(probabilities, predictions): 

283 greedy_nodes = my_prediction.ancestors + (my_prediction,) 

284 graphs.append(ThresholdDotExporter( 

285 root, 

286 probabilities=my_probabilities, 

287 greedy_nodes=greedy_nodes, 

288 horizontal=horizontal, 

289 prediction_color=prediction_color, 

290 non_prediction_color=non_prediction_color, 

291 threshold=threshold, 

292 )) 

293 

294 if filepaths: 

295 for graph, filepath in zip(graphs, filepaths): 

296 filepath = Path(filepath) 

297 filepath.parent.mkdir(exist_ok=True, parents=True) 

298 if filepath.suffix == ".dot": 

299 graph.to_dotfile(str(filepath)) 

300 else: 

301 graph.to_picture(str(filepath)) 

302 

303 return graphs 

304