Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from torch import Tensor, Size 

2from functools import cached_property 

3import torch.nn.functional as F 

4from torch.nn.parameter import Parameter 

5 

6 

7class LazyLinearTensor(Tensor): 

8 """ 

9 A tensor that is designed to be used with HierarchicalSoftmaxLazyLinear layers. 

10 """ 

11 @staticmethod 

12 def __new__(cls, x, weight:Parameter, bias:Parameter, *args, **kwargs): 

13 return super().__new__(cls, x, *args, **kwargs) 

14 

15 def __init__(self, x:Tensor, weight:Parameter, bias:Parameter, *args, **kwargs): 

16 super().__init__(*args, **kwargs) 

17 self.input = x 

18 self.weight = weight 

19 self.bias = bias 

20 

21 @cached_property 

22 def result(self): 

23 return F.linear(self.input, self.weight, self.bias) 

24 

25 def __add__(self, other): 

26 return self.result + other 

27 

28 def __sub__(self, other): 

29 return self.result - other 

30 

31 def __mul__(self, other): 

32 return self.result * other 

33 

34 def __truediv__(self, other): 

35 return self.result / other 

36 

37 def __matmul__(self, other): 

38 return self.result @ other 

39 

40 def __radd__(self, other): 

41 return other + self.result 

42 

43 def __rsub__(self, other): 

44 return other - self.result 

45 

46 def __rmul__(self, other): 

47 return other * self.result 

48 

49 def __rtruediv__(self, other): 

50 return other / self.result 

51 

52 def __rmatmul__(self, other): 

53 return other @ self.result 

54 

55 def __getitem__(self, index): 

56 assert isinstance(index, int) or isinstance(index, slice) or isinstance(index, tuple) 

57 if not isinstance(index, tuple) or isinstance(index, slice): 

58 index = (index,) 

59 

60 my_shape = self.shape 

61 if len(index) < len(my_shape): 

62 return LazyLinearTensor(self.input[index], weight=self.weight, bias=self.bias) 

63 if len(index) > len(my_shape): 

64 raise IndexError(f"Cannot get index '{index}' for LazyLinearTensor of shape {len(my_shape)}") 

65 

66 input = self.input[index[:-1]] 

67 weight = self.weight[index[-1]] 

68 bias = self.bias[index[-1]] 

69 return F.linear(input, weight, bias) 

70 

71 @property 

72 def shape(self) -> Size: 

73 return Size( self.input.shape[:-1] + (self.weight.shape[0],) ) 

74 

75 def __str__(self) -> str: 

76 return f"LazyLinearTensor (shape={tuple(self.shape)})" 

77 

78 def __repr__(self) -> str: 

79 return str(self) 

80 

81 def __len__(self) -> int: 

82 return self.shape[0] 

83 

84 def __iter__(self): 

85 for i in range(len(self)): 

86 yield self[i] 

87 

88 def float(self): 

89 x = super().float() 

90 x.input = self.input.float() 

91 x.weight = self.weight.float() 

92 x.bias = self.bias.float() 

93 return x 

94 

95 def half(self): 

96 x = super().half() 

97 x.input = self.input.half() 

98 x.weight = self.weight.half() 

99 x.bias = self.bias.half() 

100 return x