Coverage for polytorch/metrics.py : 100.00%
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import torch
2import torch.nn.functional as F
4from .util import squeeze_prediction
7def get_predictions_target_for_index(predictions, *targets, data_index=None, feature_axis=-1):
8 if not isinstance(predictions, (tuple, list)):
9 predictions = (predictions,)
11 assert len(predictions) == len(targets)
12 assert data_index is not None, "data_index must be specified"
14 my_predictions = predictions[data_index]
15 my_targets = targets[data_index]
17 return my_predictions, my_targets
20def function_metric(predictions, *targets, data_index=None, feature_axis=-1, function=None) -> torch.Tensor:
21 my_predictions, my_targets = get_predictions_target_for_index(predictions, *targets, data_index=data_index, feature_axis=feature_axis)
22 my_predictions = squeeze_prediction(my_predictions, my_targets, feature_axis)
23 return function(my_predictions, my_targets)
26def categorical_accuracy(predictions, *targets, data_index=None, feature_axis=-1) -> torch.Tensor:
27 my_predictions, my_targets = get_predictions_target_for_index(predictions, *targets, data_index=data_index, feature_axis=feature_axis)
28 my_predictions = torch.argmax(my_predictions, dim=feature_axis)
30 accuracy = (my_predictions == my_targets).float().mean()
31 return accuracy
34def binary_accuracy(predictions, *targets, data_index=None, feature_axis=-1) -> torch.Tensor:
35 my_predictions, my_targets = get_predictions_target_for_index(predictions, *targets, data_index=data_index, feature_axis=feature_axis)
36 my_predictions = my_predictions >= 0.0
37 my_predictions = squeeze_prediction(my_predictions, my_targets, feature_axis)
39 accuracy = (my_predictions == my_targets).float().mean()
40 return accuracy
43def mse(predictions, *targets, data_index=None, feature_axis=-1) -> torch.Tensor:
44 return function_metric(predictions, *targets, data_index=data_index, feature_axis=feature_axis, function=F.mse_loss)
47def l1(predictions, *targets, data_index=None, feature_axis=-1) -> torch.Tensor:
48 return function_metric(predictions, *targets, data_index=data_index, feature_axis=feature_axis, function=F.l1_loss)
51def smooth_l1(predictions, *targets, data_index=None, feature_axis=-1) -> torch.Tensor:
52 return function_metric(predictions, *targets, data_index=data_index, feature_axis=feature_axis, function=F.smooth_l1_loss)
55def calc_dice_score(predictions, target, smooth:float=1.) -> torch.Tensor:
56 predictions = predictions.view(-1)
57 target = target.view(-1)
58 intersection = (predictions * target).sum()
60 return ((2. * intersection + smooth) /
61 (predictions.sum() + target.sum() + smooth)
62 )
65def calc_generalized_dice_score(predictions, target, power:float=2.0, smooth:float=1.0, feature_axis:int=-1) -> torch.Tensor:
66 """
67 A generalized Dice score for multi-class segmentation.
69 If power=0.0, this is equivalent to normal Dice score (i.e. volume "implicit" weighting)
70 If power=1.0, this is equivalent to 'equal' weighting.
71 If power=2.0, this is equivalent to 'inverse volume' weighting.
73 See:
74 - https://www.sciencedirect.com/science/article/pii/S2590005619300049#bib73
75 - https://arxiv.org/pdf/1707.03237.pdf
76 - https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=1717643
77 """
78 n_classes = predictions.shape[feature_axis]
79 numerator = 0.0
80 denominator = 0.0
81 slice_indices = [slice(0, None)] * len(predictions.shape)
82 for i in range(n_classes):
83 my_target = (target == i)
84 my_target_sum = my_target.sum()
85 weight = 1/(my_target_sum**power + smooth)
87 slice_indices[feature_axis] = i
88 my_predictions = predictions[slice_indices]
90 numerator += weight * (my_predictions*my_target).sum()
91 denominator += weight * (my_predictions.sum() + my_target_sum)
93 return 2. * numerator / denominator
96def calc_iou(predictions, target, smooth:float=1.):
97 predictions = predictions.view(-1)
98 target = target.view(-1)
99 intersection = (predictions * target).sum()
100 union = predictions.sum() + target.sum() - intersection
102 return (intersection + smooth) / (union + smooth)
105def binary_dice(predictions, *targets, data_index=None, feature_axis=-1):
106 my_predictions, my_targets = get_predictions_target_for_index(predictions, *targets, data_index=data_index, feature_axis=feature_axis)
107 my_predictions = my_predictions >= 0.0
108 my_predictions = squeeze_prediction(my_predictions, my_targets, feature_axis)
110 return calc_dice_score(my_predictions, my_targets)
113def binary_iou(predictions, *targets, data_index=None, feature_axis=-1):
114 my_predictions, my_targets = get_predictions_target_for_index(predictions, *targets, data_index=data_index, feature_axis=feature_axis)
115 my_predictions = my_predictions >= 0.0
116 my_predictions = squeeze_prediction(my_predictions, my_targets, feature_axis)
118 return calc_iou(my_predictions, my_targets)
121def generalized_dice(predictions, *targets, data_index=None, feature_axis=-1, perform_softmax:bool=True, power:float=2.0) -> torch.Tensor:
122 """
123 Calculate the generalized dice score for a single data index. Used for for multi-class segmentation.
126 See:
127 - https://www.sciencedirect.com/science/article/pii/S2590005619300049#bib73
128 - https://arxiv.org/pdf/1707.03237.pdf
129 - https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=1717643
131 Args:
132 predictions (_type_): Logits or probabilities. If probabilities, perform_softmax must be False.
133 data_index (_type_, optional): The index of the target to . Defaults to None.
134 feature_axis (int, optional): _description_. Defaults to -1.
135 perform_softmax (bool, optional): Whether or not to normalize the predictions using the softmax function. Defaults to True.
136 power (float, optional): The power to use for the generalized dice score. Defaults to 2.0 (i.e. inverse volume weighting)
137 If power=0.0, this is equivalent to normal Dice score (i.e. volume "implicit" weighting)
138 If power=1.0, this is equivalent to 'equal' weighting.
139 If power=2.0, this is equivalent to 'inverse volume' weighting.
141 Returns:
142 torch.Tensor: The generalized dice score for the given data index.
143 """
144 my_predictions, my_targets = get_predictions_target_for_index(predictions, *targets, data_index=data_index, feature_axis=feature_axis)
145 if perform_softmax:
146 my_predictions = my_predictions.softmax(dim=feature_axis)
148 score = calc_generalized_dice_score(
149 my_predictions,
150 my_targets,
151 feature_axis=feature_axis,
152 power=power,
153 )
154 return score