Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from sklearn.metrics import f1_score, precision_score, recall_score 

2import torch 

3from . import inference, nodes 

4 

5 

6def target_max_depth(target_tensor:torch.Tensor, root:nodes.SoftmaxNode, max_depth:int): 

7 """ Converts the target tensor to the max depth of the tree. """ 

8 if max_depth: 

9 max_depth_target_nodes = [root.node_list[target].path[:max_depth+1][-1] for target in target_tensor] 

10 target_tensor = root.get_node_ids_tensor(max_depth_target_nodes) 

11 

12 return target_tensor 

13 

14 

15def greedy_accuracy(prediction_tensor, target_tensor, root, max_depth=None): 

16 """ 

17 Gives the accuracy of predicting the target in a hierarchy tree. 

18 

19 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree. 

20 

21 Args: 

22 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size) 

23 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,). 

24 root (SoftmaxNode): The root of the hierarchy tree. 

25 

26 Returns: 

27 float: The accuracy value (i.e. the number that are correct divided by the total number of samples) 

28 """ 

29 prediction_node_ids = inference.greedy_prediction_node_ids_tensor(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth) 

30 target_tensor = target_max_depth(target_tensor, root, max_depth) 

31 

32 return (prediction_node_ids.to(target_tensor.device) == target_tensor).float().mean() 

33 

34 

35def greedy_accuracy_depth_one(prediction_tensor, target_tensor, root): 

36 return greedy_accuracy(prediction_tensor, target_tensor, root, max_depth=1) 

37 

38 

39def greedy_accuracy_depth_two(prediction_tensor, target_tensor, root): 

40 return greedy_accuracy(prediction_tensor, target_tensor, root, max_depth=2) 

41 

42 

43def greedy_f1_score(prediction_tensor:torch.Tensor, target_tensor:torch.Tensor, root:nodes.SoftmaxNode, average:str="macro", max_depth=None) -> float: 

44 """ 

45 Gives the f1 score of predicting the target i.e. a harmonic mean of the precision and recall. 

46 

47 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree. 

48 

49 Args: 

50 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size) 

51 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,). 

52 root (SoftmaxNode): The root of the hierarchy tree. 

53 average (str, optional): The type of averaging over the different classes. 

54 Options are: 'micro', 'macro', 'samples', 'weighted', 'binary' or None.  

55 See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html for details. 

56 Defaults to "macro". 

57 

58 Returns: 

59 float: The f1 score  

60 """ 

61 prediction_node_ids = inference.greedy_prediction_node_ids_tensor(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth) 

62 target_tensor = target_max_depth(target_tensor, root, max_depth) 

63 

64 return f1_score(target_tensor.cpu(), prediction_node_ids.cpu(), average=average) 

65 

66 

67def greedy_precision(prediction_tensor:torch.Tensor, target_tensor:torch.Tensor, root:nodes.SoftmaxNode, average:str="macro", max_depth=None) -> float: 

68 """ 

69 Gives the precision score of predicting the target. 

70 

71 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree. 

72 

73 Args: 

74 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size) 

75 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,). 

76 root (SoftmaxNode): The root of the hierarchy tree. 

77 average (str, optional): The type of averaging over the different classes. 

78 Options are: 'micro', 'macro', 'samples', 'weighted', 'binary' or None.  

79 See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html for details. 

80 Defaults to "macro". 

81 

82 Returns: 

83 float: The precision  

84 """ 

85 prediction_node_ids = inference.greedy_prediction_node_ids_tensor(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth) 

86 target_tensor = target_max_depth(target_tensor, root, max_depth) 

87 

88 return precision_score(target_tensor.cpu(), prediction_node_ids.cpu(), average=average) 

89 

90 

91def greedy_recall(prediction_tensor:torch.Tensor, target_tensor:torch.Tensor, root:nodes.SoftmaxNode, average:str="macro", max_depth=None) -> float: 

92 """ 

93 Gives the recall score of predicting the target. 

94 

95 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree. 

96 

97 Args: 

98 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size) 

99 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,). 

100 root (SoftmaxNode): The root of the hierarchy tree. 

101 average (str, optional): The type of averaging over the different classes. 

102 Options are: 'micro', 'macro', 'samples', 'weighted', 'binary' or None.  

103 See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html for details. 

104 Defaults to "macro". 

105 

106 Returns: 

107 float: The recall  

108 """ 

109 prediction_node_ids = inference.greedy_prediction_node_ids_tensor(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth) 

110 target_tensor = target_max_depth(target_tensor, root, max_depth) 

111 

112 return recall_score(target_tensor.cpu(), prediction_node_ids.cpu(), average=average) 

113 

114 

115def greedy_accuracy_parent(prediction_tensor, target_tensor, root, max_depth=None): 

116 """ 

117 Gives the accuracy of predicting the parent of the target in a hierarchy tree. 

118 

119 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree. 

120 

121 Args: 

122 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size) 

123 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,). 

124 root (SoftmaxNode): The root of the hierarchy tree. 

125 

126 Returns: 

127 float: The accuracy value (i.e. the number that are correct divided by the total number of samples) 

128 """ 

129 prediction_nodes = inference.greedy_predictions(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth) 

130 prediction_parents = [node.parent for node in prediction_nodes] 

131 prediction_parent_ids = root.get_node_ids_tensor(prediction_parents) 

132 

133 target_tensor = target_max_depth(target_tensor, root, max_depth) 

134 target_parents = [root.node_list[target].parent for target in target_tensor] 

135 target_parent_ids = root.get_node_ids_tensor(target_parents) 

136 

137 return (prediction_parent_ids.to(target_parent_ids.device) == target_parent_ids).float().mean() 

138 

139 

140class GreedyAccuracy(): 

141 name:str = "greedy" 

142 

143 def __init__(self, root:nodes.SoftmaxNode, name="greedy_accuracy", max_depth=None): 

144 self.max_depth = max_depth 

145 self.name = name 

146 self.root = root 

147 

148 @property 

149 def __name__(self): 

150 """ For using as a FastAI metric. """ 

151 return self.name 

152 

153 def __call__(self, predictions, targets): 

154 return greedy_accuracy(predictions, targets, self.root, max_depth=self.max_depth) 

155