Coverage for hierarchicalsoftmax/tensors.py : 100.00%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from torch import Tensor, Size
2from functools import cached_property
3import torch.nn.functional as F
4from torch.nn.parameter import Parameter
7class LazyLinearTensor(Tensor):
8 """
9 A tensor that is designed to be used with HierarchicalSoftmaxLazyLinear layers.
10 """
11 @staticmethod
12 def __new__(cls, x, weight:Parameter, bias:Parameter, *args, **kwargs):
13 return super().__new__(cls, x, *args, **kwargs)
15 def __init__(self, x:Tensor, weight:Parameter, bias:Parameter, *args, **kwargs):
16 super().__init__(*args, **kwargs)
17 self.input = x
18 self.weight = weight
19 self.bias = bias
21 @cached_property
22 def result(self):
23 return F.linear(self.input, self.weight, self.bias)
25 def __add__(self, other):
26 return self.result + other
28 def __sub__(self, other):
29 return self.result - other
31 def __mul__(self, other):
32 return self.result * other
34 def __truediv__(self, other):
35 return self.result / other
37 def __matmul__(self, other):
38 return self.result @ other
40 def __radd__(self, other):
41 return other + self.result
43 def __rsub__(self, other):
44 return other - self.result
46 def __rmul__(self, other):
47 return other * self.result
49 def __rtruediv__(self, other):
50 return other / self.result
52 def __rmatmul__(self, other):
53 return other @ self.result
55 def __getitem__(self, index):
56 assert isinstance(index, int) or isinstance(index, slice) or isinstance(index, tuple)
57 if not isinstance(index, tuple) or isinstance(index, slice):
58 index = (index,)
60 my_shape = self.shape
61 if len(index) < len(my_shape):
62 return LazyLinearTensor(self.input[index], weight=self.weight, bias=self.bias)
63 if len(index) > len(my_shape):
64 raise IndexError(f"Cannot get index '{index}' for LazyLinearTensor of shape {len(my_shape)}")
66 input = self.input[index[:-1]]
67 weight = self.weight[index[-1]]
68 bias = self.bias[index[-1]]
69 return F.linear(input, weight, bias)
71 @property
72 def shape(self) -> Size:
73 return Size( self.input.shape[:-1] + (self.weight.shape[0],) )
75 def __str__(self) -> str:
76 return f"LazyLinearTensor (shape={tuple(self.shape)})"
78 def __repr__(self) -> str:
79 return str(self)
81 def __len__(self) -> int:
82 return self.shape[0]
84 def __iter__(self):
85 for i in range(len(self)):
86 yield self[i]
88 def float(self):
89 x = super().float()
90 x.input = self.input.float()
91 x.weight = self.weight.float()
92 x.bias = self.bias.float()
93 return x
95 def half(self):
96 x = super().half()
97 x.input = self.input.half()
98 x.weight = self.weight.half()
99 x.bias = self.bias.half()
100 return x