Coverage for polytorch/data.py : 100.00%
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from torch import nn
2import abc
3from typing import List, Optional
4from attrs import define, Factory, field
5import torch.nn.functional as F
7from .util import permute_feature_axis, squeeze_prediction
8from .enums import ContinuousLossType, BinaryLossType, CategoricalLossType
11@define(kw_only=True)
12class PolyData(abc.ABC):
13 name: str = Factory(lambda self: self.__class__.__name__, takes_self=True)
15 @abc.abstractmethod
16 def embedding_module(self, embedding_size:int) -> nn.Module:
17 pass
19 @abc.abstractmethod
20 def size(self) -> int:
21 pass
23 @abc.abstractmethod
24 def calculate_loss(self, prediction, target, feature_axis:int=-1):
25 pass
28def binary_default_factory():
29 return ["False", "True"]
32@define
33class BinaryData(PolyData):
34 loss_type:BinaryLossType = BinaryLossType.CROSS_ENTROPY
35 labels:List[str] = field(factory=binary_default_factory)
36 colors:Optional[List[str]] = None
38 def embedding_module(self, embedding_size:int) -> nn.Module:
39 return nn.Embedding(2, embedding_size)
41 def size(self) -> int:
42 return 1
44 def calculate_loss(self, prediction, target, feature_axis:int=-1):
45 prediction = squeeze_prediction(prediction, target, feature_axis)
47 if self.loss_type == BinaryLossType.CROSS_ENTROPY:
48 return F.binary_cross_entropy_with_logits(
49 prediction,
50 target.float(),
51 )
52 elif self.loss_type == BinaryLossType.IOU:
53 from .metrics import calc_iou
54 return 1 - calc_iou(
55 prediction.sigmoid(),
56 target,
57 )
58 elif self.loss_type == BinaryLossType.DICE:
59 from .metrics import calc_dice_score
60 return 1 - calc_dice_score(
61 prediction.sigmoid(),
62 target,
63 )
65 raise NotImplementedError(f"Unknown loss type: {self.loss_type} for {self.__class__.__name__}")
68@define
69class CategoricalData(PolyData):
70 category_count:int
71 loss_type:CategoricalLossType = CategoricalLossType.CROSS_ENTROPY
72 labels:Optional[List[str]] = None
73 colors:Optional[List[str]] = None
75 def embedding_module(self, embedding_size:int) -> nn.Module:
76 return nn.Embedding(self.category_count, embedding_size)
78 def size(self) -> int:
79 return self.category_count
81 def calculate_loss(self, prediction, target, feature_axis:int=-1):
82 if self.loss_type == CategoricalLossType.CROSS_ENTROPY:
83 prediction = permute_feature_axis(prediction, old_axis=feature_axis, new_axis=1)
84 return F.cross_entropy(
85 prediction,
86 target.long(),
87 reduction="none",
88 # label_smoothing=self.label_smoothing,
89 )
90 elif self.loss_type == CategoricalLossType.DICE:
91 from .metrics import calc_generalized_dice_score
92 return 1. - calc_generalized_dice_score(
93 prediction.softmax(dim=feature_axis),
94 target,
95 feature_axis=feature_axis,
96 )
98 raise NotImplementedError(f"Unknown loss type: {self.loss_type} for {self.__class__.__name__}")
101@define
102class OrdinalData(CategoricalData):
103 color:str = ""
104 # add in option to estimate distances or to set them?
106 def embedding_module(self, embedding_size:int) -> nn.Module:
107 from .embedding import OrdinalEmbedding
108 return OrdinalEmbedding(self.category_count, embedding_size)
111@define
112class ContinuousData(PolyData):
113 loss_type:ContinuousLossType = ContinuousLossType.SMOOTH_L1_LOSS
114 color:str = ""
116 def embedding_module(self, embedding_size:int) -> nn.Module:
117 from .embedding import ContinuousEmbedding
118 return ContinuousEmbedding(embedding_size)
120 def size(self) -> int:
121 return 1
123 def calculate_loss(self, prediction, target, feature_axis:int=-1):
124 prediction = squeeze_prediction(prediction, target, feature_axis)
125 return self.loss_func(prediction, target, reduction="none")
127 @property
128 def loss_func(self):
129 return getattr(F, self.loss_type.name.lower())