Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2import typing as typing___ 

3 

4from pathlib import Path 

5from .embedding import PolyEmbedding 

6import plotly .graph_objects as go 

7from sklearn .decomposition import PCA 

8from itertools import cycle 

9 

10import torch 

11 

12from .data import CategoricalData ,OrdinalData 

13 

14 

15def format_fig (fig )->go .Figure : 

16 """Formats a plotly figure in a nicer way.""" 

17 fig .update_layout ( 

18 width =1000 , 

19 height =800 , 

20 plot_bgcolor ="white", 

21 title_font_color ="black", 

22 font =dict ( 

23 family ="Linux Libertine Display O", 

24 size =18 , 

25 color ="black", 

26 ), 

27 ) 

28 gridcolor ="#dddddd" 

29 fig .update_xaxes (gridcolor =gridcolor ) 

30 fig .update_yaxes (gridcolor =gridcolor ) 

31 

32 fig .update_xaxes (showline =True ,linewidth =1 ,linecolor ='black',mirror =True ,ticks ='outside',zeroline =True ,zerolinewidth =1 ,zerolinecolor ='black') 

33 fig .update_yaxes (showline =True ,linewidth =1 ,linecolor ='black',mirror =True ,ticks ='outside',zeroline =True ,zerolinewidth =1 ,zerolinecolor ='black') 

34 

35 return fig 

36 

37 

38def plot_embedding (embedding :PolyEmbedding ,n_components :int =2 ,show :bool =False ,output_path :typing___.Union [typing___.Union [str ,Path ],None ]=None )->go .Figure : 

39 """ 

40 Plots the embedding in 2D or 3D. 

41 

42 Args: 

43 embedding (PolyEmbedding): The embedding to plot. 

44 n_components (int, optional): The number of principal components to plot. Can be 2 or 3. Defaults to 2. 

45 show (bool, optional): Whether to show the plot. Defaults to False. 

46 output_path (str|Path|None, optional): The path to save the plot to.  

47 Can be in HTML, PNG, JPEG, SVG or PDF. Defaults to None. 

48 

49 Returns: 

50 go.Figure: The plotly figure 

51 """ 

52 if n_components not in [2 ,3 ]: 

53 raise ValueError (f"n_components must be 2 or 3, not {n_components}") 

54 

55 # get embedding weights 

56 weights =[] 

57 labels =[] 

58 colors =[] 

59 

60 # The default colors are same as px.colors.qualitative.Plotly 

61 # I'm not using that directly because that requires plotly express 

62 # which requires pandas to be installed 

63 cmap =cycle (['#636EFA','#EF553B','#00CC96','#AB63FA','#FFA15A','#19D3F3','#FF6692','#B6E880','#FF97FF','#FECB52']) 

64 

65 for input_type ,module in zip (embedding .input_types ,embedding .embedding_modules ): 

66 weight =module .weight 

67 if len (weight .shape )==1 : 

68 weight =weight .unsqueeze (0 ) 

69 

70 weights .append (weight ) 

71 

72 if isinstance (input_type ,CategoricalData )and not isinstance (input_type ,OrdinalData ): 

73 my_labels =( 

74 input_type .labels if input_type .labels is not None 

75 else [f"{input_type.name}_{i}"for i in range (input_type .category_count )] 

76 ) 

77 labels .extend (my_labels ) 

78 

79 my_colors =( 

80 input_type .colors if input_type .colors is not None 

81 else [next (cmap )for _ in range (input_type .category_count )] 

82 ) 

83 colors .extend (my_colors ) 

84 else : 

85 labels .append (input_type .name ) 

86 colors .append (getattr (input_type ,"color",'')or next (cmap )) 

87 

88 weights =torch .cat (weights ,dim =0 ).detach () 

89 

90 # Perform a principal component analysis 

91 pca =PCA (n_components =n_components ) 

92 weights_reduced =pca .fit_transform (weights ) 

93 

94 # plot 

95 fig =go .Figure () 

96 

97 # This is done as a loop so that the legend has all the different categorical labels separate 

98 # This will be a large list in the legend potentially 

99 # This could be an option in the future 

100 for vector ,label ,color in zip (weights_reduced ,labels ,colors ): 

101 if n_components ==2 : 

102 fig .add_trace (go .Scatter ( 

103 x =[vector [0 ]], 

104 y =[vector [1 ]], 

105 mode ='markers', 

106 name =label , 

107 marker_color =color , 

108 )) 

109 fig .update_xaxes (title_text ="Component 1") 

110 fig .update_yaxes (title_text ="Component 2") 

111 elif n_components ==3 : 

112 fig .add_trace (go .Scatter3d ( 

113 x =[vector [0 ]], 

114 y =[vector [1 ]], 

115 z =[vector [2 ]], 

116 mode ='markers', 

117 name =label , 

118 marker_color =color , 

119 )) 

120 fig .update_layout (scene =dict ( 

121 xaxis_title ='Component 1', 

122 yaxis_title ='Component 2', 

123 zaxis_title ='Component 3', 

124 )) 

125 

126 fig .update_layout (showlegend =True ) 

127 format_fig (fig ) 

128 

129 # output image as file 

130 if output_path : 

131 output_path =Path (output_path ) 

132 output_path .parent .mkdir (parents =True ,exist_ok =True ) 

133 if output_path .suffix =='.html': 

134 fig .write_html (str (output_path )) 

135 else : 

136 fig .write_image (str (output_path )) 

137 

138 if show : 

139 fig .show () 

140 

141 return fig