Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from torch import nn 

2import abc 

3from typing import List, Optional 

4from attrs import define, Factory, field 

5import torch.nn.functional as F 

6 

7from .util import permute_feature_axis, squeeze_prediction 

8from .enums import ContinuousLossType, BinaryLossType, CategoricalLossType 

9 

10 

11@define(kw_only=True) 

12class PolyData(abc.ABC): 

13 name: str = Factory(lambda self: self.__class__.__name__, takes_self=True) 

14 

15 @abc.abstractmethod 

16 def embedding_module(self, embedding_size:int) -> nn.Module: 

17 pass 

18 

19 @abc.abstractmethod 

20 def size(self) -> int: 

21 pass 

22 

23 @abc.abstractmethod 

24 def calculate_loss(self, prediction, target, feature_axis:int=-1): 

25 pass 

26 

27 

28def binary_default_factory(): 

29 return ["False", "True"] 

30 

31 

32@define 

33class BinaryData(PolyData): 

34 loss_type:BinaryLossType = BinaryLossType.CROSS_ENTROPY 

35 labels:List[str] = field(factory=binary_default_factory) 

36 colors:Optional[List[str]] = None 

37 

38 def embedding_module(self, embedding_size:int) -> nn.Module: 

39 return nn.Embedding(2, embedding_size) 

40 

41 def size(self) -> int: 

42 return 1 

43 

44 def calculate_loss(self, prediction, target, feature_axis:int=-1): 

45 prediction = squeeze_prediction(prediction, target, feature_axis) 

46 

47 if self.loss_type == BinaryLossType.CROSS_ENTROPY: 

48 return F.binary_cross_entropy_with_logits( 

49 prediction, 

50 target.float(), 

51 ) 

52 elif self.loss_type == BinaryLossType.IOU: 

53 from .metrics import calc_iou 

54 return 1 - calc_iou( 

55 prediction.sigmoid(), 

56 target, 

57 ) 

58 elif self.loss_type == BinaryLossType.DICE: 

59 from .metrics import calc_dice_score 

60 return 1 - calc_dice_score( 

61 prediction.sigmoid(), 

62 target, 

63 ) 

64 

65 raise NotImplementedError(f"Unknown loss type: {self.loss_type} for {self.__class__.__name__}") 

66 

67 

68@define 

69class CategoricalData(PolyData): 

70 category_count:int 

71 loss_type:CategoricalLossType = CategoricalLossType.CROSS_ENTROPY 

72 labels:Optional[List[str]] = None 

73 colors:Optional[List[str]] = None 

74 

75 def embedding_module(self, embedding_size:int) -> nn.Module: 

76 return nn.Embedding(self.category_count, embedding_size) 

77 

78 def size(self) -> int: 

79 return self.category_count 

80 

81 def calculate_loss(self, prediction, target, feature_axis:int=-1): 

82 if self.loss_type == CategoricalLossType.CROSS_ENTROPY: 

83 prediction = permute_feature_axis(prediction, old_axis=feature_axis, new_axis=1) 

84 return F.cross_entropy( 

85 prediction, 

86 target.long(), 

87 reduction="none", 

88 # label_smoothing=self.label_smoothing, 

89 ) 

90 elif self.loss_type == CategoricalLossType.DICE: 

91 from .metrics import calc_generalized_dice_score 

92 return 1. - calc_generalized_dice_score( 

93 prediction.softmax(dim=feature_axis), 

94 target, 

95 feature_axis=feature_axis, 

96 ) 

97 

98 raise NotImplementedError(f"Unknown loss type: {self.loss_type} for {self.__class__.__name__}") 

99 

100 

101@define 

102class OrdinalData(CategoricalData): 

103 color:str = "" 

104 # add in option to estimate distances or to set them? 

105 

106 def embedding_module(self, embedding_size:int) -> nn.Module: 

107 from .embedding import OrdinalEmbedding 

108 return OrdinalEmbedding(self.category_count, embedding_size) 

109 

110 

111@define 

112class ContinuousData(PolyData): 

113 loss_type:ContinuousLossType = ContinuousLossType.SMOOTH_L1_LOSS 

114 color:str = "" 

115 

116 def embedding_module(self, embedding_size:int) -> nn.Module: 

117 from .embedding import ContinuousEmbedding 

118 return ContinuousEmbedding(embedding_size) 

119 

120 def size(self) -> int: 

121 return 1 

122 

123 def calculate_loss(self, prediction, target, feature_axis:int=-1): 

124 prediction = squeeze_prediction(prediction, target, feature_axis) 

125 return self.loss_func(prediction, target, reduction="none") 

126 

127 @property 

128 def loss_func(self): 

129 return getattr(F, self.loss_type.name.lower()) 

130