Coverage for polytorch/plots.py : 100.00%
Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
2import typing as typing___
4from pathlib import Path
5from .embedding import PolyEmbedding
6import plotly .graph_objects as go
7from sklearn .decomposition import PCA
8from itertools import cycle
10import torch
12from .data import CategoricalData ,OrdinalData
15def format_fig (fig )->go .Figure :
16 """Formats a plotly figure in a nicer way."""
17 fig .update_layout (
18 width =1000 ,
19 height =800 ,
20 plot_bgcolor ="white",
21 title_font_color ="black",
22 font =dict (
23 family ="Linux Libertine Display O",
24 size =18 ,
25 color ="black",
26 ),
27 )
28 gridcolor ="#dddddd"
29 fig .update_xaxes (gridcolor =gridcolor )
30 fig .update_yaxes (gridcolor =gridcolor )
32 fig .update_xaxes (showline =True ,linewidth =1 ,linecolor ='black',mirror =True ,ticks ='outside',zeroline =True ,zerolinewidth =1 ,zerolinecolor ='black')
33 fig .update_yaxes (showline =True ,linewidth =1 ,linecolor ='black',mirror =True ,ticks ='outside',zeroline =True ,zerolinewidth =1 ,zerolinecolor ='black')
35 return fig
38def plot_embedding (embedding :PolyEmbedding ,n_components :int =2 ,show :bool =False ,output_path :typing___.Union [typing___.Union [str ,Path ],None ]=None )->go .Figure :
39 """
40 Plots the embedding in 2D or 3D.
42 Args:
43 embedding (PolyEmbedding): The embedding to plot.
44 n_components (int, optional): The number of principal components to plot. Can be 2 or 3. Defaults to 2.
45 show (bool, optional): Whether to show the plot. Defaults to False.
46 output_path (str|Path|None, optional): The path to save the plot to.
47 Can be in HTML, PNG, JPEG, SVG or PDF. Defaults to None.
49 Returns:
50 go.Figure: The plotly figure
51 """
52 if n_components not in [2 ,3 ]:
53 raise ValueError (f"n_components must be 2 or 3, not {n_components}")
55 # get embedding weights
56 weights =[]
57 labels =[]
58 colors =[]
60 # The default colors are same as px.colors.qualitative.Plotly
61 # I'm not using that directly because that requires plotly express
62 # which requires pandas to be installed
63 cmap =cycle (['#636EFA','#EF553B','#00CC96','#AB63FA','#FFA15A','#19D3F3','#FF6692','#B6E880','#FF97FF','#FECB52'])
65 for input_type ,module in zip (embedding .input_types ,embedding .embedding_modules ):
66 weight =module .weight
67 if len (weight .shape )==1 :
68 weight =weight .unsqueeze (0 )
70 weights .append (weight )
72 if isinstance (input_type ,CategoricalData )and not isinstance (input_type ,OrdinalData ):
73 my_labels =(
74 input_type .labels if input_type .labels is not None
75 else [f"{input_type.name}_{i}"for i in range (input_type .category_count )]
76 )
77 labels .extend (my_labels )
79 my_colors =(
80 input_type .colors if input_type .colors is not None
81 else [next (cmap )for _ in range (input_type .category_count )]
82 )
83 colors .extend (my_colors )
84 else :
85 labels .append (input_type .name )
86 colors .append (getattr (input_type ,"color",'')or next (cmap ))
88 weights =torch .cat (weights ,dim =0 ).detach ()
90 # Perform a principal component analysis
91 pca =PCA (n_components =n_components )
92 weights_reduced =pca .fit_transform (weights )
94 # plot
95 fig =go .Figure ()
97 # This is done as a loop so that the legend has all the different categorical labels separate
98 # This will be a large list in the legend potentially
99 # This could be an option in the future
100 for vector ,label ,color in zip (weights_reduced ,labels ,colors ):
101 if n_components ==2 :
102 fig .add_trace (go .Scatter (
103 x =[vector [0 ]],
104 y =[vector [1 ]],
105 mode ='markers',
106 name =label ,
107 marker_color =color ,
108 ))
109 fig .update_xaxes (title_text ="Component 1")
110 fig .update_yaxes (title_text ="Component 2")
111 elif n_components ==3 :
112 fig .add_trace (go .Scatter3d (
113 x =[vector [0 ]],
114 y =[vector [1 ]],
115 z =[vector [2 ]],
116 mode ='markers',
117 name =label ,
118 marker_color =color ,
119 ))
120 fig .update_layout (scene =dict (
121 xaxis_title ='Component 1',
122 yaxis_title ='Component 2',
123 zaxis_title ='Component 3',
124 ))
126 fig .update_layout (showlegend =True )
127 format_fig (fig )
129 # output image as file
130 if output_path :
131 output_path =Path (output_path )
132 output_path .parent .mkdir (parents =True ,exist_ok =True )
133 if output_path .suffix =='.html':
134 fig .write_html (str (output_path ))
135 else :
136 fig .write_image (str (output_path ))
138 if show :
139 fig .show ()
141 return fig