Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from typing import List 

2from torch import nn 

3 

4from .data import PolyData 

5from .util import split_tensor 

6 

7class PolyLoss(nn.Module): 

8 def __init__( 

9 self, 

10 data_types:List[PolyData], 

11 feature_axis:int=-1, 

12 **kwargs, 

13 ): 

14 super().__init__(**kwargs) 

15 self.data_types = data_types 

16 self.feature_axis = feature_axis 

17 

18 def forward(self, predictions, *targets): 

19 

20 

21 if not isinstance(predictions, (tuple, list)): 

22 predictions = split_tensor(predictions, self.data_types, feature_axis=self.feature_axis) 

23 

24 assert len(predictions) == len(targets) == len(self.data_types) 

25 

26 loss = 0.0 

27 for prediction, target, data_type in zip(predictions, targets, self.data_types): 

28 feature_axis = self.feature_axis % len(prediction.shape) 

29 

30 if not hasattr(data_type, "calculate_loss"): 

31 raise ValueError(f"Data type {data_type} does not have a calculate_loss method") 

32 

33 loss += data_type.calculate_loss(prediction, target, feature_axis=feature_axis).mean() 

34 

35 return loss