Coverage for polytorch/modules.py : 100.00%
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from torch import nn
2from typing import List
4from .data import PolyData
5from .util import total_size, split_tensor
8class PolyLayerError(RuntimeError):
9 pass
12class PolyLayerMixin():
13 def __init__(self, output_types:List[PolyData], out_features=None, **kwargs):
14 self.output_types = output_types
16 if out_features is not None:
17 raise PolyLayerError(
18 "Trying to create a PolyLinear Layer by explicitly setting `out_features`. "
19 "This value should be determined from the list of output types and not the `out_features` argument."
20 )
22 super().__init__(out_features=total_size(self.output_types), **kwargs)
24 def forward(self, *inputs):
25 outputs = super().forward(*inputs)
26 return split_tensor(outputs, self.output_types, feature_axis=-1)
29class PolyLinear(PolyLayerMixin, nn.Linear):
30 """
31 Creates a linear layer designed to be the final layer in a neural network model that produces unnormalized scores given to PolyLoss.
33 The `out_features` value is set internally from root.layer_size and cannot be given as an argument.
34 """