Coverage for polytorch/loss.py : 100.00%
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from typing import List
2from torch import nn
4from .data import PolyData
5from .util import split_tensor
7class PolyLoss(nn.Module):
8 def __init__(
9 self,
10 data_types:List[PolyData],
11 feature_axis:int=-1,
12 **kwargs,
13 ):
14 super().__init__(**kwargs)
15 self.data_types = data_types
16 self.feature_axis = feature_axis
18 def forward(self, predictions, *targets):
21 if not isinstance(predictions, (tuple, list)):
22 predictions = split_tensor(predictions, self.data_types, feature_axis=self.feature_axis)
24 assert len(predictions) == len(targets) == len(self.data_types)
26 loss = 0.0
27 for prediction, target, data_type in zip(predictions, targets, self.data_types):
28 feature_axis = self.feature_axis % len(prediction.shape)
30 if not hasattr(data_type, "calculate_loss"):
31 raise ValueError(f"Data type {data_type} does not have a calculate_loss method")
33 loss += data_type.calculate_loss(prediction, target, feature_axis=feature_axis).mean()
35 return loss