Coverage for polytorch/embedding.py : 100.00%
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
2import typing as typing___
4from typing import List
6import torch
7from torch import nn
8from torch .nn .parameter import Parameter
9import torch .nn .functional as F
10import plotly .graph_objects as go
12from .data import PolyData
13from .util import permute_feature_axis
15class ContinuousEmbedding (nn .Module ):
16 def __init__ (
17 self ,
18 embedding_size :int ,
19 bias :bool =True ,
20 device =None ,
21 dtype =None ,
22 **kwargs ,
23 ):
24 super ().__init__ (**kwargs )
26 self .embedding_size =embedding_size
28 factory_kwargs ={'device':device ,'dtype':dtype }
29 self .weight =Parameter (torch .empty ((embedding_size ,),**factory_kwargs ),requires_grad =True )
30 if bias :
31 self .bias =Parameter (torch .empty ((embedding_size ,),**factory_kwargs ),requires_grad =True )
32 else :
33 self .bias =Parameter (torch .zeros ((embedding_size ,),**factory_kwargs ),requires_grad =False )
35 self .reset_parameters ()
37 def forward (self ,input ):
38 x =input .flatten ().unsqueeze (1 )
39 embedded =self .bias +x *self .weight .unsqueeze (0 )
40 embedded =embedded .reshape (input .shape +(-1 ,))
42 return embedded
44 def reset_parameters (self )->None :
45 torch .nn .init .normal_ (self .weight )
46 torch .nn .init .constant_ (self .bias ,0.0 )
49class OrdinalEmbedding (ContinuousEmbedding ):
50 def __init__ (
51 self ,
52 category_count ,
53 embedding_size ,
54 bias :bool =True ,
55 device =None ,
56 dtype =None ,
57 **kwargs ,
58 ):
59 super ().__init__ (
60 embedding_size ,
61 bias =bias ,
62 device =device ,
63 dtype =dtype ,
64 **kwargs ,
65 )
66 factory_kwargs ={'device':device ,'dtype':dtype }
67 self .distance_scores =Parameter (torch .ones ((category_count -1 ,),**factory_kwargs ),requires_grad =True )
69 def forward (self ,x ):
70 distances =torch .cumsum (F .softmax (self .distance_scores ,dim =0 ),dim =0 )
72 # prepend zero
73 distances =torch .cat ([torch .zeros ((1 ,),device =distances .device ,dtype =distances .dtype ),distances ])
74 distance =torch .gather (distances ,0 ,x .flatten ())
75 embedded =self .bias +distance .unsqueeze (1 )*self .weight .unsqueeze (0 )
76 embedded =embedded .reshape (x .shape +(-1 ,))
78 return embedded
81class PolyEmbedding (nn .Module ):
82 def __init__ (
83 self ,
84 input_types :List [PolyData ],
85 embedding_size :int ,
86 feature_axis :int =-1 ,
87 **kwargs ,
88 ):
89 super ().__init__ (**kwargs )
90 self .input_types =input_types
91 self .embedding_size =embedding_size
92 self .embedding_modules =nn .ModuleList ([
93 input .embedding_module (embedding_size )for input in input_types
94 ])
95 self .feature_axis =feature_axis
97 def forward (self ,*inputs ):
98 shape =inputs [0 ].shape +(self .embedding_size ,)
99 embedded =torch .zeros (shape ,device =inputs [0 ].device )
101 for input ,module in zip (inputs ,self .embedding_modules ):
102 embedded +=module (input )
104 return permute_feature_axis (embedded ,old_axis =-1 ,new_axis =self .feature_axis )
106 def plot (self ,**kwargs )->go .Figure :
107 from .plots import plot_embedding
108 return plot_embedding (self ,**kwargs )