Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from pathlib import Path 

2import numpy as np 

3import plotly.graph_objects as go 

4import PIL 

5from PIL import Image 

6from plotly.subplots import make_subplots 

7import numpy as np 

8 

9from .transforms import read3D 

10 

11 

12def format_fig(fig): 

13 fig.update_layout( 

14 plot_bgcolor="white", 

15 title_font_color="black", 

16 font=dict( 

17 family="Linux Libertine Display O", 

18 size=18, 

19 color="black", 

20 ), 

21 ) 

22 

23 

24def render_volume(volume, width: int = 600, height: int = 600, title: str = "Volume") -> go.Figure: 

25 """ 

26 Renders a volume with a number of cross-sections that can be animated. 

27 

28 Based on https://plotly.com/python/visualizing-mri-volume-slices/ by Emilia Petrisor 

29 

30 Args: 

31 volume (ndarray): The volume to be rendered. 

32 width (int, optional): The width of the figure. Defaults to 600. 

33 height (int, optional): The height of the figure. Defaults to 600. 

34 height (str, optional): The title of the figure. Defaults to 'Volume'. 

35 

36 Returns: 

37 plotly.graph_objects.Figure: A plotly figure representing the volume. 

38 """ 

39 r, c = volume[0].shape 

40 

41 nb_frames = len(volume) 

42 

43 fig = go.Figure(frames=[go.Frame(data=go.Surface( 

44 z=k * np.ones((r, c)), 

45 surfacecolor=np.flipud(volume[k]), 

46 cmin=0, cmax=1.0 

47 ), 

48 # you need to name the frame for the animation to behave properly 

49 name=str(k) 

50 ) 

51 for k in range(nb_frames)]) 

52 

53 # Add data to be displayed before animation starts 

54 fig.add_trace(go.Surface( 

55 z=0.0 * np.ones((r, c)), 

56 surfacecolor=np.flipud(volume[0]), 

57 colorscale='Gray', 

58 cmin=0, cmax=1.0, 

59 colorbar=dict(thickness=20, ticklen=4) 

60 )) 

61 

62 def frame_args(duration): 

63 return { 

64 "frame": {"duration": duration}, 

65 "mode": "immediate", 

66 "fromcurrent": True, 

67 "transition": {"duration": duration, "easing": "linear"}, 

68 } 

69 

70 sliders = [ 

71 { 

72 "pad": {"b": 10, "t": 60}, 

73 "len": 0.9, 

74 "x": 0.1, 

75 "y": 0, 

76 "steps": [ 

77 { 

78 "args": [[f.name], frame_args(0)], 

79 "label": str(k), 

80 "method": "animate", 

81 } 

82 for k, f in enumerate(fig.frames) 

83 ], 

84 } 

85 ] 

86 

87 # Layout 

88 fig.update_layout( 

89 title=title, 

90 width=width, 

91 height=height, 

92 scene=dict( 

93 zaxis=dict(range=[0, nb_frames-1], autorange=False), 

94 aspectratio=dict(x=1, y=1, z=1), 

95 ), 

96 updatemenus=[ 

97 { 

98 "buttons": [ 

99 { 

100 "args": [None, frame_args(50)], 

101 "label": "▶", # play symbol 

102 "method": "animate", 

103 }, 

104 { 

105 "args": [[None], frame_args(0)], 

106 "label": "◼", # pause symbol 

107 "method": "animate", 

108 }, 

109 ], 

110 "direction": "left", 

111 "pad": {"r": 10, "t": 70}, 

112 "type": "buttons", 

113 "x": 0.1, 

114 "y": 0, 

115 } 

116 ], 

117 sliders=sliders 

118 ) 

119 

120 return fig 

121 

122 

123 # upscaled_images = app( 

124 # items=downscaled_images,  

125 # item_dir=[],  

126 # pretrained=str(pretrained), 

127 # width=500,  

128 # height=500, 

129 # return_images=True, 

130 # ) 

131 

132def comparison_plot(originals, downscaled_images, upscaled_images, titles, crops): 

133 fig = make_subplots( 

134 rows=len(originals), 

135 cols=5, 

136 subplot_titles=( 

137 "Original", 

138 "Cropped", 

139 "Downscaled", 

140 "Upscaled", 

141 "Difference", 

142 ), 

143 vertical_spacing = 0.02, 

144 horizontal_spacing = 0.02, 

145 ) 

146 

147 for row, (original, downscaled, upscaled, title, crop) in enumerate(zip(originals, downscaled_images, upscaled_images, titles, crops)): 

148 original_im = Image.open(original) 

149 downscaled_im = Image.open(downscaled).resize( (original_im.size[0], original_im.size[1]), resample=PIL.Image.Resampling.NEAREST) 

150 

151 crop_x = crop[0:2] 

152 crop_y = (crop[3], crop[2]) 

153 

154 if isinstance(upscaled, (Path, str)): 

155 upscaled = Image.open(upscaled) 

156 

157 difference = np.asarray(upscaled).astype(int) - np.asarray(original_im.convert("RGB"))[:,:,0].astype(int) 

158 # squared_error = np.power(difference.astype(float)/255, 2.0) 

159 

160 fig.add_trace( go.Image(z=np.asarray(original_im.convert("RGB"))), row=row+1, col=1) 

161 fig.add_trace( go.Image(z=np.asarray(original_im.convert("RGB"))), row=row+1, col=2) 

162 fig.add_trace( go.Image(z=np.asarray(downscaled_im.convert("RGB"))), row=row+1, col=3) 

163 fig.add_trace( go.Image(z=np.asarray(upscaled.convert("RGB")).astype(int)), row=row+1, col=4) 

164 fig.add_trace( go.Heatmap(z=difference.astype(float)/255, coloraxis="coloraxis"), row=row+1, col=5) 

165 

166 update_dict = { 

167 f"yaxis{1+row*5}_title":title, 

168 f"xaxis{2+row*5}_range":(crop_x[0],crop_x[1]), 

169 f"yaxis{2+row*5}_range":(crop_y[0],crop_y[1]), 

170 f"xaxis{3+row*5}_range":(crop_x[0],crop_x[1]), 

171 f"yaxis{3+row*5}_range":(crop_y[0],crop_y[1]), 

172 f"xaxis{4+row*5}_range":(crop_x[0],crop_x[1]), 

173 f"yaxis{4+row*5}_range":(crop_y[0],crop_y[1]), 

174 f"xaxis{5+row*5}_range":(crop_x[0],crop_x[1]), 

175 f"yaxis{5+row*5}_range":(crop_y[0],crop_y[1]), 

176 } 

177 fig.update_layout(**update_dict) 

178 fig.add_shape(type="rect", 

179 x0=crop_x[0], y0=crop_y[0], x1=crop_x[1], y1=crop_y[1], 

180 line=dict(color="Red"), 

181 row=row+1, 

182 col=1, 

183 ) 

184 fig.update_layout(plot_bgcolor='rgba(0,0,0,0)') 

185 fig.update_xaxes(showticklabels=False) 

186 fig.update_yaxes(showticklabels=False) 

187 fig.update_layout( 

188 height=150 + 200 * len(originals), 

189 width=1200, 

190 ) 

191 format_fig(fig) 

192 

193 fig.update_layout(coloraxis=dict(colorscale='Rainbow'), showlegend=False) 

194 fig.update_annotations(font_size=24) 

195 

196 return fig 

197 

198 

199def add_volume_face_traces(fig, volume, coloraxis="coloraxis", **kwargs): 

200 """ Adds six faces of a volume to a plotly figure. """ 

201 x1 = np.zeros(volume.shape[0], dtype=int) 

202 y1 = np.arange(volume.shape[1]) 

203 z1 = np.arange(volume.shape[2]) 

204 surfacecolor = volume[x1[0],:,:] 

205 fig.add_trace( 

206 go.Surface(x=x1, y=y1, z=np.array([z1] * len(x1)), surfacecolor=surfacecolor, text=surfacecolor, coloraxis=coloraxis, name="left"), 

207 **kwargs 

208 ) 

209 

210 # RIGHT x = max 

211 x1[:] = volume.shape[0] - 1 

212 surfacecolor = volume[x1[0],:,:] 

213 fig.add_trace( 

214 go.Surface(x=x1, y=y1, z=np.array([z1] * len(x1)), surfacecolor=surfacecolor, text=surfacecolor, coloraxis=coloraxis, name="right"), 

215 **kwargs 

216 ) 

217 

218 # BACK y = 0 

219 x1 = np.arange(volume.shape[0]) 

220 y1 = np.zeros(volume.shape[1], dtype=int) 

221 surfacecolor = volume[:,y1[0],:].T 

222 fig.add_trace( 

223 go.Surface(x=x1, y=y1, z=np.array([z1] * len(y1)).T, surfacecolor=surfacecolor, text=surfacecolor, coloraxis=coloraxis, name="back"), 

224 **kwargs 

225 ) 

226 

227 # FRONT y = max 

228 y1[:] = volume.shape[1] - 1 

229 surfacecolor = volume[:,y1[0],:].T 

230 fig.add_trace( 

231 go.Surface(x=x1, y=y1, z=np.array([z1] * len(y1)).T, surfacecolor=surfacecolor, text=surfacecolor, coloraxis=coloraxis, name="front"), 

232 **kwargs 

233 ) 

234 

235 # BOTTOM z = 0 

236 x1 = np.arange(volume.shape[0]) 

237 y1 = np.arange(volume.shape[1]) 

238 z1 = np.zeros((volume.shape[0],volume.shape[1]), dtype=int) 

239 surfacecolor = volume[:,:,z1[0,0]].T 

240 fig.add_trace( 

241 go.Surface(x=x1, y=y1, z=z1, surfacecolor=surfacecolor, text=surfacecolor, coloraxis=coloraxis, name="bottom"), 

242 **kwargs 

243 ) 

244 

245 # TOP z = 0 

246 z1[:,:] = volume.shape[2] - 1 

247 surfacecolor = volume[:,:,z1[0,0]].T 

248 fig.add_trace( 

249 go.Surface(x=x1, y=y1, z=z1, surfacecolor=surfacecolor, text=surfacecolor, coloraxis=coloraxis, name="top"), 

250 **kwargs 

251 ) 

252 return fig 

253 

254 

255def comparison_plot3D(originals, downscaled_volumes, upscaled_volumes, titles): 

256 fig = make_subplots( 

257 rows=len(originals), 

258 cols=3, 

259 subplot_titles=( 

260 "Original", 

261 "Downscaled", 

262 "Upscaled", 

263 # "Difference", 

264 ), 

265 vertical_spacing = 0.02, 

266 horizontal_spacing = 0.02, 

267 specs=[[{'type':"surface"}, {'type':"surface"}, {'type':"surface"}, ]]*len(originals), # hack 

268 ) 

269 

270 axis = dict(showgrid=False, showticklabels=False, showaxeslabels=False, title="", showbackground=False) 

271 scene = dict( 

272 xaxis=axis, 

273 yaxis=axis, 

274 zaxis=axis, 

275 ) 

276 

277 for row, (original, downscaled, upscaled, title) in enumerate(zip(originals, downscaled_volumes, upscaled_volumes, titles)): 

278 original = read3D(original) if isinstance(original, (str, Path)) else original 

279 downscaled = read3D(downscaled) if isinstance(downscaled, (str, Path)) else downscaled 

280 upscaled = read3D(upscaled) if isinstance(upscaled, (str, Path)) else upscaled 

281 

282 # upscaled = (upscaled - upscaled.mean())/upscaled.std() 

283 # upscaled = upscaled * downscaled.std() + downscaled.mean() 

284 # breakpoint() 

285 # upscaled *= 255.0 

286 # breakpoint() 

287 

288 add_volume_face_traces(fig, original, row=row+1, col=1) 

289 add_volume_face_traces(fig, downscaled, row=row+1, col=2) 

290 add_volume_face_traces(fig, upscaled, row=row+1, col=3) 

291 # add_volume_face_traces(fig, upscaled-original, row=row+1, col=4, coloraxis="coloraxis2") 

292 

293 scenes = {f"scene{row*4+column}":scene for column in range(1,5)} 

294 fig.update_layout(**scenes) 

295 

296 fig.add_annotation( 

297 text=title, 

298 xref="paper", 

299 yref="paper", 

300 x=0.0, 

301 y=1.0-1.0*row/len(originals)-0.5/len(originals), 

302 showarrow=False, 

303 xanchor="right", 

304 yanchor="middle", 

305 textangle=-90 

306 

307 ) 

308 

309 

310 

311 fig.update_layout(coloraxis=dict(colorscale='gray', cmin=0.0, cmax=1.0, showscale=False), showlegend=False) 

312 fig.update_layout(coloraxis2=dict(colorscale='Rainbow'), showlegend=False) 

313 

314 fig.update_layout( 

315 scene1=scene, 

316 scene2=scene, 

317 scene3=scene, 

318 scene4=scene, 

319 ) 

320 fig.update_layout( 

321 height=250 + 200 * len(originals), 

322 width=1200, 

323 ) 

324 format_fig(fig) 

325 return fig