Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2import typing as typing___ 

3 

4from typing import List 

5 

6import torch 

7from torch import nn 

8from torch .nn .parameter import Parameter 

9import torch .nn .functional as F 

10import plotly .graph_objects as go 

11 

12from .data import PolyData 

13from .util import permute_feature_axis 

14 

15class ContinuousEmbedding (nn .Module ): 

16 def __init__ ( 

17 self , 

18 embedding_size :int , 

19 bias :bool =True , 

20 device =None , 

21 dtype =None , 

22 **kwargs , 

23 ): 

24 super ().__init__ (**kwargs ) 

25 

26 self .embedding_size =embedding_size 

27 

28 factory_kwargs ={'device':device ,'dtype':dtype } 

29 self .weight =Parameter (torch .empty ((embedding_size ,),**factory_kwargs ),requires_grad =True ) 

30 if bias : 

31 self .bias =Parameter (torch .empty ((embedding_size ,),**factory_kwargs ),requires_grad =True ) 

32 else : 

33 self .bias =Parameter (torch .zeros ((embedding_size ,),**factory_kwargs ),requires_grad =False ) 

34 

35 self .reset_parameters () 

36 

37 def forward (self ,input ): 

38 x =input .flatten ().unsqueeze (1 ) 

39 embedded =self .bias +x *self .weight .unsqueeze (0 ) 

40 embedded =embedded .reshape (input .shape +(-1 ,)) 

41 

42 return embedded 

43 

44 def reset_parameters (self )->None : 

45 torch .nn .init .normal_ (self .weight ) 

46 torch .nn .init .constant_ (self .bias ,0.0 ) 

47 

48 

49class OrdinalEmbedding (ContinuousEmbedding ): 

50 def __init__ ( 

51 self , 

52 category_count , 

53 embedding_size , 

54 bias :bool =True , 

55 device =None , 

56 dtype =None , 

57 **kwargs , 

58 ): 

59 super ().__init__ ( 

60 embedding_size , 

61 bias =bias , 

62 device =device , 

63 dtype =dtype , 

64 **kwargs , 

65 ) 

66 factory_kwargs ={'device':device ,'dtype':dtype } 

67 self .distance_scores =Parameter (torch .ones ((category_count -1 ,),**factory_kwargs ),requires_grad =True ) 

68 

69 def forward (self ,x ): 

70 distances =torch .cumsum (F .softmax (self .distance_scores ,dim =0 ),dim =0 ) 

71 

72 # prepend zero 

73 distances =torch .cat ([torch .zeros ((1 ,),device =distances .device ,dtype =distances .dtype ),distances ]) 

74 distance =torch .gather (distances ,0 ,x .flatten ()) 

75 embedded =self .bias +distance .unsqueeze (1 )*self .weight .unsqueeze (0 ) 

76 embedded =embedded .reshape (x .shape +(-1 ,)) 

77 

78 return embedded 

79 

80 

81class PolyEmbedding (nn .Module ): 

82 def __init__ ( 

83 self , 

84 input_types :List [PolyData ], 

85 embedding_size :int , 

86 feature_axis :int =-1 , 

87 **kwargs , 

88 ): 

89 super ().__init__ (**kwargs ) 

90 self .input_types =input_types 

91 self .embedding_size =embedding_size 

92 self .embedding_modules =nn .ModuleList ([ 

93 input .embedding_module (embedding_size )for input in input_types 

94 ]) 

95 self .feature_axis =feature_axis 

96 

97 def forward (self ,*inputs ): 

98 shape =inputs [0 ].shape +(self .embedding_size ,) 

99 embedded =torch .zeros (shape ,device =inputs [0 ].device ) 

100 

101 for input ,module in zip (inputs ,self .embedding_modules ): 

102 embedded +=module (input ) 

103 

104 return permute_feature_axis (embedded ,old_axis =-1 ,new_axis =self .feature_axis ) 

105 

106 def plot (self ,**kwargs )->go .Figure : 

107 from .plots import plot_embedding 

108 return plot_embedding (self ,**kwargs ) 

109