Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import torch 

2import torch.nn.functional as F 

3 

4from .util import squeeze_prediction 

5 

6 

7def get_predictions_target_for_index(predictions, *targets, data_index=None, feature_axis=-1): 

8 if not isinstance(predictions, (tuple, list)): 

9 predictions = (predictions,) 

10 

11 assert len(predictions) == len(targets) 

12 assert data_index is not None, "data_index must be specified" 

13 

14 my_predictions = predictions[data_index] 

15 my_targets = targets[data_index] 

16 

17 return my_predictions, my_targets 

18 

19 

20def function_metric(predictions, *targets, data_index=None, feature_axis=-1, function=None) -> torch.Tensor: 

21 my_predictions, my_targets = get_predictions_target_for_index(predictions, *targets, data_index=data_index, feature_axis=feature_axis) 

22 my_predictions = squeeze_prediction(my_predictions, my_targets, feature_axis) 

23 return function(my_predictions, my_targets) 

24 

25 

26def categorical_accuracy(predictions, *targets, data_index=None, feature_axis=-1) -> torch.Tensor: 

27 my_predictions, my_targets = get_predictions_target_for_index(predictions, *targets, data_index=data_index, feature_axis=feature_axis) 

28 my_predictions = torch.argmax(my_predictions, dim=feature_axis) 

29 

30 accuracy = (my_predictions == my_targets).float().mean() 

31 return accuracy 

32 

33 

34def binary_accuracy(predictions, *targets, data_index=None, feature_axis=-1) -> torch.Tensor: 

35 my_predictions, my_targets = get_predictions_target_for_index(predictions, *targets, data_index=data_index, feature_axis=feature_axis) 

36 my_predictions = my_predictions >= 0.0 

37 my_predictions = squeeze_prediction(my_predictions, my_targets, feature_axis) 

38 

39 accuracy = (my_predictions == my_targets).float().mean() 

40 return accuracy 

41 

42 

43def mse(predictions, *targets, data_index=None, feature_axis=-1) -> torch.Tensor: 

44 return function_metric(predictions, *targets, data_index=data_index, feature_axis=feature_axis, function=F.mse_loss) 

45 

46 

47def l1(predictions, *targets, data_index=None, feature_axis=-1) -> torch.Tensor: 

48 return function_metric(predictions, *targets, data_index=data_index, feature_axis=feature_axis, function=F.l1_loss) 

49 

50 

51def smooth_l1(predictions, *targets, data_index=None, feature_axis=-1) -> torch.Tensor: 

52 return function_metric(predictions, *targets, data_index=data_index, feature_axis=feature_axis, function=F.smooth_l1_loss) 

53 

54 

55def calc_dice_score(predictions, target, smooth:float=1.) -> torch.Tensor: 

56 predictions = predictions.view(-1) 

57 target = target.view(-1) 

58 intersection = (predictions * target).sum() 

59 

60 return ((2. * intersection + smooth) / 

61 (predictions.sum() + target.sum() + smooth) 

62 ) 

63 

64 

65def calc_generalized_dice_score(predictions, target, power:float=2.0, smooth:float=1.0, feature_axis:int=-1) -> torch.Tensor: 

66 """ 

67 A generalized Dice score for multi-class segmentation. 

68 

69 If power=0.0, this is equivalent to normal Dice score (i.e. volume "implicit" weighting) 

70 If power=1.0, this is equivalent to 'equal' weighting. 

71 If power=2.0, this is equivalent to 'inverse volume' weighting. 

72 

73 See: 

74 - https://www.sciencedirect.com/science/article/pii/S2590005619300049#bib73 

75 - https://arxiv.org/pdf/1707.03237.pdf 

76 - https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=1717643 

77 """ 

78 n_classes = predictions.shape[feature_axis] 

79 numerator = 0.0 

80 denominator = 0.0 

81 slice_indices = [slice(0, None)] * len(predictions.shape) 

82 for i in range(n_classes): 

83 my_target = (target == i) 

84 my_target_sum = my_target.sum() 

85 weight = 1/(my_target_sum**power + smooth) 

86 

87 slice_indices[feature_axis] = i 

88 my_predictions = predictions[slice_indices] 

89 

90 numerator += weight * (my_predictions*my_target).sum() 

91 denominator += weight * (my_predictions.sum() + my_target_sum) 

92 

93 return 2. * numerator / denominator 

94 

95 

96def calc_iou(predictions, target, smooth:float=1.): 

97 predictions = predictions.view(-1) 

98 target = target.view(-1) 

99 intersection = (predictions * target).sum() 

100 union = predictions.sum() + target.sum() - intersection 

101 

102 return (intersection + smooth) / (union + smooth) 

103 

104 

105def binary_dice(predictions, *targets, data_index=None, feature_axis=-1): 

106 my_predictions, my_targets = get_predictions_target_for_index(predictions, *targets, data_index=data_index, feature_axis=feature_axis) 

107 my_predictions = my_predictions >= 0.0 

108 my_predictions = squeeze_prediction(my_predictions, my_targets, feature_axis) 

109 

110 return calc_dice_score(my_predictions, my_targets) 

111 

112 

113def binary_iou(predictions, *targets, data_index=None, feature_axis=-1): 

114 my_predictions, my_targets = get_predictions_target_for_index(predictions, *targets, data_index=data_index, feature_axis=feature_axis) 

115 my_predictions = my_predictions >= 0.0 

116 my_predictions = squeeze_prediction(my_predictions, my_targets, feature_axis) 

117 

118 return calc_iou(my_predictions, my_targets) 

119 

120 

121def generalized_dice(predictions, *targets, data_index=None, feature_axis=-1, perform_softmax:bool=True, power:float=2.0) -> torch.Tensor: 

122 """ 

123 Calculate the generalized dice score for a single data index. Used for for multi-class segmentation. 

124 

125 

126 See: 

127 - https://www.sciencedirect.com/science/article/pii/S2590005619300049#bib73 

128 - https://arxiv.org/pdf/1707.03237.pdf 

129 - https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=1717643 

130 

131 Args: 

132 predictions (_type_): Logits or probabilities. If probabilities, perform_softmax must be False. 

133 data_index (_type_, optional): The index of the target to . Defaults to None. 

134 feature_axis (int, optional): _description_. Defaults to -1. 

135 perform_softmax (bool, optional): Whether or not to normalize the predictions using the softmax function. Defaults to True. 

136 power (float, optional): The power to use for the generalized dice score. Defaults to 2.0 (i.e. inverse volume weighting) 

137 If power=0.0, this is equivalent to normal Dice score (i.e. volume "implicit" weighting) 

138 If power=1.0, this is equivalent to 'equal' weighting. 

139 If power=2.0, this is equivalent to 'inverse volume' weighting. 

140 

141 Returns: 

142 torch.Tensor: The generalized dice score for the given data index. 

143 """ 

144 my_predictions, my_targets = get_predictions_target_for_index(predictions, *targets, data_index=data_index, feature_axis=feature_axis) 

145 if perform_softmax: 

146 my_predictions = my_predictions.softmax(dim=feature_axis) 

147 

148 score = calc_generalized_dice_score( 

149 my_predictions, 

150 my_targets, 

151 feature_axis=feature_axis, 

152 power=power, 

153 ) 

154 return score 

155 

156