Coverage for hierarchicalsoftmax/metrics.py: 100.00%
160 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-02 01:49 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-02 01:49 +0000
1from sklearn.metrics import f1_score, precision_score, recall_score
2import torch
3from typing import Callable
4from collections.abc import Sequence
5from torch.nn import Module
6from torch import Tensor
7from torchmetrics.metric import Metric, apply_to_collection
9from . import inference, nodes
10from .inference import ShapeError
14def target_max_depth(target_tensor:torch.Tensor, root:nodes.SoftmaxNode, max_depth:int):
15 """ Converts the target tensor to the max depth of the tree. """
16 if max_depth:
17 max_depth_target_nodes = [root.node_list[target].path[:max_depth+1][-1] for target in target_tensor]
18 target_tensor = root.get_node_ids_tensor(max_depth_target_nodes)
20 return target_tensor
23def greedy_accuracy(prediction_tensor, target_tensor, root, max_depth=None):
24 """
25 Gives the accuracy of predicting the target in a hierarchy tree.
27 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree.
29 Args:
30 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size)
31 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,).
32 root (SoftmaxNode): The root of the hierarchy tree.
34 Returns:
35 float: The accuracy value (i.e. the number that are correct divided by the total number of samples)
36 """
37 prediction_node_ids = inference.greedy_prediction_node_ids_tensor(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth)
38 target_tensor = target_max_depth(target_tensor, root, max_depth)
40 return (prediction_node_ids.to(target_tensor.device) == target_tensor).float().mean()
43def greedy_accuracy_depth_one(prediction_tensor, target_tensor, root):
44 return greedy_accuracy(prediction_tensor, target_tensor, root, max_depth=1)
47def greedy_accuracy_depth_two(prediction_tensor, target_tensor, root):
48 return greedy_accuracy(prediction_tensor, target_tensor, root, max_depth=2)
51def depth_accurate(prediction_tensor, target_tensor, root:nodes.SoftmaxNode, max_depth:int=0, threshold:float|None=None):
52 """ Returns a tensor of shape (samples,) with the depth of predictions which were accurate """
53 depths = []
55 if root.softmax_start_index is None:
56 raise nodes.IndexNotSetError(f"The index of the root node {root} has not been set. Call `set_indexes` on this object.")
58 if isinstance(prediction_tensor, tuple) and len(prediction_tensor) == 1:
59 prediction_tensor = prediction_tensor[0]
61 if prediction_tensor.shape[-1] != root.layer_size:
62 raise ShapeError(
63 f"The predictions tensor given to {__name__} has final dimensions of {prediction_tensor.shape[-1]}. "
64 f"That is not compatible with the root node which expects prediciton tensors to have a final dimension of {root.layer_size}."
65 )
67 for predictions, target in zip(prediction_tensor, target_tensor):
68 node = root
69 depth = 0
70 target_node = root.node_list[target]
71 target_path = target_node.path
72 target_path_length = len(target_path)
75 while (node.children):
76 # This would be better if we could use torch.argmax but it doesn't work with MPS in the production version of pytorch
77 # See https://github.com/pytorch/pytorch/issues/98191
78 # https://github.com/pytorch/pytorch/pull/104374
79 if len(node.children) == 1:
80 # if this node use just one child, then we don't check the prediction
81 prediction_child_index = 0
82 else:
83 prediction_child_index = torch.max(predictions[node.softmax_start_index:node.softmax_end_index], dim=0).indices
85 node = node.children[prediction_child_index]
86 depth += 1
88 if depth < target_path_length and node != target_path[depth]:
89 depth -= 1
90 break
92 # Stop if we have reached the maximum depth
93 if max_depth and depth >= max_depth:
94 break
96 depths.append(depth)
98 return torch.tensor(depths, dtype=int)
101def greedy_f1_score(prediction_tensor:torch.Tensor, target_tensor:torch.Tensor, root:nodes.SoftmaxNode, average:str="macro", max_depth=None) -> float:
102 """
103 Gives the f1 score of predicting the target i.e. a harmonic mean of the precision and recall.
105 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree.
107 Args:
108 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size)
109 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,).
110 root (SoftmaxNode): The root of the hierarchy tree.
111 average (str, optional): The type of averaging over the different classes.
112 Options are: 'micro', 'macro', 'samples', 'weighted', 'binary' or None.
113 See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html for details.
114 Defaults to "macro".
116 Returns:
117 float: The f1 score
118 """
119 prediction_node_ids = inference.greedy_prediction_node_ids_tensor(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth)
120 target_tensor = target_max_depth(target_tensor, root, max_depth)
122 return f1_score(target_tensor.cpu(), prediction_node_ids.cpu(), average=average)
125def greedy_precision(prediction_tensor:torch.Tensor, target_tensor:torch.Tensor, root:nodes.SoftmaxNode, average:str="macro", max_depth=None) -> float:
126 """
127 Gives the precision score of predicting the target.
129 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree.
131 Args:
132 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size)
133 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,).
134 root (SoftmaxNode): The root of the hierarchy tree.
135 average (str, optional): The type of averaging over the different classes.
136 Options are: 'micro', 'macro', 'samples', 'weighted', 'binary' or None.
137 See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html for details.
138 Defaults to "macro".
140 Returns:
141 float: The precision
142 """
143 prediction_node_ids = inference.greedy_prediction_node_ids_tensor(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth)
144 target_tensor = target_max_depth(target_tensor, root, max_depth)
146 return precision_score(target_tensor.cpu(), prediction_node_ids.cpu(), average=average)
149def greedy_recall(prediction_tensor:torch.Tensor, target_tensor:torch.Tensor, root:nodes.SoftmaxNode, average:str="macro", max_depth=None) -> float:
150 """
151 Gives the recall score of predicting the target.
153 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree.
155 Args:
156 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size)
157 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,).
158 root (SoftmaxNode): The root of the hierarchy tree.
159 average (str, optional): The type of averaging over the different classes.
160 Options are: 'micro', 'macro', 'samples', 'weighted', 'binary' or None.
161 See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html for details.
162 Defaults to "macro".
164 Returns:
165 float: The recall
166 """
167 prediction_node_ids = inference.greedy_prediction_node_ids_tensor(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth)
168 target_tensor = target_max_depth(target_tensor, root, max_depth)
170 return recall_score(target_tensor.cpu(), prediction_node_ids.cpu(), average=average)
173def greedy_accuracy_parent(prediction_tensor, target_tensor, root, max_depth=None):
174 """
175 Gives the accuracy of predicting the parent of the target in a hierarchy tree.
177 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree.
179 Args:
180 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size)
181 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,).
182 root (SoftmaxNode): The root of the hierarchy tree.
184 Returns:
185 float: The accuracy value (i.e. the number that are correct divided by the total number of samples)
186 """
187 prediction_nodes = inference.greedy_predictions(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth)
188 prediction_parents = [node.parent for node in prediction_nodes]
189 prediction_parent_ids = root.get_node_ids_tensor(prediction_parents)
191 target_tensor = target_max_depth(target_tensor, root, max_depth)
192 target_parents = [root.node_list[target].parent for target in target_tensor]
193 target_parent_ids = root.get_node_ids_tensor(target_parents)
195 return (prediction_parent_ids.to(target_parent_ids.device) == target_parent_ids).float().mean()
198class GreedyAccuracy():
199 name:str = "greedy"
201 def __init__(self, root:nodes.SoftmaxNode, name="greedy_accuracy", max_depth=None):
202 self.max_depth = max_depth
203 self.name = name
204 self.root = root
206 @property
207 def __name__(self):
208 """ For using as a FastAI metric. """
209 return self.name
211 def __call__(self, predictions, targets):
212 return greedy_accuracy(predictions, targets, self.root, max_depth=self.max_depth)
215class HierarchicalSoftmaxTorchMetric(Metric):
216 def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> Module:
217 """Overwrite `_apply` function such that we can also move metric states to the correct device.
219 This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, `.float`, `.half` etc. methods
220 are called. Dtype conversion is guarded and will only happen through the special `set_dtype` method.
222 Overriding because there is an issue device in the parent class.
224 Args:
225 fn: the function to apply
226 exclude_state: list of state variables to exclude from applying the function, that then needs to be handled
227 by the metric class itself.
228 """
229 this = super(Metric, self)._apply(fn)
230 fs = str(fn)
231 cond = any(f in fs for f in ["Module.type", "Module.half", "Module.float", "Module.double", "Module.bfloat16"])
232 if not self._dtype_convert and cond:
233 return this
235 # Also apply fn to metric states and defaults
236 for key, value in this._defaults.items():
237 if key in exclude_state:
238 continue
240 if isinstance(value, Tensor):
241 this._defaults[key] = fn(value)
242 elif isinstance(value, Sequence):
243 this._defaults[key] = [fn(v) for v in value]
245 current_val = getattr(this, key)
246 if isinstance(current_val, Tensor):
247 setattr(this, key, fn(current_val))
248 elif isinstance(current_val, Sequence):
249 setattr(this, key, [fn(cur_v) for cur_v in current_val])
250 else:
251 raise TypeError(
252 f"Expected metric state to be either a Tensor or a list of Tensor, but encountered {current_val}"
253 )
255 # Additional apply to forward cache and computed attributes (may be nested)
256 if this._computed is not None:
257 this._computed = apply_to_collection(this._computed, Tensor, fn)
258 if this._forward_cache is not None:
259 this._forward_cache = apply_to_collection(this._forward_cache, Tensor, fn)
261 return this
264class GreedyAccuracyTorchMetric(HierarchicalSoftmaxTorchMetric):
265 def __init__(self, root:nodes.SoftmaxNode, name:str="", max_depth=None):
266 super().__init__()
267 self.root = root
268 self.max_depth = max_depth
269 self.name = name or (f"greedy_accuracy_{max_depth}" if max_depth else "greedy_accuracy")
270 self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
271 self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
273 def update(self, predictions, targets):
274 self.total += targets.size(0)
275 self.correct += int(greedy_accuracy(predictions, targets, self.root, max_depth=self.max_depth) * targets.size(0))
277 def compute(self):
278 return self.correct / self.total
281class RankAccuracyTorchMetric(HierarchicalSoftmaxTorchMetric):
282 def __init__(self, root, ranks: dict[int, str], name: str = "rank_accuracy"):
283 super().__init__()
284 self.root = root
285 self.ranks = ranks
286 self.name = name
288 # Use `add_state` for metrics to handle distributed reduction and device placement
289 self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
291 for rank_name in ranks.values():
292 self.add_state(rank_name, default=torch.tensor(0), dist_reduce_fx="sum")
294 def update(self, predictions, targets):
295 if isinstance(predictions, tuple) and len(predictions) == 1:
296 predictions = predictions[0]
298 # Ensure tensors match the device
299 predictions = predictions.to(self.device)
300 targets = targets.to(self.device)
302 self.total += targets.size(0)
303 depth_accurate_tensor = depth_accurate(predictions, targets, self.root)
305 for depth, rank_name in self.ranks.items():
306 accurate_at_depth = (depth_accurate_tensor >= depth).sum()
307 setattr(self, rank_name, getattr(self, rank_name) + accurate_at_depth)
309 def compute(self):
310 # Compute final metric values
311 return {
312 rank_name: getattr(self, rank_name) / self.total
313 for rank_name in self.ranks.values()
314 }
317class LeafAccuracyTorchMetric(HierarchicalSoftmaxTorchMetric):
318 def __init__(self, root:nodes.SoftmaxNode, name:str="", max_depth=None):
319 super().__init__()
320 self.root = root
321 self.max_depth = max_depth
322 self.name = name or (f"leaf_accuracy_{max_depth}" if max_depth else "leaf_accuracy")
323 self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
324 self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
325 self.node_indexes = torch.as_tensor([node.best_index_in_softmax_layer() for node in self.root.node_list])
326 self.leaf_indexes = torch.as_tensor(self.root.leaf_indexes)
328 def update(self, predictions, targets):
329 self.total += targets.size(0)
331 # Make sure the tensors are on the same device
332 self.node_indexes = self.node_indexes.to(predictions.device)
333 self.leaf_indexes = self.leaf_indexes.to(predictions.device)
335 target_indices = torch.index_select(self.node_indexes.to(targets.device), 0, targets)
337 # get indices of the maximum values along the last dimension
338 probabilities = inference.leaf_probabilities(prediction_tensor=predictions, root=self.root)
339 _, max_indices = torch.max(probabilities, dim=1)
340 predicted_leaf_indices = torch.index_select(self.root.leaf_indexes.to(targets.device), 0, max_indices)
342 self.correct += (predicted_leaf_indices == target_indices).sum()
344 def compute(self):
345 return self.correct / self.total