Coverage for hierarchicalsoftmax/inference.py : 100.00%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from typing import List, Optional
2import torch
3from pathlib import Path
4from anytree import PreOrderIter
6from . import nodes
7from .dotexporter import ThresholdDotExporter
10class ShapeError(RuntimeError):
11 """
12 Raised when the shape of a tensor is different to what is expected.
13 """
15def node_probabilities(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode) -> torch.Tensor:
16 """
17 """
18 probabilities = torch.zeros_like(prediction_tensor)
20 if root.softmax_start_index is None:
21 raise nodes.IndexNotSetError(f"The index of the root node {root} has not been set. Call `set_indexes` on this object.")
23 if prediction_tensor.shape[-1] != root.layer_size:
24 raise ShapeError(
25 f"The predictions tensor given to {__name__} has final dimensions of {prediction_tensor.shape[-1]}. "
26 "That is not compatible with the root node which expects prediciton tensors to have a final dimension of {root.layer_size}."
27 )
29 for node in PreOrderIter(root):
30 if node.is_leaf:
31 continue
32 elif node == root:
33 my_probability = 1.0
34 else :
35 my_probability = probabilities[:,node.index_in_softmax_layer]
36 my_probability = my_probability[:,None]
38 softmax_probabilities = torch.softmax(
39 prediction_tensor[:,node.softmax_start_index:node.softmax_end_index],
40 dim=1,
41 )
42 probabilities[:,node.softmax_start_index:node.softmax_end_index] = softmax_probabilities * my_probability
44 return probabilities
47def leaf_probabilities(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode) -> torch.Tensor:
48 """
49 """
50 probabilities = node_probabilities(prediction_tensor, root=root)
51 return torch.index_select(probabilities, 1, root.leaf_indexes_in_softmax_layer)
54def greedy_predictions(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode, max_depth:Optional[int]=None, threshold:Optional[float]=None) -> List[nodes.SoftmaxNode]:
55 """
56 Takes the prediction scores for a number of samples and converts it to a list of predictions of nodes in the tree.
58 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree.
60 Args:
61 prediction_tensor (torch.Tensor): The output from the softmax layer.
62 Shape (samples, root.layer_size)
63 Works with raw scores or probabilities.
64 root (SoftmaxNode): The root softmax node. Needs `set_indexes` to have been called.
65 prediction_tensor (torch.Tensor): The predictions coming from the softmax layer. Shape (samples, root.layer_size)
66 max_depth (int, optional): If set, then it only gives predictions at a maximum of this number of levels from the root.
67 threshold (int, optional): If set, then it only gives predictions where the value at the node is greater than this threshold.
68 Designed for use with probabilities.
70 Returns:
71 List[nodes.SoftmaxNode]: A list of nodes predicted for each sample.
72 """
73 prediction_nodes = []
75 if root.softmax_start_index is None:
76 raise nodes.IndexNotSetError(f"The index of the root node {root} has not been set. Call `set_indexes` on this object.")
78 if prediction_tensor.shape[-1] != root.layer_size:
79 raise ShapeError(
80 f"The predictions tensor given to {__name__} has final dimensions of {prediction_tensor.shape[-1]}. "
81 "That is not compatible with the root node which expects prediciton tensors to have a final dimension of {root.layer_size}."
82 )
84 for predictions in prediction_tensor:
85 node = root
86 depth = 1
87 while (node.children):
88 # This would be better if we could use torch.argmax but it doesn't work with MPS in the production version of pytorch
89 # See https://github.com/pytorch/pytorch/issues/98191
90 # https://github.com/pytorch/pytorch/pull/104374
91 prediction_child_index = torch.max(predictions[node.softmax_start_index:node.softmax_end_index], dim=0).indices
93 # Stop if the prediction is below the threshold
94 if threshold and predictions[node.softmax_start_index+prediction_child_index] < threshold:
95 break
97 node = node.children[prediction_child_index]
99 # Stop if we have reached the maximum depth
100 if max_depth and depth >= max_depth:
101 break
103 depth += 1
105 prediction_nodes.append(node)
107 return prediction_nodes
110def greedy_prediction_node_ids(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode, max_depth:Optional[int]=None) -> List[int]:
111 """
112 Takes the prediction scores for a number of samples and converts it to a list of predictions of nodes in the tree.
114 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree.
116 Args:
117 root (SoftmaxNode): The root softmax node. Needs `set_indexes` to have been called.
118 prediction_tensor (torch.Tensor): The predictions coming from the softmax layer. Shape (samples, root.layer_size)
119 max_depth (int, optional): If set, then it only gives predictions at a maximum of this number of levels from the root.
121 Returns:
122 List[int]: A list of node IDs predicted for each sample.
123 """
124 prediction_nodes = greedy_predictions(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth)
125 return root.get_node_ids(prediction_nodes)
128def greedy_prediction_node_ids_tensor(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode, max_depth:Optional[int]=None) -> torch.Tensor:
129 node_ids = greedy_prediction_node_ids(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth)
130 return torch.as_tensor( node_ids, dtype=int)
133def render_probabilities(
134 root:nodes.SoftmaxNode,
135 filepaths:List[Path]=None,
136 prediction_color="red",
137 non_prediction_color="gray",
138 prediction_tensor:torch.Tensor=None,
139 probabilities:torch.Tensor=None,
140 predictions:List[nodes.SoftmaxNode]=None,
141 horizontal:bool=True,
142 threshold:float=0.005,
143 ) -> List[ThresholdDotExporter]:
144 """
145 Renders the probabilities of each node in the tree as a graphviz graph.
147 See https://anytree.readthedocs.io/en/latest/_modules/anytree/exporter/dotexporter.html for more information.
149 Args:
150 prediction_tensor (torch.Tensor): The output activations from the softmax layer. Shape (samples, root.layer_size)
151 root (SoftmaxNode): The root softmax node. Needs `set_indexes` to have been called.
152 filepaths (List[Path], optional): Paths to locations where the files can be saved.
153 Can have extension .dot or another format which can be interpreted by GraphViz such as .png or .svg.
154 Defaults to None so that files are not saved.
155 prediction_color (str, optional): The color for the greedy prediction nodes and edges. Defaults to "red".
156 non_prediction_color (str, optional): The color for the edges which weren't predicted. Defaults to "gray".
158 Returns:
159 List[DotExporter]: The list of rendered graphs.
160 """
161 if probabilities is None:
162 assert prediction_tensor is not None, "Either `prediction_tensor` or `probabilities` must be given."
163 probabilities = node_probabilities(prediction_tensor, root=root)
165 if predictions is None:
166 assert prediction_tensor is not None, "Either `prediction_tensor` or `node_probabilities` must be given."
167 predictions = greedy_predictions(prediction_tensor, root=root)
169 graphs = []
170 for my_probabilities, my_prediction in zip(probabilities, predictions):
171 greedy_nodes = my_prediction.ancestors + (my_prediction,)
172 graphs.append(ThresholdDotExporter(
173 root,
174 probabilities=my_probabilities,
175 greedy_nodes=greedy_nodes,
176 horizontal=horizontal,
177 prediction_color=prediction_color,
178 non_prediction_color=non_prediction_color,
179 threshold=threshold,
180 ))
182 if filepaths:
183 for graph, filepath in zip(graphs, filepaths):
184 filepath = Path(filepath)
185 filepath.parent.mkdir(exist_ok=True, parents=True)
186 if filepath.suffix == ".dot":
187 graph.to_dotfile(str(filepath))
188 else:
189 graph.to_picture(str(filepath))
191 return graphs