Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from pathlib import Path
2import numpy as np
3import plotly.graph_objects as go
4import PIL
5from PIL import Image
6from plotly.subplots import make_subplots
7import numpy as np
9from .transforms import read3D
12def format_fig(fig):
13 fig.update_layout(
14 plot_bgcolor="white",
15 title_font_color="black",
16 font=dict(
17 family="Linux Libertine Display O",
18 size=18,
19 color="black",
20 ),
21 )
24def render_volume(volume, width: int = 600, height: int = 600, title: str = "Volume") -> go.Figure:
25 """
26 Renders a volume with a number of cross-sections that can be animated.
28 Based on https://plotly.com/python/visualizing-mri-volume-slices/ by Emilia Petrisor
30 Args:
31 volume (ndarray): The volume to be rendered.
32 width (int, optional): The width of the figure. Defaults to 600.
33 height (int, optional): The height of the figure. Defaults to 600.
34 height (str, optional): The title of the figure. Defaults to 'Volume'.
36 Returns:
37 plotly.graph_objects.Figure: A plotly figure representing the volume.
38 """
39 r, c = volume[0].shape
41 nb_frames = len(volume)
43 fig = go.Figure(frames=[go.Frame(data=go.Surface(
44 z=k * np.ones((r, c)),
45 surfacecolor=np.flipud(volume[k]),
46 cmin=0, cmax=1.0
47 ),
48 # you need to name the frame for the animation to behave properly
49 name=str(k)
50 )
51 for k in range(nb_frames)])
53 # Add data to be displayed before animation starts
54 fig.add_trace(go.Surface(
55 z=0.0 * np.ones((r, c)),
56 surfacecolor=np.flipud(volume[0]),
57 colorscale='Gray',
58 cmin=0, cmax=1.0,
59 colorbar=dict(thickness=20, ticklen=4)
60 ))
62 def frame_args(duration):
63 return {
64 "frame": {"duration": duration},
65 "mode": "immediate",
66 "fromcurrent": True,
67 "transition": {"duration": duration, "easing": "linear"},
68 }
70 sliders = [
71 {
72 "pad": {"b": 10, "t": 60},
73 "len": 0.9,
74 "x": 0.1,
75 "y": 0,
76 "steps": [
77 {
78 "args": [[f.name], frame_args(0)],
79 "label": str(k),
80 "method": "animate",
81 }
82 for k, f in enumerate(fig.frames)
83 ],
84 }
85 ]
87 # Layout
88 fig.update_layout(
89 title=title,
90 width=width,
91 height=height,
92 scene=dict(
93 zaxis=dict(range=[0, nb_frames-1], autorange=False),
94 aspectratio=dict(x=1, y=1, z=1),
95 ),
96 updatemenus=[
97 {
98 "buttons": [
99 {
100 "args": [None, frame_args(50)],
101 "label": "▶", # play symbol
102 "method": "animate",
103 },
104 {
105 "args": [[None], frame_args(0)],
106 "label": "◼", # pause symbol
107 "method": "animate",
108 },
109 ],
110 "direction": "left",
111 "pad": {"r": 10, "t": 70},
112 "type": "buttons",
113 "x": 0.1,
114 "y": 0,
115 }
116 ],
117 sliders=sliders
118 )
120 return fig
123 # upscaled_images = app(
124 # items=downscaled_images,
125 # item_dir=[],
126 # pretrained=str(pretrained),
127 # width=500,
128 # height=500,
129 # return_images=True,
130 # )
132def comparison_plot(originals, downscaled_images, upscaled_images, titles, crops):
133 fig = make_subplots(
134 rows=len(originals),
135 cols=5,
136 subplot_titles=(
137 "Original",
138 "Cropped",
139 "Downscaled",
140 "Upscaled",
141 "Difference",
142 ),
143 vertical_spacing = 0.02,
144 horizontal_spacing = 0.02,
145 )
147 for row, (original, downscaled, upscaled, title, crop) in enumerate(zip(originals, downscaled_images, upscaled_images, titles, crops)):
148 original_im = Image.open(original)
149 downscaled_im = Image.open(downscaled).resize( (original_im.size[0], original_im.size[1]), resample=PIL.Image.Resampling.NEAREST)
151 crop_x = crop[0:2]
152 crop_y = (crop[3], crop[2])
154 if isinstance(upscaled, (Path, str)):
155 upscaled = Image.open(upscaled)
157 difference = np.asarray(upscaled).astype(int) - np.asarray(original_im.convert("RGB"))[:,:,0].astype(int)
158 # squared_error = np.power(difference.astype(float)/255, 2.0)
160 fig.add_trace( go.Image(z=np.asarray(original_im.convert("RGB"))), row=row+1, col=1)
161 fig.add_trace( go.Image(z=np.asarray(original_im.convert("RGB"))), row=row+1, col=2)
162 fig.add_trace( go.Image(z=np.asarray(downscaled_im.convert("RGB"))), row=row+1, col=3)
163 fig.add_trace( go.Image(z=np.asarray(upscaled.convert("RGB")).astype(int)), row=row+1, col=4)
164 fig.add_trace( go.Heatmap(z=difference.astype(float)/255, coloraxis="coloraxis"), row=row+1, col=5)
166 update_dict = {
167 f"yaxis{1+row*5}_title":title,
168 f"xaxis{2+row*5}_range":(crop_x[0],crop_x[1]),
169 f"yaxis{2+row*5}_range":(crop_y[0],crop_y[1]),
170 f"xaxis{3+row*5}_range":(crop_x[0],crop_x[1]),
171 f"yaxis{3+row*5}_range":(crop_y[0],crop_y[1]),
172 f"xaxis{4+row*5}_range":(crop_x[0],crop_x[1]),
173 f"yaxis{4+row*5}_range":(crop_y[0],crop_y[1]),
174 f"xaxis{5+row*5}_range":(crop_x[0],crop_x[1]),
175 f"yaxis{5+row*5}_range":(crop_y[0],crop_y[1]),
176 }
177 fig.update_layout(**update_dict)
178 fig.add_shape(type="rect",
179 x0=crop_x[0], y0=crop_y[0], x1=crop_x[1], y1=crop_y[1],
180 line=dict(color="Red"),
181 row=row+1,
182 col=1,
183 )
184 fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')
185 fig.update_xaxes(showticklabels=False)
186 fig.update_yaxes(showticklabels=False)
187 fig.update_layout(
188 height=150 + 200 * len(originals),
189 width=1200,
190 )
191 format_fig(fig)
193 fig.update_layout(coloraxis=dict(colorscale='Rainbow'), showlegend=False)
194 fig.update_annotations(font_size=24)
196 return fig
199def add_volume_face_traces(fig, volume, coloraxis="coloraxis", **kwargs):
200 """ Adds six faces of a volume to a plotly figure. """
201 x1 = np.zeros(volume.shape[0], dtype=int)
202 y1 = np.arange(volume.shape[1])
203 z1 = np.arange(volume.shape[2])
204 surfacecolor = volume[x1[0],:,:]
205 fig.add_trace(
206 go.Surface(x=x1, y=y1, z=np.array([z1] * len(x1)), surfacecolor=surfacecolor, text=surfacecolor, coloraxis=coloraxis, name="left"),
207 **kwargs
208 )
210 # RIGHT x = max
211 x1[:] = volume.shape[0] - 1
212 surfacecolor = volume[x1[0],:,:]
213 fig.add_trace(
214 go.Surface(x=x1, y=y1, z=np.array([z1] * len(x1)), surfacecolor=surfacecolor, text=surfacecolor, coloraxis=coloraxis, name="right"),
215 **kwargs
216 )
218 # BACK y = 0
219 x1 = np.arange(volume.shape[0])
220 y1 = np.zeros(volume.shape[1], dtype=int)
221 surfacecolor = volume[:,y1[0],:].T
222 fig.add_trace(
223 go.Surface(x=x1, y=y1, z=np.array([z1] * len(y1)).T, surfacecolor=surfacecolor, text=surfacecolor, coloraxis=coloraxis, name="back"),
224 **kwargs
225 )
227 # FRONT y = max
228 y1[:] = volume.shape[1] - 1
229 surfacecolor = volume[:,y1[0],:].T
230 fig.add_trace(
231 go.Surface(x=x1, y=y1, z=np.array([z1] * len(y1)).T, surfacecolor=surfacecolor, text=surfacecolor, coloraxis=coloraxis, name="front"),
232 **kwargs
233 )
235 # BOTTOM z = 0
236 x1 = np.arange(volume.shape[0])
237 y1 = np.arange(volume.shape[1])
238 z1 = np.zeros((volume.shape[0],volume.shape[1]), dtype=int)
239 surfacecolor = volume[:,:,z1[0,0]].T
240 fig.add_trace(
241 go.Surface(x=x1, y=y1, z=z1, surfacecolor=surfacecolor, text=surfacecolor, coloraxis=coloraxis, name="bottom"),
242 **kwargs
243 )
245 # TOP z = 0
246 z1[:,:] = volume.shape[2] - 1
247 surfacecolor = volume[:,:,z1[0,0]].T
248 fig.add_trace(
249 go.Surface(x=x1, y=y1, z=z1, surfacecolor=surfacecolor, text=surfacecolor, coloraxis=coloraxis, name="top"),
250 **kwargs
251 )
252 return fig
255def comparison_plot3D(originals, downscaled_volumes, upscaled_volumes, titles):
256 fig = make_subplots(
257 rows=len(originals),
258 cols=3,
259 subplot_titles=(
260 "Original",
261 "Downscaled",
262 "Upscaled",
263 # "Difference",
264 ),
265 vertical_spacing = 0.02,
266 horizontal_spacing = 0.02,
267 specs=[[{'type':"surface"}, {'type':"surface"}, {'type':"surface"}, ]]*len(originals), # hack
268 )
270 axis = dict(showgrid=False, showticklabels=False, showaxeslabels=False, title="", showbackground=False)
271 scene = dict(
272 xaxis=axis,
273 yaxis=axis,
274 zaxis=axis,
275 )
277 for row, (original, downscaled, upscaled, title) in enumerate(zip(originals, downscaled_volumes, upscaled_volumes, titles)):
278 original = read3D(original) if isinstance(original, (str, Path)) else original
279 downscaled = read3D(downscaled) if isinstance(downscaled, (str, Path)) else downscaled
280 upscaled = read3D(upscaled) if isinstance(upscaled, (str, Path)) else upscaled
282 # upscaled = (upscaled - upscaled.mean())/upscaled.std()
283 # upscaled = upscaled * downscaled.std() + downscaled.mean()
284 # breakpoint()
285 # upscaled *= 255.0
286 # breakpoint()
288 add_volume_face_traces(fig, original, row=row+1, col=1)
289 add_volume_face_traces(fig, downscaled, row=row+1, col=2)
290 add_volume_face_traces(fig, upscaled, row=row+1, col=3)
291 # add_volume_face_traces(fig, upscaled-original, row=row+1, col=4, coloraxis="coloraxis2")
293 scenes = {f"scene{row*4+column}":scene for column in range(1,5)}
294 fig.update_layout(**scenes)
296 fig.add_annotation(
297 text=title,
298 xref="paper",
299 yref="paper",
300 x=0.0,
301 y=1.0-1.0*row/len(originals)-0.5/len(originals),
302 showarrow=False,
303 xanchor="right",
304 yanchor="middle",
305 textangle=-90
307 )
311 fig.update_layout(coloraxis=dict(colorscale='gray', cmin=0.0, cmax=1.0, showscale=False), showlegend=False)
312 fig.update_layout(coloraxis2=dict(colorscale='Rainbow'), showlegend=False)
314 fig.update_layout(
315 scene1=scene,
316 scene2=scene,
317 scene3=scene,
318 scene4=scene,
319 )
320 fig.update_layout(
321 height=250 + 200 * len(originals),
322 width=1200,
323 )
324 format_fig(fig)
325 return fig