Coverage for hierarchicalsoftmax/metrics.py: 100.00%

160 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-02 01:49 +0000

1from sklearn.metrics import f1_score, precision_score, recall_score 

2import torch 

3from typing import Callable 

4from collections.abc import Sequence 

5from torch.nn import Module 

6from torch import Tensor 

7from torchmetrics.metric import Metric, apply_to_collection 

8 

9from . import inference, nodes 

10from .inference import ShapeError 

11 

12 

13 

14def target_max_depth(target_tensor:torch.Tensor, root:nodes.SoftmaxNode, max_depth:int): 

15 """ Converts the target tensor to the max depth of the tree. """ 

16 if max_depth: 

17 max_depth_target_nodes = [root.node_list[target].path[:max_depth+1][-1] for target in target_tensor] 

18 target_tensor = root.get_node_ids_tensor(max_depth_target_nodes) 

19 

20 return target_tensor 

21 

22 

23def greedy_accuracy(prediction_tensor, target_tensor, root, max_depth=None): 

24 """ 

25 Gives the accuracy of predicting the target in a hierarchy tree. 

26 

27 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree. 

28 

29 Args: 

30 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size) 

31 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,). 

32 root (SoftmaxNode): The root of the hierarchy tree. 

33 

34 Returns: 

35 float: The accuracy value (i.e. the number that are correct divided by the total number of samples) 

36 """ 

37 prediction_node_ids = inference.greedy_prediction_node_ids_tensor(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth) 

38 target_tensor = target_max_depth(target_tensor, root, max_depth) 

39 

40 return (prediction_node_ids.to(target_tensor.device) == target_tensor).float().mean() 

41 

42 

43def greedy_accuracy_depth_one(prediction_tensor, target_tensor, root): 

44 return greedy_accuracy(prediction_tensor, target_tensor, root, max_depth=1) 

45 

46 

47def greedy_accuracy_depth_two(prediction_tensor, target_tensor, root): 

48 return greedy_accuracy(prediction_tensor, target_tensor, root, max_depth=2) 

49 

50 

51def depth_accurate(prediction_tensor, target_tensor, root:nodes.SoftmaxNode, max_depth:int=0, threshold:float|None=None): 

52 """ Returns a tensor of shape (samples,) with the depth of predictions which were accurate """ 

53 depths = [] 

54 

55 if root.softmax_start_index is None: 

56 raise nodes.IndexNotSetError(f"The index of the root node {root} has not been set. Call `set_indexes` on this object.") 

57 

58 if isinstance(prediction_tensor, tuple) and len(prediction_tensor) == 1: 

59 prediction_tensor = prediction_tensor[0] 

60 

61 if prediction_tensor.shape[-1] != root.layer_size: 

62 raise ShapeError( 

63 f"The predictions tensor given to {__name__} has final dimensions of {prediction_tensor.shape[-1]}. " 

64 f"That is not compatible with the root node which expects prediciton tensors to have a final dimension of {root.layer_size}." 

65 ) 

66 

67 for predictions, target in zip(prediction_tensor, target_tensor): 

68 node = root 

69 depth = 0 

70 target_node = root.node_list[target] 

71 target_path = target_node.path 

72 target_path_length = len(target_path) 

73 

74 

75 while (node.children): 

76 # This would be better if we could use torch.argmax but it doesn't work with MPS in the production version of pytorch 

77 # See https://github.com/pytorch/pytorch/issues/98191 

78 # https://github.com/pytorch/pytorch/pull/104374 

79 if len(node.children) == 1: 

80 # if this node use just one child, then we don't check the prediction 

81 prediction_child_index = 0 

82 else: 

83 prediction_child_index = torch.max(predictions[node.softmax_start_index:node.softmax_end_index], dim=0).indices 

84 

85 node = node.children[prediction_child_index] 

86 depth += 1 

87 

88 if depth < target_path_length and node != target_path[depth]: 

89 depth -= 1 

90 break 

91 

92 # Stop if we have reached the maximum depth 

93 if max_depth and depth >= max_depth: 

94 break 

95 

96 depths.append(depth) 

97 

98 return torch.tensor(depths, dtype=int) 

99 

100 

101def greedy_f1_score(prediction_tensor:torch.Tensor, target_tensor:torch.Tensor, root:nodes.SoftmaxNode, average:str="macro", max_depth=None) -> float: 

102 """ 

103 Gives the f1 score of predicting the target i.e. a harmonic mean of the precision and recall. 

104 

105 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree. 

106 

107 Args: 

108 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size) 

109 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,). 

110 root (SoftmaxNode): The root of the hierarchy tree. 

111 average (str, optional): The type of averaging over the different classes. 

112 Options are: 'micro', 'macro', 'samples', 'weighted', 'binary' or None.  

113 See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html for details. 

114 Defaults to "macro". 

115 

116 Returns: 

117 float: The f1 score  

118 """ 

119 prediction_node_ids = inference.greedy_prediction_node_ids_tensor(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth) 

120 target_tensor = target_max_depth(target_tensor, root, max_depth) 

121 

122 return f1_score(target_tensor.cpu(), prediction_node_ids.cpu(), average=average) 

123 

124 

125def greedy_precision(prediction_tensor:torch.Tensor, target_tensor:torch.Tensor, root:nodes.SoftmaxNode, average:str="macro", max_depth=None) -> float: 

126 """ 

127 Gives the precision score of predicting the target. 

128 

129 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree. 

130 

131 Args: 

132 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size) 

133 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,). 

134 root (SoftmaxNode): The root of the hierarchy tree. 

135 average (str, optional): The type of averaging over the different classes. 

136 Options are: 'micro', 'macro', 'samples', 'weighted', 'binary' or None.  

137 See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html for details. 

138 Defaults to "macro". 

139 

140 Returns: 

141 float: The precision  

142 """ 

143 prediction_node_ids = inference.greedy_prediction_node_ids_tensor(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth) 

144 target_tensor = target_max_depth(target_tensor, root, max_depth) 

145 

146 return precision_score(target_tensor.cpu(), prediction_node_ids.cpu(), average=average) 

147 

148 

149def greedy_recall(prediction_tensor:torch.Tensor, target_tensor:torch.Tensor, root:nodes.SoftmaxNode, average:str="macro", max_depth=None) -> float: 

150 """ 

151 Gives the recall score of predicting the target. 

152 

153 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree. 

154 

155 Args: 

156 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size) 

157 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,). 

158 root (SoftmaxNode): The root of the hierarchy tree. 

159 average (str, optional): The type of averaging over the different classes. 

160 Options are: 'micro', 'macro', 'samples', 'weighted', 'binary' or None.  

161 See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html for details. 

162 Defaults to "macro". 

163 

164 Returns: 

165 float: The recall  

166 """ 

167 prediction_node_ids = inference.greedy_prediction_node_ids_tensor(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth) 

168 target_tensor = target_max_depth(target_tensor, root, max_depth) 

169 

170 return recall_score(target_tensor.cpu(), prediction_node_ids.cpu(), average=average) 

171 

172 

173def greedy_accuracy_parent(prediction_tensor, target_tensor, root, max_depth=None): 

174 """ 

175 Gives the accuracy of predicting the parent of the target in a hierarchy tree. 

176 

177 Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree. 

178 

179 Args: 

180 prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size) 

181 target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,). 

182 root (SoftmaxNode): The root of the hierarchy tree. 

183 

184 Returns: 

185 float: The accuracy value (i.e. the number that are correct divided by the total number of samples) 

186 """ 

187 prediction_nodes = inference.greedy_predictions(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth) 

188 prediction_parents = [node.parent for node in prediction_nodes] 

189 prediction_parent_ids = root.get_node_ids_tensor(prediction_parents) 

190 

191 target_tensor = target_max_depth(target_tensor, root, max_depth) 

192 target_parents = [root.node_list[target].parent for target in target_tensor] 

193 target_parent_ids = root.get_node_ids_tensor(target_parents) 

194 

195 return (prediction_parent_ids.to(target_parent_ids.device) == target_parent_ids).float().mean() 

196 

197 

198class GreedyAccuracy(): 

199 name:str = "greedy" 

200 

201 def __init__(self, root:nodes.SoftmaxNode, name="greedy_accuracy", max_depth=None): 

202 self.max_depth = max_depth 

203 self.name = name 

204 self.root = root 

205 

206 @property 

207 def __name__(self): 

208 """ For using as a FastAI metric. """ 

209 return self.name 

210 

211 def __call__(self, predictions, targets): 

212 return greedy_accuracy(predictions, targets, self.root, max_depth=self.max_depth) 

213 

214 

215class HierarchicalSoftmaxTorchMetric(Metric): 

216 def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> Module: 

217 """Overwrite `_apply` function such that we can also move metric states to the correct device. 

218 

219 This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, `.float`, `.half` etc. methods 

220 are called. Dtype conversion is guarded and will only happen through the special `set_dtype` method. 

221 

222 Overriding because there is an issue device in the parent class. 

223 

224 Args: 

225 fn: the function to apply 

226 exclude_state: list of state variables to exclude from applying the function, that then needs to be handled 

227 by the metric class itself. 

228 """ 

229 this = super(Metric, self)._apply(fn) 

230 fs = str(fn) 

231 cond = any(f in fs for f in ["Module.type", "Module.half", "Module.float", "Module.double", "Module.bfloat16"]) 

232 if not self._dtype_convert and cond: 

233 return this 

234 

235 # Also apply fn to metric states and defaults 

236 for key, value in this._defaults.items(): 

237 if key in exclude_state: 

238 continue 

239 

240 if isinstance(value, Tensor): 

241 this._defaults[key] = fn(value) 

242 elif isinstance(value, Sequence): 

243 this._defaults[key] = [fn(v) for v in value] 

244 

245 current_val = getattr(this, key) 

246 if isinstance(current_val, Tensor): 

247 setattr(this, key, fn(current_val)) 

248 elif isinstance(current_val, Sequence): 

249 setattr(this, key, [fn(cur_v) for cur_v in current_val]) 

250 else: 

251 raise TypeError( 

252 f"Expected metric state to be either a Tensor or a list of Tensor, but encountered {current_val}" 

253 ) 

254 

255 # Additional apply to forward cache and computed attributes (may be nested) 

256 if this._computed is not None: 

257 this._computed = apply_to_collection(this._computed, Tensor, fn) 

258 if this._forward_cache is not None: 

259 this._forward_cache = apply_to_collection(this._forward_cache, Tensor, fn) 

260 

261 return this 

262 

263 

264class GreedyAccuracyTorchMetric(HierarchicalSoftmaxTorchMetric): 

265 def __init__(self, root:nodes.SoftmaxNode, name:str="", max_depth=None): 

266 super().__init__() 

267 self.root = root 

268 self.max_depth = max_depth 

269 self.name = name or (f"greedy_accuracy_{max_depth}" if max_depth else "greedy_accuracy") 

270 self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 

271 self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") 

272 

273 def update(self, predictions, targets): 

274 self.total += targets.size(0) 

275 self.correct += int(greedy_accuracy(predictions, targets, self.root, max_depth=self.max_depth) * targets.size(0)) 

276 

277 def compute(self): 

278 return self.correct / self.total 

279 

280 

281class RankAccuracyTorchMetric(HierarchicalSoftmaxTorchMetric): 

282 def __init__(self, root, ranks: dict[int, str], name: str = "rank_accuracy"): 

283 super().__init__() 

284 self.root = root 

285 self.ranks = ranks 

286 self.name = name 

287 

288 # Use `add_state` for metrics to handle distributed reduction and device placement 

289 self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 

290 

291 for rank_name in ranks.values(): 

292 self.add_state(rank_name, default=torch.tensor(0), dist_reduce_fx="sum") 

293 

294 def update(self, predictions, targets): 

295 if isinstance(predictions, tuple) and len(predictions) == 1: 

296 predictions = predictions[0] 

297 

298 # Ensure tensors match the device 

299 predictions = predictions.to(self.device) 

300 targets = targets.to(self.device) 

301 

302 self.total += targets.size(0) 

303 depth_accurate_tensor = depth_accurate(predictions, targets, self.root) 

304 

305 for depth, rank_name in self.ranks.items(): 

306 accurate_at_depth = (depth_accurate_tensor >= depth).sum() 

307 setattr(self, rank_name, getattr(self, rank_name) + accurate_at_depth) 

308 

309 def compute(self): 

310 # Compute final metric values 

311 return { 

312 rank_name: getattr(self, rank_name) / self.total 

313 for rank_name in self.ranks.values() 

314 } 

315 

316 

317class LeafAccuracyTorchMetric(HierarchicalSoftmaxTorchMetric): 

318 def __init__(self, root:nodes.SoftmaxNode, name:str="", max_depth=None): 

319 super().__init__() 

320 self.root = root 

321 self.max_depth = max_depth 

322 self.name = name or (f"leaf_accuracy_{max_depth}" if max_depth else "leaf_accuracy") 

323 self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 

324 self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") 

325 self.node_indexes = torch.as_tensor([node.best_index_in_softmax_layer() for node in self.root.node_list]) 

326 self.leaf_indexes = torch.as_tensor(self.root.leaf_indexes) 

327 

328 def update(self, predictions, targets): 

329 self.total += targets.size(0) 

330 

331 # Make sure the tensors are on the same device 

332 self.node_indexes = self.node_indexes.to(predictions.device) 

333 self.leaf_indexes = self.leaf_indexes.to(predictions.device) 

334 

335 target_indices = torch.index_select(self.node_indexes.to(targets.device), 0, targets) 

336 

337 # get indices of the maximum values along the last dimension 

338 probabilities = inference.leaf_probabilities(prediction_tensor=predictions, root=self.root) 

339 _, max_indices = torch.max(probabilities, dim=1) 

340 predicted_leaf_indices = torch.index_select(self.root.leaf_indexes.to(targets.device), 0, max_indices) 

341 

342 self.correct += (predicted_leaf_indices == target_indices).sum() 

343 

344 def compute(self): 

345 return self.correct / self.total 

346 

347