Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from torch import nn 

2from typing import List 

3 

4from .data import PolyData 

5from .util import total_size, split_tensor 

6 

7 

8class PolyLayerError(RuntimeError): 

9 pass 

10 

11 

12class PolyLayerMixin(): 

13 def __init__(self, output_types:List[PolyData], out_features=None, **kwargs): 

14 self.output_types = output_types 

15 

16 if out_features is not None: 

17 raise PolyLayerError( 

18 "Trying to create a PolyLinear Layer by explicitly setting `out_features`. " 

19 "This value should be determined from the list of output types and not the `out_features` argument." 

20 ) 

21 

22 super().__init__(out_features=total_size(self.output_types), **kwargs) 

23 

24 def forward(self, *inputs): 

25 outputs = super().forward(*inputs) 

26 return split_tensor(outputs, self.output_types, feature_axis=-1) 

27 

28 

29class PolyLinear(PolyLayerMixin, nn.Linear): 

30 """ 

31 Creates a linear layer designed to be the final layer in a neural network model that produces unnormalized scores given to PolyLoss. 

32 

33 The `out_features` value is set internally from root.layer_size and cannot be given as an argument. 

34 """