Coverage for hierarchicalsoftmax/metrics.py : 100.00%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from sklearn.metrics import f1_score, precision_score, recall_score
2import torch
3from . import inference, nodes
6def target_max_depth(target_tensor:torch.Tensor, root:nodes.SoftmaxNode, max_depth:int):
7 """ Converts the target tensor to the max depth of the tree. """
8 if max_depth:
9 max_depth_target_nodes = [root.node_list[target].path[:max_depth+1][-1] for target in target_tensor]
10 target_tensor = root.get_node_ids_tensor(max_depth_target_nodes)
12 return target_tensor
15def greedy_accuracy(prediction_tensor, target_tensor, root, max_depth=None):
16 """
17 Gives the accuracy of predicting the target in a hierarchy tree.
19 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree.
21 Args:
22 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size)
23 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,).
24 root (SoftmaxNode): The root of the hierarchy tree.
26 Returns:
27 float: The accuracy value (i.e. the number that are correct divided by the total number of samples)
28 """
29 prediction_node_ids = inference.greedy_prediction_node_ids_tensor(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth)
30 target_tensor = target_max_depth(target_tensor, root, max_depth)
32 return (prediction_node_ids.to(target_tensor.device) == target_tensor).float().mean()
35def greedy_accuracy_depth_one(prediction_tensor, target_tensor, root):
36 return greedy_accuracy(prediction_tensor, target_tensor, root, max_depth=1)
39def greedy_accuracy_depth_two(prediction_tensor, target_tensor, root):
40 return greedy_accuracy(prediction_tensor, target_tensor, root, max_depth=2)
43def greedy_f1_score(prediction_tensor:torch.Tensor, target_tensor:torch.Tensor, root:nodes.SoftmaxNode, average:str="macro", max_depth=None) -> float:
44 """
45 Gives the f1 score of predicting the target i.e. a harmonic mean of the precision and recall.
47 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree.
49 Args:
50 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size)
51 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,).
52 root (SoftmaxNode): The root of the hierarchy tree.
53 average (str, optional): The type of averaging over the different classes.
54 Options are: 'micro', 'macro', 'samples', 'weighted', 'binary' or None.
55 See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html for details.
56 Defaults to "macro".
58 Returns:
59 float: The f1 score
60 """
61 prediction_node_ids = inference.greedy_prediction_node_ids_tensor(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth)
62 target_tensor = target_max_depth(target_tensor, root, max_depth)
64 return f1_score(target_tensor.cpu(), prediction_node_ids.cpu(), average=average)
67def greedy_precision(prediction_tensor:torch.Tensor, target_tensor:torch.Tensor, root:nodes.SoftmaxNode, average:str="macro", max_depth=None) -> float:
68 """
69 Gives the precision score of predicting the target.
71 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree.
73 Args:
74 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size)
75 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,).
76 root (SoftmaxNode): The root of the hierarchy tree.
77 average (str, optional): The type of averaging over the different classes.
78 Options are: 'micro', 'macro', 'samples', 'weighted', 'binary' or None.
79 See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html for details.
80 Defaults to "macro".
82 Returns:
83 float: The precision
84 """
85 prediction_node_ids = inference.greedy_prediction_node_ids_tensor(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth)
86 target_tensor = target_max_depth(target_tensor, root, max_depth)
88 return precision_score(target_tensor.cpu(), prediction_node_ids.cpu(), average=average)
91def greedy_recall(prediction_tensor:torch.Tensor, target_tensor:torch.Tensor, root:nodes.SoftmaxNode, average:str="macro", max_depth=None) -> float:
92 """
93 Gives the recall score of predicting the target.
95 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree.
97 Args:
98 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size)
99 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,).
100 root (SoftmaxNode): The root of the hierarchy tree.
101 average (str, optional): The type of averaging over the different classes.
102 Options are: 'micro', 'macro', 'samples', 'weighted', 'binary' or None.
103 See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html for details.
104 Defaults to "macro".
106 Returns:
107 float: The recall
108 """
109 prediction_node_ids = inference.greedy_prediction_node_ids_tensor(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth)
110 target_tensor = target_max_depth(target_tensor, root, max_depth)
112 return recall_score(target_tensor.cpu(), prediction_node_ids.cpu(), average=average)
115def greedy_accuracy_parent(prediction_tensor, target_tensor, root, max_depth=None):
116 """
117 Gives the accuracy of predicting the parent of the target in a hierarchy tree.
119 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree.
121 Args:
122 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size)
123 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,).
124 root (SoftmaxNode): The root of the hierarchy tree.
126 Returns:
127 float: The accuracy value (i.e. the number that are correct divided by the total number of samples)
128 """
129 prediction_nodes = inference.greedy_predictions(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth)
130 prediction_parents = [node.parent for node in prediction_nodes]
131 prediction_parent_ids = root.get_node_ids_tensor(prediction_parents)
133 target_tensor = target_max_depth(target_tensor, root, max_depth)
134 target_parents = [root.node_list[target].parent for target in target_tensor]
135 target_parent_ids = root.get_node_ids_tensor(target_parents)
137 return (prediction_parent_ids.to(target_parent_ids.device) == target_parent_ids).float().mean()
140class GreedyAccuracy():
141 name:str = "greedy"
143 def __init__(self, root:nodes.SoftmaxNode, name="greedy_accuracy", max_depth=None):
144 self.max_depth = max_depth
145 self.name = name
146 self.root = root
148 @property
149 def __name__(self):
150 """ For using as a FastAI metric. """
151 return self.name
153 def __call__(self, predictions, targets):
154 return greedy_accuracy(predictions, targets, self.root, max_depth=self.max_depth)