Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from typing import List, Optional 

2import torch 

3from pathlib import Path 

4from anytree import PreOrderIter 

5 

6from . import nodes 

7from .dotexporter import ThresholdDotExporter 

8 

9 

10class ShapeError(RuntimeError): 

11 """ 

12 Raised when the shape of a tensor is different to what is expected. 

13 """ 

14 

15def node_probabilities(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode) -> torch.Tensor: 

16 """ 

17 """ 

18 probabilities = torch.zeros_like(prediction_tensor) 

19 

20 if root.softmax_start_index is None: 

21 raise nodes.IndexNotSetError(f"The index of the root node {root} has not been set. Call `set_indexes` on this object.") 

22 

23 if prediction_tensor.shape[-1] != root.layer_size: 

24 raise ShapeError( 

25 f"The predictions tensor given to {__name__} has final dimensions of {prediction_tensor.shape[-1]}. " 

26 "That is not compatible with the root node which expects prediciton tensors to have a final dimension of {root.layer_size}." 

27 ) 

28 

29 for node in PreOrderIter(root): 

30 if node.is_leaf: 

31 continue 

32 elif node == root: 

33 my_probability = 1.0 

34 else : 

35 my_probability = probabilities[:,node.index_in_softmax_layer] 

36 my_probability = my_probability[:,None] 

37 

38 softmax_probabilities = torch.softmax( 

39 prediction_tensor[:,node.softmax_start_index:node.softmax_end_index], 

40 dim=1, 

41 ) 

42 probabilities[:,node.softmax_start_index:node.softmax_end_index] = softmax_probabilities * my_probability 

43 

44 return probabilities 

45 

46 

47def leaf_probabilities(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode) -> torch.Tensor: 

48 """ 

49 """ 

50 probabilities = node_probabilities(prediction_tensor, root=root) 

51 return torch.index_select(probabilities, 1, root.leaf_indexes_in_softmax_layer) 

52 

53 

54def greedy_predictions(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode, max_depth:Optional[int]=None, threshold:Optional[float]=None) -> List[nodes.SoftmaxNode]: 

55 """ 

56 Takes the prediction scores for a number of samples and converts it to a list of predictions of nodes in the tree. 

57 

58 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree. 

59 

60 Args: 

61 prediction_tensor (torch.Tensor): The output from the softmax layer.  

62 Shape (samples, root.layer_size) 

63 Works with raw scores or probabilities. 

64 root (SoftmaxNode): The root softmax node. Needs `set_indexes` to have been called. 

65 prediction_tensor (torch.Tensor): The predictions coming from the softmax layer. Shape (samples, root.layer_size) 

66 max_depth (int, optional): If set, then it only gives predictions at a maximum of this number of levels from the root. 

67 threshold (int, optional): If set, then it only gives predictions where the value at the node is greater than this threshold. 

68 Designed for use with probabilities. 

69 

70 Returns: 

71 List[nodes.SoftmaxNode]: A list of nodes predicted for each sample. 

72 """ 

73 prediction_nodes = [] 

74 

75 if root.softmax_start_index is None: 

76 raise nodes.IndexNotSetError(f"The index of the root node {root} has not been set. Call `set_indexes` on this object.") 

77 

78 if prediction_tensor.shape[-1] != root.layer_size: 

79 raise ShapeError( 

80 f"The predictions tensor given to {__name__} has final dimensions of {prediction_tensor.shape[-1]}. " 

81 "That is not compatible with the root node which expects prediciton tensors to have a final dimension of {root.layer_size}." 

82 ) 

83 

84 for predictions in prediction_tensor: 

85 node = root 

86 depth = 1 

87 while (node.children): 

88 # This would be better if we could use torch.argmax but it doesn't work with MPS in the production version of pytorch 

89 # See https://github.com/pytorch/pytorch/issues/98191 

90 # https://github.com/pytorch/pytorch/pull/104374 

91 prediction_child_index = torch.max(predictions[node.softmax_start_index:node.softmax_end_index], dim=0).indices 

92 

93 # Stop if the prediction is below the threshold 

94 if threshold and predictions[node.softmax_start_index+prediction_child_index] < threshold: 

95 break 

96 

97 node = node.children[prediction_child_index] 

98 

99 # Stop if we have reached the maximum depth 

100 if max_depth and depth >= max_depth: 

101 break 

102 

103 depth += 1 

104 

105 prediction_nodes.append(node) 

106 

107 return prediction_nodes 

108 

109 

110def greedy_prediction_node_ids(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode, max_depth:Optional[int]=None) -> List[int]: 

111 """ 

112 Takes the prediction scores for a number of samples and converts it to a list of predictions of nodes in the tree. 

113 

114 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree. 

115 

116 Args: 

117 root (SoftmaxNode): The root softmax node. Needs `set_indexes` to have been called. 

118 prediction_tensor (torch.Tensor): The predictions coming from the softmax layer. Shape (samples, root.layer_size) 

119 max_depth (int, optional): If set, then it only gives predictions at a maximum of this number of levels from the root. 

120 

121 Returns: 

122 List[int]: A list of node IDs predicted for each sample. 

123 """ 

124 prediction_nodes = greedy_predictions(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth) 

125 return root.get_node_ids(prediction_nodes) 

126 

127 

128def greedy_prediction_node_ids_tensor(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode, max_depth:Optional[int]=None) -> torch.Tensor: 

129 node_ids = greedy_prediction_node_ids(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth) 

130 return torch.as_tensor( node_ids, dtype=int) 

131 

132 

133def render_probabilities( 

134 root:nodes.SoftmaxNode, 

135 filepaths:List[Path]=None, 

136 prediction_color="red", 

137 non_prediction_color="gray", 

138 prediction_tensor:torch.Tensor=None, 

139 probabilities:torch.Tensor=None, 

140 predictions:List[nodes.SoftmaxNode]=None, 

141 horizontal:bool=True, 

142 threshold:float=0.005, 

143 ) -> List[ThresholdDotExporter]: 

144 """ 

145 Renders the probabilities of each node in the tree as a graphviz graph. 

146 

147 See https://anytree.readthedocs.io/en/latest/_modules/anytree/exporter/dotexporter.html for more information. 

148 

149 Args: 

150 prediction_tensor (torch.Tensor): The output activations from the softmax layer. Shape (samples, root.layer_size) 

151 root (SoftmaxNode): The root softmax node. Needs `set_indexes` to have been called. 

152 filepaths (List[Path], optional): Paths to locations where the files can be saved.  

153 Can have extension .dot or another format which can be interpreted by GraphViz such as .png or .svg.  

154 Defaults to None so that files are not saved. 

155 prediction_color (str, optional): The color for the greedy prediction nodes and edges. Defaults to "red". 

156 non_prediction_color (str, optional): The color for the edges which weren't predicted. Defaults to "gray". 

157 

158 Returns: 

159 List[DotExporter]: The list of rendered graphs. 

160 """ 

161 if probabilities is None: 

162 assert prediction_tensor is not None, "Either `prediction_tensor` or `probabilities` must be given." 

163 probabilities = node_probabilities(prediction_tensor, root=root) 

164 

165 if predictions is None: 

166 assert prediction_tensor is not None, "Either `prediction_tensor` or `node_probabilities` must be given." 

167 predictions = greedy_predictions(prediction_tensor, root=root) 

168 

169 graphs = [] 

170 for my_probabilities, my_prediction in zip(probabilities, predictions): 

171 greedy_nodes = my_prediction.ancestors + (my_prediction,) 

172 graphs.append(ThresholdDotExporter( 

173 root, 

174 probabilities=my_probabilities, 

175 greedy_nodes=greedy_nodes, 

176 horizontal=horizontal, 

177 prediction_color=prediction_color, 

178 non_prediction_color=non_prediction_color, 

179 threshold=threshold, 

180 )) 

181 

182 if filepaths: 

183 for graph, filepath in zip(graphs, filepaths): 

184 filepath = Path(filepath) 

185 filepath.parent.mkdir(exist_ok=True, parents=True) 

186 if filepath.suffix == ".dot": 

187 graph.to_dotfile(str(filepath)) 

188 else: 

189 graph.to_picture(str(filepath)) 

190 

191 return graphs 

192