Coverage for hierarchicalsoftmax/inference.py: 100.00%
106 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-02 01:49 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-02 01:49 +0000
1from typing import List, Optional
2import torch
3from pathlib import Path
4from anytree import PreOrderIter
5from rich.progress import track
7from . import nodes
8from .dotexporter import ThresholdDotExporter
11class ShapeError(RuntimeError):
12 """
13 Raised when the shape of a tensor is different to what is expected.
14 """
16def node_probabilities(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode, progress_bar:bool=False) -> torch.Tensor:
17 """
18 Takes the prediction scores for a number of samples and converts it to a list of probabilities of nodes in the tree.
19 """
20 probabilities = torch.zeros(size=prediction_tensor.shape, device=prediction_tensor.device)
22 if root.softmax_start_index is None:
23 raise nodes.IndexNotSetError(f"The index of the root node {root} has not been set. Call `set_indexes` on this object.")
25 if prediction_tensor.shape[-1] != root.layer_size:
26 raise ShapeError(
27 f"The predictions tensor given to {__name__} has final dimensions of {prediction_tensor.shape[-1]}. "
28 f"That is not compatible with the root node which expects prediciton tensors to have a final dimension of {root.layer_size}."
29 )
31 for node in track(PreOrderIter(root)) if progress_bar else PreOrderIter(root):
32 if node.is_leaf:
33 continue
34 elif node == root:
35 my_probability = 1.0
36 elif node.index_in_softmax_layer != None:
37 my_probability = probabilities[:,node.index_in_softmax_layer]
38 my_probability = my_probability[:,None]
40 if len(node.children) == 1:
41 # If this has just one child, then skip it
42 continue
44 softmax_probabilities = torch.softmax(
45 prediction_tensor[:,node.softmax_start_index:node.softmax_end_index],
46 dim=1,
47 )
49 probabilities[:,node.softmax_start_index:node.softmax_end_index] = softmax_probabilities * my_probability
51 return probabilities
54def leaf_probabilities(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode) -> torch.Tensor:
55 """
56 Takes the prediction scores for a number of samples and converts it to a list of probabilities of nodes in the tree.
57 """
58 probabilities = node_probabilities(prediction_tensor, root=root)
59 return torch.index_select(probabilities, 1, root.leaf_indexes.to(probabilities.device))
62def greedy_predictions(
63 prediction_tensor:torch.Tensor,
64 root:nodes.SoftmaxNode,
65 max_depth:Optional[int]=None,
66 threshold:Optional[float]=None,
67 progress_bar:bool=False,
68 ) -> List[nodes.SoftmaxNode]:
69 """
70 Takes the prediction scores for a number of samples and converts it to a list of predictions of nodes in the tree.
72 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree.
74 Args:
75 prediction_tensor (torch.Tensor): The output from the softmax layer.
76 Shape (samples, root.layer_size)
77 Works with raw scores or probabilities.
78 root (SoftmaxNode): The root softmax node. Needs `set_indexes` to have been called.
79 prediction_tensor (torch.Tensor): The predictions coming from the softmax layer. Shape (samples, root.layer_size)
80 max_depth (int, optional): If set, then it only gives predictions at a maximum of this number of levels from the root.
81 threshold (int, optional): If set, then it only gives predictions where the value at the node is greater than this threshold.
82 Designed for use with probabilities.
84 Returns:
85 List[nodes.SoftmaxNode]: A list of nodes predicted for each sample.
86 """
87 prediction_nodes = []
89 if isinstance(prediction_tensor, tuple) and len(prediction_tensor) == 1:
90 prediction_tensor = prediction_tensor[0]
92 if root.softmax_start_index is None:
93 raise nodes.IndexNotSetError(f"The index of the root node {root} has not been set. Call `set_indexes` on this object.")
95 if prediction_tensor.shape[-1] != root.layer_size:
96 raise ShapeError(
97 f"The predictions tensor given to {__name__} has final dimensions of {prediction_tensor.shape[-1]}. "
98 f"That is not compatible with the root node which expects prediciton tensors to have a final dimension of {root.layer_size}."
99 )
101 for predictions in track(prediction_tensor) if progress_bar else prediction_tensor:
102 node = root
103 depth = 1
104 while (node.children):
105 if len(node.children) == 1:
106 # if this has just one child, then we don't check the prediction
107 prediction_child_index = 0
108 else:
109 # This would be better if we could use torch.argmax but it doesn't work with MPS in the production version of pytorch
110 # See https://github.com/pytorch/pytorch/issues/98191
111 # https://github.com/pytorch/pytorch/pull/104374
112 prediction_child_index = torch.max(predictions[node.softmax_start_index:node.softmax_end_index], dim=0).indices
114 # Stop if the prediction is below the threshold
115 if threshold and predictions[node.softmax_start_index+prediction_child_index] < threshold:
116 break
118 node = node.children[prediction_child_index]
120 # Stop if we have reached the maximum depth
121 if max_depth and depth >= max_depth:
122 break
124 depth += 1
126 prediction_nodes.append(node)
128 return prediction_nodes
131def greedy_lineage_probabilities(
132 prediction_tensor:torch.Tensor,
133 root:nodes.SoftmaxNode,
134 max_depth:Optional[int]=None,
135 threshold:Optional[float]=None,
136 progress_bar:bool=False,
137 ) -> List[nodes.SoftmaxNode]:
138 """
139 Takes the prediction scores for a number of samples and converts it to a list of predictions of nodes in the tree.
141 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree.
143 Args:
144 prediction_tensor (torch.Tensor): The output from the softmax layer.
145 Shape (samples, root.layer_size)
146 Works with raw scores or probabilities.
147 root (SoftmaxNode): The root softmax node. Needs `set_indexes` to have been called.
148 prediction_tensor (torch.Tensor): The predictions coming from the softmax layer. Shape (samples, root.layer_size)
149 max_depth (int, optional): If set, then it only gives predictions at a maximum of this number of levels from the root.
150 threshold (int, optional): If set, then it only gives predictions where the value at the node is greater than this threshold.
151 Designed for use with probabilities.
153 Returns:
154 List[List[Tuple[nodes.SoftmaxNode, float]]]: A list of nodes predicted for each sample with their probabilities.
155 """
156 prediction_lineages = []
158 if isinstance(prediction_tensor, tuple) and len(prediction_tensor) == 1:
159 prediction_tensor = prediction_tensor[0]
161 if root.softmax_start_index is None:
162 raise nodes.IndexNotSetError(f"The index of the root node {root} has not been set. Call `set_indexes` on this object.")
164 if prediction_tensor.shape[-1] != root.layer_size:
165 raise ShapeError(
166 f"The predictions tensor given to {__name__} has final dimensions of {prediction_tensor.shape[-1]}. "
167 f"That is not compatible with the root node which expects prediciton tensors to have a final dimension of {root.layer_size}."
168 )
170 for predictions in track(prediction_tensor) if progress_bar else prediction_tensor:
171 node = root
172 depth = 1
173 probability = 1.0 # Start with the root node having a probability of 1.0
174 my_lineage = []
175 while (node.children):
176 if len(node.children) == 1:
177 # if this has just one child, then we don't check the prediction
178 prediction_child_index = 0
179 else:
180 # This would be better if we could use torch.argmax but it doesn't work with MPS in the production version of pytorch
181 # See https://github.com/pytorch/pytorch/issues/98191
182 # https://github.com/pytorch/pytorch/pull/104374
183 child_probabilities = torch.softmax(
184 predictions[node.softmax_start_index:node.softmax_end_index],
185 dim=0,
186 )
187 prediction_child_index = torch.max(child_probabilities, dim=0).indices
188 probability *= child_probabilities[prediction_child_index].item()
190 # Stop if the probability is below the threshold
191 if threshold and probability < threshold:
192 break
194 node = node.children[prediction_child_index]
196 my_lineage.append(
197 (node, probability)
198 )
200 # Stop if we have reached the maximum depth
201 if max_depth and depth >= max_depth:
202 break
204 depth += 1
206 prediction_lineages.append(my_lineage)
208 return prediction_lineages
211def greedy_prediction_node_ids(
212 prediction_tensor:torch.Tensor,
213 root:nodes.SoftmaxNode,
214 max_depth:Optional[int]=None,
215 threshold:Optional[float]=None,
216 progress_bar:bool=False,
217 ) -> List[int]:
218 """
219 Takes the prediction scores for a number of samples and converts it to a list of predictions of nodes in the tree.
221 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree.
223 Args:
224 root (SoftmaxNode): The root softmax node. Needs `set_indexes` to have been called.
225 prediction_tensor (torch.Tensor): The predictions coming from the softmax layer. Shape (samples, root.layer_size)
226 max_depth (int, optional): If set, then it only gives predictions at a maximum of this number of levels from the root.
228 Returns:
229 List[int]: A list of node IDs predicted for each sample.
230 """
231 prediction_nodes = greedy_predictions(
232 prediction_tensor=prediction_tensor,
233 root=root,
234 max_depth=max_depth,
235 threshold=threshold,
236 progress_bar=progress_bar,)
237 return root.get_node_ids(prediction_nodes)
240def greedy_prediction_node_ids_tensor(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode, max_depth:Optional[int]=None) -> torch.Tensor:
241 node_ids = greedy_prediction_node_ids(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth)
242 return torch.as_tensor( node_ids, dtype=int)
245def render_probabilities(
246 root:nodes.SoftmaxNode,
247 filepaths:List[Path]=None,
248 prediction_color="red",
249 non_prediction_color="gray",
250 prediction_tensor:torch.Tensor=None,
251 probabilities:torch.Tensor=None,
252 predictions:List[nodes.SoftmaxNode]=None,
253 horizontal:bool=True,
254 threshold:float=0.005,
255 ) -> List[ThresholdDotExporter]:
256 """
257 Renders the probabilities of each node in the tree as a graphviz graph.
259 See https://anytree.readthedocs.io/en/latest/_modules/anytree/exporter/dotexporter.html for more information.
261 Args:
262 prediction_tensor (torch.Tensor): The output activations from the softmax layer. Shape (samples, root.layer_size)
263 root (SoftmaxNode): The root softmax node. Needs `set_indexes` to have been called.
264 filepaths (List[Path], optional): Paths to locations where the files can be saved.
265 Can have extension .dot or another format which can be interpreted by GraphViz such as .png or .svg.
266 Defaults to None so that files are not saved.
267 prediction_color (str, optional): The color for the greedy prediction nodes and edges. Defaults to "red".
268 non_prediction_color (str, optional): The color for the edges which weren't predicted. Defaults to "gray".
270 Returns:
271 List[DotExporter]: The list of rendered graphs.
272 """
273 if probabilities is None:
274 assert prediction_tensor is not None, "Either `prediction_tensor` or `probabilities` must be given."
275 probabilities = node_probabilities(prediction_tensor, root=root)
277 if predictions is None:
278 assert prediction_tensor is not None, "Either `prediction_tensor` or `node_probabilities` must be given."
279 predictions = greedy_predictions(prediction_tensor, root=root)
281 graphs = []
282 for my_probabilities, my_prediction in zip(probabilities, predictions):
283 greedy_nodes = my_prediction.ancestors + (my_prediction,)
284 graphs.append(ThresholdDotExporter(
285 root,
286 probabilities=my_probabilities,
287 greedy_nodes=greedy_nodes,
288 horizontal=horizontal,
289 prediction_color=prediction_color,
290 non_prediction_color=non_prediction_color,
291 threshold=threshold,
292 ))
294 if filepaths:
295 for graph, filepath in zip(graphs, filepaths):
296 filepath = Path(filepath)
297 filepath.parent.mkdir(exist_ok=True, parents=True)
298 if filepath.suffix == ".dot":
299 graph.to_dotfile(str(filepath))
300 else:
301 graph.to_picture(str(filepath))
303 return graphs