Coverage for psyop/viz.py: 38.34%
1158 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-10 06:02 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-10 06:02 +0000
1# -*- coding: utf-8 -*-
2from pathlib import Path
3import numpy as np
4import pandas as pd
5import xarray as xr
7import plotly.graph_objects as go
8from plotly.subplots import make_subplots
9from plotly.colors import get_colorscale, sample_colorscale
11from .model import (
12 kernel_diag_m52,
13 kernel_m52_ard,
14 add_jitter,
15 solve_chol,
16 solve_lower,
17 feature_raw_from_artifact_or_reconstruct,
18)
19from . import opt
22def _canon_key_set(ds) -> dict[str, str]:
23 feats = [str(x) for x in ds["feature"].values.tolist()]
24 def _norm(s: str) -> str:
25 import re
26 return re.sub(r"[^a-z0-9]+", "", s.lower())
27 return {**{f: f for f in feats}, **{_norm(f): f for f in feats}}
30def _edges_from_centers(vals: np.ndarray, is_log: bool) -> tuple[float, float]:
31 """Return (min_edge, max_edge) that tightly bound a heatmap with given center coords."""
32 v = np.asarray(vals, float)
33 v = v[np.isfinite(v)]
34 if v.size == 0:
35 return (0.0, 1.0)
36 if v.size == 1:
37 # tiny pad
38 if is_log:
39 lo = max(v[0] / 1.5, 1e-12)
40 hi = v[0] * 1.5
41 else:
42 span = max(abs(v[0]) * 0.5, 1e-9)
43 lo, hi = v[0] - span, v[0] + span
44 return float(lo), float(hi)
46 if is_log:
47 lv = np.log10(v)
48 l0 = lv[0] - 0.5 * (lv[1] - lv[0])
49 lN = lv[-1] + 0.5 * (lv[-1] - lv[-2])
50 lo = 10.0 ** l0
51 hi = 10.0 ** lN
52 lo = max(lo, 1e-12)
53 hi = max(hi, lo * 1.0000001)
54 return float(lo), float(hi)
55 else:
56 d0 = v[1] - v[0]
57 dN = v[-1] - v[-2]
58 lo = v[0] - 0.5 * d0
59 hi = v[-1] + 0.5 * dN
60 if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
61 lo, hi = float(v.min()), float(v.max())
62 return float(lo), float(hi)
64def _update_axis_type_and_range(
65 fig, *, row: int, col: int, axis: str, centers: np.ndarray, is_log: bool
66):
67 """Set axis type and range to heatmap edges so tiles meet the axes exactly."""
68 lo, hi = _edges_from_centers(centers, is_log)
69 if axis == "x":
70 if is_log:
71 fig.update_xaxes(type="log", range=[np.log10(lo), np.log10(hi)], row=row, col=col)
72 else:
73 fig.update_xaxes(range=[lo, hi], row=row, col=col)
74 else:
75 if is_log:
76 fig.update_yaxes(type="log", range=[np.log10(lo), np.log10(hi)], row=row, col=col)
77 else:
78 fig.update_yaxes(range=[lo, hi], row=row, col=col)
81def plot2d(
82 model: xr.Dataset | Path | str,
83 output: Path | None = None,
84 grid_size: int = 70,
85 use_log_scale_for_target: bool = False, # log10 colors for heatmaps
86 log_shift_epsilon: float = 1e-9,
87 colorscale: str = "RdBu",
88 show: bool = False,
89 n_contours: int = 12,
90 optimal: bool = True,
91 suggest: int = 0,
92 seed: int|None = 42,
93 width:int|None = None,
94 height:int|None = None,
95 **kwargs,
96) -> go.Figure:
97 """
98 2D Partial Dependence of E[target|success] (pairwise features), including
99 categorical variables as single axes (one row/column per base).
101 Conditioning via kwargs (original units):
102 - numeric: fixed scalar, slice(lo,hi), list/tuple choices
103 - categorical base (e.g. language="Linear A" or language=("Linear A","Linear B")):
104 * single string: fixed to that label (axis removed and clamped)
105 * list/tuple of labels: restrict the categorical axis to those labels
107 Notes:
108 * one-hot member features (e.g. language=Linear A) never appear as axes.
109 * when a categorical axis is present, we render a heatmap over category index
110 with tick labels set to the category names; data overlays use jitter.
111 """
112 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model)
113 pred_success, pred_loss = _build_predictors(ds)
115 # --- features & transforms
116 feature_names = [str(x) for x in ds["feature"].values.tolist()]
117 transforms = [str(t) for t in ds["feature_transform"].values.tolist()]
118 X_mean = ds["feature_mean"].values.astype(float)
119 X_std = ds["feature_std"].values.astype(float)
120 name_to_idx = {nm: i for i, nm in enumerate(feature_names)}
122 # one-hot groups
123 groups = opt._onehot_groups(feature_names) # { base: {"labels":[...], "name_by_label":{label->member}, "members":[...] } }
124 bases = set(groups.keys())
125 onehot_member_names = {m for g in groups.values() for m in g["members"]}
127 # raw df + train design
128 df_raw = _raw_dataframe_from_dataset(ds)
129 Xn_train = ds["Xn_train"].values.astype(float)
130 n_rows = Xn_train.shape[0]
132 # --- split kwargs into numeric vs categorical (keys are canonical already when coming from CLI)
133 kw_num: dict[str, object] = {}
134 kw_cat: dict[str, object] = {}
135 for k, v in (kwargs or {}).items():
136 if k in bases:
137 kw_cat[k] = v
138 elif k in name_to_idx:
139 kw_num[k] = v
140 else:
141 # unknown/ignored
142 pass
144 # --- resolve categorical constraints:
145 # - cat_fixed[base] -> fixed single label (axis removed and clamped)
146 # - cat_allowed[base] -> labels that are allowed on that axis (if not fixed)
147 cat_fixed: dict[str, str] = {}
148 cat_allowed: dict[str, list[str]] = {}
149 for base in bases:
150 labels = list(groups[base]["labels"])
151 if base not in kw_cat:
152 cat_allowed[base] = labels # unrestricted axis (if not fixed numerically later)
153 continue
154 val = kw_cat[base]
155 if isinstance(val, str):
156 if val not in labels:
157 raise ValueError(f"Unknown category for {base!r}: {val!r}. Choices: {labels}")
158 cat_fixed[base] = val
159 elif isinstance(val, (list, tuple, set)):
160 chosen = [x for x in val if isinstance(x, str) and x in labels]
161 if not chosen:
162 raise ValueError(f"No valid categories for {base!r} in {val!r}. Choices: {labels}")
163 # if only one remains, treat as fixed; else allowed list
164 if len(chosen) == 1:
165 cat_fixed[base] = chosen[0]
166 else:
167 cat_allowed[base] = chosen
168 else:
169 raise ValueError(f"Categorical constraint for {base!r} must be a string or list/tuple of strings.")
171 # --- filter rows to categorical *fixed* selections for medians/percentiles & overlays
172 row_mask = np.ones(n_rows, dtype=bool)
173 for base, label in cat_fixed.items():
174 if base in df_raw.columns:
175 row_mask &= (df_raw[base].astype("string") == pd.Series([label]*len(df_raw), dtype="string")).to_numpy()
176 else:
177 member = groups[base]["name_by_label"][label]
178 j = name_to_idx[member]
179 raw_j = feature_raw_from_artifact_or_reconstruct(ds, j, member, transforms[j]).astype(float)
180 row_mask &= (raw_j >= 0.5)
182 df_raw_f = df_raw.loc[row_mask].reset_index(drop=True) if cat_fixed else df_raw
183 Xn_train_f = Xn_train[row_mask, :] if cat_fixed else Xn_train
185 # --- numeric constraints (standardized)
186 def _orig_to_std(j: int, x, transforms, mu, sd):
187 x = np.asarray(x, dtype=float)
188 if transforms[j] == "log10":
189 x = np.where(x <= 0, np.nan, x)
190 x = np.log10(x)
191 return (x - mu[j]) / sd[j]
193 fixed_scalars_std: dict[int, float] = {}
194 range_windows_std: dict[int, tuple[float, float]] = {}
195 choice_values_std: dict[int, np.ndarray] = {}
197 for name, val in kw_num.items():
198 j = name_to_idx[name]
199 if isinstance(val, slice):
200 lo = _orig_to_std(j, float(val.start), transforms, X_mean, X_std)
201 hi = _orig_to_std(j, float(val.stop), transforms, X_mean, X_std)
202 lo, hi = float(min(lo, hi)), float(max(lo, hi))
203 range_windows_std[j] = (lo, hi)
204 elif isinstance(val, (list, tuple, np.ndarray)):
205 arr = _orig_to_std(j, np.asarray(val, dtype=float), transforms, X_mean, X_std)
206 choice_values_std[j] = np.asarray(arr, dtype=float)
207 else:
208 fixed_scalars_std[j] = float(_orig_to_std(j, float(val), transforms, X_mean, X_std))
210 # --- apply categorical *fixed* selections as standardized 0/1 on their member features
211 for base, label in cat_fixed.items():
212 labels = groups[base]["labels"]
213 for lab in labels:
214 member = groups[base]["name_by_label"][lab]
215 j = name_to_idx[member]
216 raw_val = 1.0 if (lab == label) else 0.0
217 fixed_scalars_std[j] = float(_orig_to_std(j, raw_val, transforms, X_mean, X_std))
219 # --- free axes = numeric features not scalar-fixed & not one-hot members, plus categorical bases not fixed
220 free_numeric_idx = [
221 j for j, nm in enumerate(feature_names)
222 if (j not in fixed_scalars_std) and (nm not in onehot_member_names)
223 ]
224 free_cat_bases = [b for b in bases if b not in cat_fixed] # we already filtered by allowed above
226 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases]
227 if not panels:
228 raise ValueError("All features are fixed (or only single-category categoricals remain); nothing to plot.")
230 # --- base point (median in standardized space of filtered rows), then apply scalar fixes
231 base_std = np.median(Xn_train_f, axis=0)
232 for j, vstd in fixed_scalars_std.items():
233 base_std[j] = vstd
235 # --- per-feature grids (numeric) over filtered 1–99% + respecting ranges/choices
236 p01p99 = [np.percentile(Xn_train_f[:, j], [1, 99]) for j in range(len(feature_names))]
237 def _grid_std_num(j: int) -> np.ndarray:
238 p01, p99 = p01p99[j]
239 if j in choice_values_std:
240 vals = np.asarray(choice_values_std[j], dtype=float)
241 vals = vals[(vals >= p01) & (vals <= p99)]
242 return np.unique(np.sort(vals)) if vals.size else np.array([np.median(Xn_train_f[:, j])])
243 lo, hi = p01, p99
244 if j in range_windows_std:
245 rlo, rhi = range_windows_std[j]
246 lo, hi = max(lo, rlo), min(hi, rhi)
247 if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
248 hi = lo + 1e-9
249 return np.linspace(lo, hi, grid_size)
251 grids_std_num = {j: _grid_std_num(j) for j in free_numeric_idx}
253 # --- helpers for categorical evaluation ---------------------------------
254 def _std_for_member(member_name: str, raw01: float) -> float:
255 j = name_to_idx[member_name]
256 return float(_orig_to_std(j, raw01, transforms, X_mean, X_std))
258 def _apply_onehot_for_base(Xn_block: np.ndarray, base: str, label: str) -> None:
259 # set the whole block's rows to the 0/1 standardized values for this label
260 for lab in groups[base]["labels"]:
261 member = groups[base]["name_by_label"][lab]
262 j = name_to_idx[member]
263 Xn_block[:, j] = _std_for_member(member, 1.0 if lab == label else 0.0)
265 def _denorm_inv(j: int, std_vals: np.ndarray) -> np.ndarray:
266 internal = std_vals * X_std[j] + X_mean[j]
267 return _inverse_transform(transforms[j], internal)
269 # 1) Robustly detect one-hot member columns.
270 # Use both the detector output AND a fallback "base=" prefix scan,
271 # so any columns like "language=Linear A" are guaranteed to be excluded.
272 onehot_member_names: set[str] = set()
273 for base, g in groups.items():
274 # detector-known members
275 onehot_member_names.update(g["members"])
276 # prefix fallback
277 prefix = f"{base}="
278 onehot_member_names.update([nm for nm in feature_names if nm.startswith(prefix)])
280 # 2) Build panel list: keep numeric features that are not scalar-fixed AND
281 # are not one-hot members; plus categorical bases that are not fixed.
282 free_numeric_idx = [
283 j for j, nm in enumerate(feature_names)
284 if (j not in fixed_scalars_std) and (nm not in onehot_member_names)
285 ]
286 free_cat_bases = [b for b in bases if b not in cat_fixed]
288 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases]
289 if not panels:
290 raise ValueError("All features are fixed (or only single-category categoricals remain); nothing to plot.")
292 # 3) Sanity check: no one-hot member should survive as a numeric panel.
293 assert all(
294 (feature_names[key] not in onehot_member_names) if kind == "num" else True
295 for kind, key in panels
296 ), "internal: one-hot member leaked into numeric panels"
298 # 4) Subplot scaffold (matrix layout k x k) with clear titles.
299 def _panel_title(kind: str, key: object) -> str:
300 return feature_names[int(key)] if kind == "num" else str(key)
302 k = len(panels)
303 fig = make_subplots(
304 rows=k,
305 cols=k,
306 shared_xaxes=False,
307 shared_yaxes=False,
308 horizontal_spacing=0.03,
309 vertical_spacing=0.03,
310 subplot_titles=[_panel_title(kind, key) for kind, key in panels],
311 )
313 # (Keep the rest of your cell-evaluation and rendering logic unchanged.
314 # Because we filtered `onehot_member_names`, rows/columns like
315 # "language=Linear A" / "language=Linear B" will no longer appear.
316 # Categorical bases (e.g., "language") will show as a single axis.)
319 # overlays prepared under the SAME constraints (pass original kwargs straight through)
320 optimal_df = opt.optimal(ds, count=1, seed=seed, **kwargs) if optimal else None
321 suggest_df = opt.suggest(ds, count=suggest, seed=seed, **kwargs) if (suggest and suggest > 0) else None
323 # masks for data overlays (already filtered if cat_fixed)
324 tgt_col = str(ds.attrs["target"])
325 success_mask = ~pd.isna(df_raw_f[tgt_col]).to_numpy()
326 fail_mask = ~success_mask
328 # collect Z blocks for global color bounds
329 all_blocks: list[np.ndarray] = []
330 cell_payload: dict[tuple[int,int], dict] = {}
332 # --- build each cell payload (numeric/num, cat/num, num/cat, cat/cat)
333 for r, (kind_r, key_r) in enumerate(panels):
334 for c, (kind_c, key_c) in enumerate(panels):
335 # X axis = column; Y axis = row
336 if kind_r == "num" and kind_c == "num":
337 i = int(key_r); j = int(key_c)
338 xg = grids_std_num[j]; yg = grids_std_num[i]
339 if i == j:
340 grid = grids_std_num[j]
341 Xn_1d = np.repeat(base_std[None, :], len(grid), axis=0)
342 Xn_1d[:, j] = grid
343 mu_1d, _ = pred_loss(Xn_1d, include_observation_noise=True)
344 p_1d = pred_success(Xn_1d)
345 Zmu = 0.5 * (mu_1d[:, None] + mu_1d[None, :])
346 Zp = np.minimum(p_1d[:, None], p_1d[None, :])
347 x_orig = _denorm_inv(j, grid)
348 y_orig = x_orig
349 else:
350 XX, YY = np.meshgrid(xg, yg)
351 Xn_grid = np.repeat(base_std[None, :], XX.size, axis=0)
352 Xn_grid[:, j] = XX.ravel()
353 Xn_grid[:, i] = YY.ravel()
354 mu_flat, _ = pred_loss(Xn_grid, include_observation_noise=True)
355 p_flat = pred_success(Xn_grid)
356 Zmu = mu_flat.reshape(YY.shape)
357 Zp = p_flat.reshape(YY.shape)
358 x_orig = _denorm_inv(j, xg)
359 y_orig = _denorm_inv(i, yg)
360 cell_payload[(r, c)] = dict(kind=("num","num"), i=i, j=j, x=x_orig, y=y_orig, Zmu=Zmu, Zp=Zp)
362 elif kind_r == "cat" and kind_c == "num":
363 base = str(key_r); j = int(key_c)
364 labels = list(cat_allowed.get(base, groups[base]["labels"]))
365 xg = grids_std_num[j]
366 # build rows per label
367 Zmu_rows = []; Zp_rows = []
368 for lab in labels:
369 Xn_grid = np.repeat(base_std[None, :], len(xg), axis=0)
370 Xn_grid[:, j] = xg
371 _apply_onehot_for_base(Xn_grid, base, lab)
372 mu_row, _ = pred_loss(Xn_grid, include_observation_noise=True)
373 p_row = pred_success(Xn_grid)
374 Zmu_rows.append(mu_row[None, :])
375 Zp_rows.append(p_row[None, :])
376 Zmu = np.concatenate(Zmu_rows, axis=0) # (n_labels, n_x)
377 Zp = np.concatenate(Zp_rows, axis=0)
378 x_orig = _denorm_inv(j, xg)
379 y_cats = labels # categorical ticks
380 cell_payload[(r,c)] = dict(kind=("cat","num"), base=base, j=j, x=x_orig, y=y_cats, Zmu=Zmu, Zp=Zp)
382 elif kind_r == "num" and kind_c == "cat":
383 i = int(key_r); base = str(key_c)
384 labels = list(cat_allowed.get(base, groups[base]["labels"]))
385 yg = grids_std_num[i]
386 # columns per label
387 Zmu_cols = []; Zp_cols = []
388 for lab in labels:
389 Xn_grid = np.repeat(base_std[None, :], len(yg), axis=0)
390 Xn_grid[:, i] = yg
391 _apply_onehot_for_base(Xn_grid, base, lab)
392 mu_col, _ = pred_loss(Xn_grid, include_observation_noise=True)
393 p_col = pred_success(Xn_grid)
394 Zmu_cols.append(mu_col[:, None])
395 Zp_cols.append(p_col[:, None])
396 Zmu = np.concatenate(Zmu_cols, axis=1) # (n_y, n_labels)
397 Zp = np.concatenate(Zp_cols, axis=1)
398 x_cats = labels
399 y_orig = _denorm_inv(i, yg)
400 cell_payload[(r,c)] = dict(kind=("num","cat"), i=i, base=base, x=x_cats, y=y_orig, Zmu=Zmu, Zp=Zp)
402 else: # kind_r == "cat" and kind_c == "cat"
403 base_r = str(key_r); base_c = str(key_c)
404 labels_r = list(cat_allowed.get(base_r, groups[base_r]["labels"]))
405 labels_c = list(cat_allowed.get(base_c, groups[base_c]["labels"]))
406 Z = np.zeros((len(labels_r), len(labels_c)), dtype=float)
407 P = np.zeros_like(Z)
408 # evaluate each pair
409 for rr, lab_r in enumerate(labels_r):
410 for cc, lab_c in enumerate(labels_c):
411 Xn_grid = base_std[None, :].copy()
412 _apply_onehot_for_base(Xn_grid, base_r, lab_r)
413 _apply_onehot_for_base(Xn_grid, base_c, lab_c)
414 mu_val, _ = pred_loss(Xn_grid, include_observation_noise=True)
415 p_val = pred_success(Xn_grid)
416 Z[rr, cc] = float(mu_val[0])
417 P[rr, cc] = float(p_val[0])
418 cell_payload[(r,c)] = dict(kind=("cat","cat"), x=labels_c, y=labels_r, Zmu=Z, Zp=P)
420 all_blocks.append(cell_payload[(r,c)]["Zmu"].ravel())
422 # --- color transform bounds
423 def _color_xform(z_raw: np.ndarray) -> tuple[np.ndarray, float]:
424 if not use_log_scale_for_target:
425 return z_raw, 0.0
426 zmin = float(np.nanmin(z_raw))
427 shift = 0.0 if zmin > 0 else -zmin + float(log_shift_epsilon)
428 return np.log10(np.maximum(z_raw + shift, log_shift_epsilon)), shift
430 z_all = np.concatenate(all_blocks) if all_blocks else np.array([0.0, 1.0])
431 z_all_t, global_shift = _color_xform(z_all)
432 cmin_t = float(np.nanmin(z_all_t))
433 cmax_t = float(np.nanmax(z_all_t))
434 cs = get_colorscale(colorscale)
436 def _contour_line_color(level_raw: float) -> str:
437 zt = np.log10(max(level_raw + global_shift, log_shift_epsilon)) if use_log_scale_for_target else level_raw
438 t = 0.5 if cmax_t == cmin_t else (zt - cmin_t) / (cmax_t - cmin_t)
439 rgb = sample_colorscale(cs, [float(np.clip(t, 0.0, 1.0))])[0]
440 r, g, b = _rgb_string_to_tuple(rgb)
441 lum = (0.2126*r + 0.7152*g + 0.0722*b)/255.0
442 grey = int(round((1.0 - lum) * 255))
443 return f"rgba({grey},{grey},{grey},0.9)"
445 # --- render cells
446 def _is_log_feature(j: int) -> bool: return (transforms[j] == "log10")
448 for (r, c), PAY in cell_payload.items():
449 kind = PAY["kind"]; Zmu_raw = PAY["Zmu"]; Zp = PAY["Zp"]
450 Z_t, _ = _color_xform(Zmu_raw)
452 # axes values (numeric arrays or category indices)
453 if kind == ("num","num"):
454 x_vals = PAY["x"]; y_vals = PAY["y"]
455 fig.add_trace(go.Heatmap(
456 x=x_vals, y=y_vals, z=Z_t,
457 coloraxis="coloraxis", zsmooth=False, showscale=False,
458 hovertemplate=(f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>"
459 f"{feature_names[PAY['i']]}: %{{y:.6g}}"
460 "<br>E[target|success]: %{customdata:.3f}<extra></extra>"),
461 customdata=Zmu_raw
462 ), row=r+1, col=c+1)
464 # p(success) shading + contours
465 for thr, alpha in ((0.5, 0.25), (0.8, 0.40)):
466 mask = np.where(Zp < thr, 1.0, np.nan)
467 fig.add_trace(go.Heatmap(
468 x=x_vals, y=y_vals, z=mask, zmin=0, zmax=1,
469 colorscale=[[0, "rgba(0,0,0,0)"], [1, f"rgba(128,128,128,{alpha})"]],
470 showscale=False, hoverinfo="skip"
471 ), row=r+1, col=c+1)
473 # contour lines
474 zmin_r, zmax_r = float(np.nanmin(Zmu_raw)), float(np.nanmax(Zmu_raw))
475 levels = np.linspace(zmin_r, zmax_r, max(n_contours, 2))
476 for lev in levels:
477 color = _contour_line_color(lev)
478 fig.add_trace(go.Contour(
479 x=x_vals, y=y_vals, z=Zmu_raw,
480 autocontour=False,
481 contours=dict(coloring="lines", showlabels=False, start=lev, end=lev, size=1e-9),
482 line=dict(width=1),
483 colorscale=[[0, color], [1, color]],
484 showscale=False, hoverinfo="skip"
485 ), row=r+1, col=c+1)
487 # data overlays (success/fail)
488 def _data_vals_for_feature(j_full: int) -> np.ndarray:
489 nm = feature_names[j_full]
490 if nm in df_raw_f.columns:
491 return df_raw_f[nm].to_numpy().astype(float)
492 return feature_raw_from_artifact_or_reconstruct(ds, j_full, nm, transforms[j_full]).astype(float)[row_mask] \
493 if cat_fixed else \
494 feature_raw_from_artifact_or_reconstruct(ds, j_full, nm, transforms[j_full]).astype(float)
496 xd = _data_vals_for_feature(PAY["j"])
497 yd = _data_vals_for_feature(PAY["i"])
498 show_leg = (r == 0 and c == 0)
499 fig.add_trace(go.Scattergl(
500 x=xd[success_mask], y=yd[success_mask], mode="markers",
501 marker=dict(size=4, color="black", line=dict(width=0)),
502 name="data (success)", legendgroup="data_succ", showlegend=show_leg,
503 hovertemplate=("trial_id: %{customdata[0]}<br>"
504 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>"
505 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>"
506 f"{tgt_col}: %{{customdata[1]:.4f}}<extra></extra>"),
507 customdata=np.column_stack([
508 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask],
509 df_raw_f[tgt_col].to_numpy()[success_mask],
510 ])
511 ), row=r+1, col=c+1)
512 fig.add_trace(go.Scattergl(
513 x=xd[fail_mask], y=yd[fail_mask], mode="markers",
514 marker=dict(size=5, color="red", line=dict(color="black", width=0.8)),
515 name="data (failed)", legendgroup="data_fail", showlegend=show_leg,
516 hovertemplate=("trial_id: %{customdata}<br>"
517 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>"
518 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>"
519 "status: failed (NaN target)<extra></extra>"),
520 customdata=df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask]
521 ), row=r+1, col=c+1)
523 # overlays (optimal/suggest) on numeric axes only
524 if optimal and (optimal_df is not None):
525 if feature_names[PAY["j"]] in optimal_df.columns and feature_names[PAY["i"]] in optimal_df.columns:
526 ox = np.asarray(optimal_df[feature_names[PAY["j"]]].values, dtype=float)
527 oy = np.asarray(optimal_df[feature_names[PAY["i"]]].values, dtype=float)
528 if np.isfinite(ox).all() and np.isfinite(oy).all():
529 pmu = float(optimal_df["pred_target_mean"].values[0])
530 psd = float(optimal_df["pred_target_sd"].values[0])
531 fig.add_trace(go.Scattergl(
532 x=ox, y=oy, mode="markers",
533 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"),
534 name="optimal", legendgroup="optimal", showlegend=(r == 0 and c == 0),
535 hovertemplate=(f"predicted: {pmu:.2g} ± {psd:.2g}<br>"
536 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>"
537 f"{feature_names[PAY['i']]}: %{{y:.6g}}<extra></extra>")
538 ), row=r+1, col=c+1)
539 if suggest and (suggest_df is not None):
540 have = (feature_names[PAY["j"]] in suggest_df.columns) and (feature_names[PAY["i"]] in suggest_df.columns)
541 if have:
542 sx = np.asarray(suggest_df[feature_names[PAY["j"]]].values, dtype=float)
543 sy = np.asarray(suggest_df[feature_names[PAY["i"]]].values, dtype=float)
544 keep_s = np.isfinite(sx) & np.isfinite(sy)
545 if keep_s.any():
546 sx, sy = sx[keep_s], sy[keep_s]
547 mu_s = suggest_df.loc[keep_s, "pred_target_mean"].values if "pred_target_mean" in suggest_df else None
548 sd_s = suggest_df.loc[keep_s, "pred_target_sd"].values if "pred_target_sd" in suggest_df else None
549 ps_s = suggest_df.loc[keep_s, "pred_p_success"].values if "pred_p_success" in suggest_df else None
550 if (mu_s is not None) and (sd_s is not None) and (ps_s is not None):
551 custom_s = np.column_stack([mu_s, sd_s, ps_s])
552 hover_s = (
553 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>"
554 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>"
555 "pred: %{customdata[0]:.3g} ± %{customdata[1]:.3g}<br>"
556 "p(success): %{customdata[2]:.2f}<extra>suggested</extra>"
557 )
558 else:
559 custom_s = None
560 hover_s = (
561 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>"
562 f"{feature_names[PAY['i']]}: %{{y:.6g}}<extra>suggested</extra>"
563 )
564 fig.add_trace(go.Scattergl(
565 x=sx, y=sy, mode="markers",
566 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"),
567 name="suggested", legendgroup="suggested",
568 showlegend=(r == 0 and c == 0),
569 customdata=custom_s, hovertemplate=hover_s
570 ), row=r+1, col=c+1)
572 # axis types/ranges
573 _update_axis_type_and_range(fig, row=r+1, col=c+1, axis="x", centers=x_vals, is_log=_is_log_feature(PAY["j"]))
574 _update_axis_type_and_range(fig, row=r+1, col=c+1, axis="y", centers=y_vals, is_log=_is_log_feature(PAY["i"]))
576 elif kind == ("cat","num"):
577 base = PAY["base"]; x_vals = PAY["x"]; labels = PAY["y"]
578 nlab = len(labels)
579 # heatmap (categories on Y)
580 fig.add_trace(go.Heatmap(
581 x=x_vals, y=np.arange(nlab), z=Z_t,
582 coloraxis="coloraxis", zsmooth=False, showscale=False,
583 hovertemplate=(f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>"
584 f"{base}: %{{text}}"
585 "<br>E[target|success]: %{customdata:.3f}<extra></extra>"),
586 text=np.array(labels)[:, None].repeat(len(x_vals), axis=1),
587 customdata=Zmu_raw
588 ), row=r+1, col=c+1)
589 # p(success) shading
590 for thr, alpha in ((0.5, 0.25), (0.8, 0.40)):
591 mask = np.where(Zp < thr, 1.0, np.nan)
592 fig.add_trace(go.Heatmap(
593 x=x_vals, y=np.arange(nlab), z=mask, zmin=0, zmax=1,
594 colorscale=[[0, "rgba(0,0,0,0)"], [1, f"rgba(128,128,128,{alpha})"]],
595 showscale=False, hoverinfo="skip"
596 ), row=r+1, col=c+1)
597 # categorical ticks
598 fig.update_yaxes(tickmode="array", tickvals=list(range(nlab)), ticktext=labels, row=r+1, col=c+1)
599 # data overlays: numeric vs categorical with jitter on Y
600 if base in df_raw_f.columns and feature_names[PAY["j"]] in df_raw_f.columns:
601 cat_series = df_raw_f[base].astype("string")
602 cat_to_idx = {lab: i for i, lab in enumerate(labels)}
603 y_map = cat_series.map(cat_to_idx)
604 ok = y_map.notna().to_numpy()
605 y_idx = y_map.to_numpy(dtype=float)
606 jitter = 0.10 * (np.random.default_rng(0).standard_normal(size=len(y_idx)))
607 yj = y_idx + jitter
608 xd = df_raw_f[feature_names[PAY["j"]]].to_numpy(dtype=float)
609 show_leg = (r == 0 and c == 0)
610 fig.add_trace(go.Scattergl(
611 x=xd[success_mask & ok], y=yj[success_mask & ok], mode="markers",
612 marker=dict(size=4, color="black", line=dict(width=0)),
613 name="data (success)", legendgroup="data_succ", showlegend=show_leg,
614 hovertemplate=("trial_id: %{customdata[0]}<br>"
615 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>"
616 f"{base}: %{{customdata[1]}}<br>"
617 f"{tgt_col}: %{{customdata[2]:.4f}}<extra></extra>"),
618 customdata=np.column_stack([
619 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask & ok],
620 cat_series.to_numpy()[success_mask & ok],
621 df_raw_f[tgt_col].to_numpy()[success_mask & ok],
622 ])
623 ), row=r+1, col=c+1)
624 fig.add_trace(go.Scattergl(
625 x=xd[fail_mask & ok], y=yj[fail_mask & ok], mode="markers",
626 marker=dict(size=5, color="red", line=dict(color="black", width=0.8)),
627 name="data (failed)", legendgroup="data_fail", showlegend=show_leg,
628 hovertemplate=("trial_id: %{customdata[0]}<br>"
629 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>"
630 f"{base}: %{{customdata[1]}}<br>"
631 "status: failed (NaN target)<extra></extra>"),
632 customdata=np.column_stack([
633 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask & ok],
634 cat_series.to_numpy()[fail_mask & ok],
635 ])
636 ), row=r+1, col=c+1)
637 # axes: x numeric; y categorical range
638 _update_axis_type_and_range(fig, row=r+1, col=c+1, axis="x", centers=x_vals, is_log=_is_log_feature(PAY["j"]))
639 fig.update_yaxes(range=[-0.5, nlab - 0.5], row=r+1, col=c+1)
641 elif kind == ("num","cat"):
642 base = PAY["base"]; y_vals = PAY["y"]; labels = PAY["x"]
643 nlab = len(labels)
644 # heatmap (categories on X)
645 fig.add_trace(go.Heatmap(
646 x=np.arange(nlab), y=y_vals, z=Z_t,
647 coloraxis="coloraxis", zsmooth=False, showscale=False,
648 hovertemplate=(f"{base}: %{{text}}<br>"
649 f"{feature_names[PAY['i']]}: %{{y:.6g}}"
650 "<br>E[target|success]: %{customdata:.3f}<extra></extra>"),
651 text=np.array(labels)[None, :].repeat(len(y_vals), axis=0),
652 customdata=Zmu_raw
653 ), row=r+1, col=c+1)
654 for thr, alpha in ((0.5, 0.25), (0.8, 0.40)):
655 mask = np.where(Zp < thr, 1.0, np.nan)
656 fig.add_trace(go.Heatmap(
657 x=np.arange(nlab), y=y_vals, z=mask, zmin=0, zmax=1,
658 colorscale=[[0, "rgba(0,0,0,0)"], [1, f"rgba(128,128,128,{alpha})"]],
659 showscale=False, hoverinfo="skip"
660 ), row=r+1, col=c+1)
661 fig.update_xaxes(tickmode="array", tickvals=list(range(nlab)), ticktext=labels, row=r+1, col=c+1)
662 # data overlays with jitter on X
663 if base in df_raw_f.columns and feature_names[PAY["i"]] in df_raw_f.columns:
664 cat_series = df_raw_f[base].astype("string")
665 cat_to_idx = {lab: i for i, lab in enumerate(labels)}
666 x_map = cat_series.map(cat_to_idx)
667 ok = x_map.notna().to_numpy()
668 x_idx = x_map.to_numpy(dtype=float)
669 jitter = 0.10 * (np.random.default_rng(0).standard_normal(size=len(x_idx)))
670 xj = x_idx + jitter
671 yd = df_raw_f[feature_names[PAY["i"]]].to_numpy(dtype=float)
672 show_leg = (r == 0 and c == 0)
673 fig.add_trace(go.Scattergl(
674 x=xj[success_mask & ok], y=yd[success_mask & ok], mode="markers",
675 marker=dict(size=4, color="black", line=dict(width=0)),
676 name="data (success)", legendgroup="data_succ", showlegend=show_leg,
677 hovertemplate=("trial_id: %{customdata[0]}<br>"
678 f"{base}: %{{customdata[1]}}<br>"
679 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>"
680 f"{tgt_col}: %{{customdata[2]:.4f}}<extra></extra>"),
681 customdata=np.column_stack([
682 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask & ok],
683 cat_series.to_numpy()[success_mask & ok],
684 df_raw_f[tgt_col].to_numpy()[success_mask & ok],
685 ])
686 ), row=r+1, col=c+1)
687 fig.add_trace(go.Scattergl(
688 x=xj[fail_mask & ok], y=yd[fail_mask & ok], mode="markers",
689 marker=dict(size=5, color="red", line=dict(color="black", width=0.8)),
690 name="data (failed)", legendgroup="data_fail", showlegend=show_leg,
691 hovertemplate=("trial_id: %{customdata[0]}<br>"
692 f"{base}: %{{customdata[1]}}<br>"
693 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>"
694 "status: failed (NaN target)<extra></extra>"),
695 customdata=np.column_stack([
696 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask & ok],
697 cat_series.to_numpy()[fail_mask & ok],
698 ])
699 ), row=r+1, col=c+1)
700 # axes: x categorical; y numeric
701 fig.update_xaxes(range=[-0.5, nlab - 0.5], row=r+1, col=c+1)
702 _update_axis_type_and_range(fig, row=r+1, col=c+1, axis="y", centers=y_vals, is_log=_is_log_feature(PAY["i"]))
704 elif kind == ("cat","cat"):
705 labels_y = PAY["y"]
706 labels_x = PAY["x"]
707 ny, nx = len(labels_y), len(labels_x)
709 # Build customdata carrying (row_label, col_label) for hovertemplate.
710 custom = np.dstack((
711 np.array(labels_y, dtype=object)[:, None].repeat(nx, axis=1),
712 np.array(labels_x, dtype=object)[None, :].repeat(ny, axis=0),
713 ))
715 # Heatmap over categorical indices
716 fig.add_trace(go.Heatmap(
717 x=np.arange(nx),
718 y=np.arange(ny),
719 z=Z_t,
720 coloraxis="coloraxis",
721 zsmooth=False,
722 showscale=False,
723 hovertemplate=(
724 "row: %{customdata[0]}<br>"
725 "col: %{customdata[1]}<br>"
726 "E[target|success]: %{z:.3f}<extra></extra>"
727 ),
728 customdata=custom,
729 ), row=r+1, col=c+1)
731 # p(success) shading overlays
732 for thr, alpha in ((0.5, 0.25), (0.8, 0.40)):
733 mask = np.where(Zp < thr, 1.0, np.nan)
734 fig.add_trace(go.Heatmap(
735 x=np.arange(nx),
736 y=np.arange(ny),
737 z=mask,
738 zmin=0,
739 zmax=1,
740 colorscale=[[0, "rgba(0,0,0,0)"], [1, f"rgba(128,128,128,{alpha})"]],
741 showscale=False,
742 hoverinfo="skip",
743 ), row=r+1, col=c+1)
745 # Categorical tick labels on both axes
746 fig.update_xaxes(
747 tickmode="array",
748 tickvals=list(range(nx)),
749 ticktext=labels_x,
750 range=[-0.5, nx - 0.5],
751 row=r+1,
752 col=c+1,
753 )
754 fig.update_yaxes(
755 tickmode="array",
756 tickvals=list(range(ny)),
757 ticktext=labels_y,
758 range=[-0.5, ny - 0.5],
759 row=r+1,
760 col=c+1,
761 )
763 # --- outer axis labels
764 def _panel_title(kind: str, key: object) -> str:
765 return feature_names[int(key)] if kind == "num" else str(key)
767 for c, (_, key_c) in enumerate(panels):
768 fig.update_xaxes(title_text=_panel_title(panels[c][0], key_c), row=k, col=c+1)
769 for r, (kind_r, key_r) in enumerate(panels):
770 fig.update_yaxes(title_text=_panel_title(kind_r, key_r), row=r+1, col=1)
772 # --- title
773 def _fmt_c(v):
774 if isinstance(v, slice):
775 a = f"{v.start:g}" if v.start is not None else ""
776 b = f"{v.stop:g}" if v.stop is not None else ""
777 return f"[{a},{b}]"
778 if isinstance(v, (list, tuple, np.ndarray)):
779 try:
780 return "[" + ",".join(f"{float(x):g}" for x in np.asarray(v).tolist()) + "]"
781 except Exception:
782 return "[" + ",".join(map(str, v)) + "]"
783 return str(v)
785 title_parts = [f"2D partial dependence of expected {tgt_col}"]
787 # numeric constraints shown
788 for name, val in kw_num.items():
789 title_parts.append(f"{name}={_fmt_c(val)}")
790 # categorical constraints: fixed shown as base=Label; allowed ranges omitted in title
791 for base, lab in cat_fixed.items():
792 title_parts.append(f"{base}={lab}")
793 title = " — ".join([title_parts[0], ", ".join(title_parts[1:])]) if len(title_parts) > 1 else title_parts[0]
795 # --- layout
796 cell = 250
797 z_title = "E[target|success]" + (" (log10)" if use_log_scale_for_target else "")
798 if use_log_scale_for_target and global_shift > 0:
799 z_title += f" (shift Δ={global_shift:.3g})"
801 width = width if (width and width > 0) else cell * k
802 width = max(width, 400)
803 height = height if (height and height > 0) else cell * k
804 height = max(height, 400)
806 fig.update_layout(
807 template="simple_white",
808 width=width,
809 height=height,
810 title=title,
811 legend_title_text="",
812 coloraxis=dict(
813 colorscale=colorscale,
814 cmin=cmin_t, cmax=cmax_t,
815 colorbar=dict(
816 title=z_title,
817 thickness=10, # thinner bar
818 len=0.55, # shorter bar (fraction of plot height)
819 lenmode="fraction",
820 x=1.02, y=0.5, # just right of plot, vertically centered
821 xanchor="left", yanchor="middle",
822 ),
823 ),
824 legend=dict(
825 orientation="v",
826 x=1.02, xanchor="left", # to the right of the colorbar
827 y=1.0, yanchor="top",
828 bgcolor="rgba(255,255,255,0.85)"
829 ),
830 margin=dict(t=90, r=100), # room for title + legend + colorbar
831 )
833 if output:
834 output = Path(output)
835 output.parent.mkdir(parents=True, exist_ok=True)
836 fig.write_html(str(output), include_plotlyjs="cdn")
837 if show:
838 fig.show("browser")
839 return fig
842def plot1d(
843 model: xr.Dataset | Path | str,
844 output: Path | None = None,
845 csv_out: Path | None = None,
846 grid_size: int = 300,
847 line_color: str = "rgb(31,119,180)",
848 band_alpha: float = 0.25,
849 figure_height_per_row_px: int = 320,
850 show: bool = False,
851 use_log_scale_for_target_y: bool = True, # log-y for target
852 log_y_epsilon: float = 1e-9,
853 optimal: bool = True,
854 suggest: int = 0,
855 width:int|None = None,
856 height:int|None = None,
857 seed: int|None = 42,
858 **kwargs,
859) -> go.Figure:
860 """
861 Vertical 1D PD panels of E[target|success] vs each *free* feature.
862 Scalars (fix & hide), slices (restrict sweep & x-range), lists/tuples (discrete grids).
863 Categorical bases (e.g. language) are plotted as a single categorical subplot
864 when not fixed; passing --language "Linear A" fixes that base and removes it
865 from the plotted axes.
866 """
867 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model)
868 pred_success, pred_loss = _build_predictors(ds)
870 feature_names = [str(n) for n in ds["feature"].values.tolist()]
871 transforms = [str(t) for t in ds["feature_transform"].values.tolist()]
872 X_mean = ds["feature_mean"].values.astype(float)
873 X_std = ds["feature_std"].values.astype(float)
875 df_raw = _raw_dataframe_from_dataset(ds)
876 Xn_train = ds["Xn_train"].values.astype(float)
877 n_rows, p = Xn_train.shape
879 # --- one-hot categorical groups ---
880 groups = opt._onehot_groups(feature_names) # { base: {"labels":[...], "name_by_label":{label:member}, "members":[...]} }
881 bases = set(groups.keys())
882 name_to_idx = {name: j for j, name in enumerate(feature_names)}
884 # --- canonicalize kwargs: numeric vs categorical (base) ---
885 idx_map = _canon_key_set(ds)
886 kw_num_raw: dict[str, object] = {}
887 kw_cat_raw: dict[str, object] = {}
888 for k, v in kwargs.items():
889 if k in bases:
890 kw_cat_raw[k] = v
891 continue
892 if k in idx_map:
893 kw_num_raw[idx_map[k]] = v
894 continue
895 import re as _re
896 nk = _re.sub(r"[^a-z0-9]+", "", str(k).lower())
897 if nk in idx_map:
898 kw_num_raw[idx_map[nk]] = v
900 # --- resolve categorical constraints: fixed (single) vs allowed (multiple) ---
901 cat_fixed: dict[str, str] = {}
902 cat_allowed: dict[str, list[str]] = {}
903 for base, val in kw_cat_raw.items():
904 labels = groups[base]["labels"]
905 if isinstance(val, str):
906 if val not in labels:
907 raise ValueError(f"Unknown category for {base!r}: {val!r}. Choices: {labels}")
908 cat_fixed[base] = val
909 elif isinstance(val, (list, tuple, set)):
910 chosen = [x for x in val if isinstance(x, str) and x in labels]
911 if not chosen:
912 raise ValueError(f"No valid categories for {base!r} in {val!r}. Choices: {labels}")
913 # multiple -> treat as allowed subset (NOT fixed)
914 cat_allowed[base] = list(dict.fromkeys(chosen))
915 else:
916 raise ValueError(f"Categorical constraint for {base!r} must be a string or list/tuple of strings.")
918 # --- filter rows by fixed categoricals (affects medians/percentiles & overlays) ---
919 row_mask = np.ones(n_rows, dtype=bool)
920 for base, label in cat_fixed.items():
921 if base in df_raw.columns:
922 row_mask &= (df_raw[base].astype("string") == pd.Series([label]*len(df_raw), dtype="string")).to_numpy()
923 else:
924 member_name = groups[base]["name_by_label"][label]
925 j = name_to_idx[member_name]
926 raw_j = feature_raw_from_artifact_or_reconstruct(ds, j, member_name, transforms[j]).astype(float)
927 row_mask &= (raw_j >= 0.5)
929 df_raw_f = df_raw.loc[row_mask].reset_index(drop=True)
930 Xn_train_f = Xn_train[row_mask, :]
932 # --- helpers to transform original <-> standardized for feature j ---
933 def _orig_to_std(j: int, x, transforms, mu, sd):
934 x = np.asarray(x, dtype=float)
935 if transforms[j] == "log10":
936 x = np.where(x <= 0, np.nan, x)
937 x = np.log10(x)
938 return (x - mu[j]) / sd[j]
940 # --- numeric constraint split (STANDARDIZED) ---
941 fixed_scalars: dict[int, float] = {}
942 range_windows: dict[int, tuple[float, float]] = {}
943 choice_values: dict[int, np.ndarray] = {}
944 for name, val in kw_num_raw.items():
945 if name not in name_to_idx:
946 continue
947 j = name_to_idx[name]
948 if isinstance(val, slice):
949 lo = _orig_to_std(j, float(val.start), transforms, X_mean, X_std)
950 hi = _orig_to_std(j, float(val.stop), transforms, X_mean, X_std)
951 lo, hi = float(min(lo, hi)), float(max(lo, hi))
952 range_windows[j] = (lo, hi)
953 elif isinstance(val, (list, tuple, np.ndarray)):
954 arr = _orig_to_std(j, np.asarray(val, dtype=float), transforms, X_mean, X_std)
955 choice_values[j] = np.asarray(arr, dtype=float)
956 else:
957 fixed_scalars[j] = float(_orig_to_std(j, float(val), transforms, X_mean, X_std))
959 # --- apply categorical fixed as standardized scalar fixes on each one-hot member ---
960 for base, label in cat_fixed.items():
961 labels = groups[base]["labels"]
962 for lab in labels:
963 member_name = groups[base]["name_by_label"][lab]
964 j = name_to_idx[member_name]
965 raw_val = 1.0 if lab == label else 0.0
966 fixed_scalars[j] = float(_orig_to_std(j, raw_val, transforms, X_mean, X_std))
968 # --- overlays conditioned on the same kwargs (numeric + categorical) ---
969 optimal_df = opt.optimal(model, count=1, seed=seed, **kwargs) if optimal else None
970 suggest_df = opt.suggest(model, count=suggest, seed=seed, **kwargs) if (suggest and suggest > 0) else None
972 # --- base standardized point (median over filtered rows), then apply scalar fixes ---
973 base_std = np.median(Xn_train_f, axis=0)
974 for j, vstd in fixed_scalars.items():
975 base_std[j] = vstd
977 # --- plotted panels: numeric free features + categorical bases not fixed ---
978 onehot_members = set()
979 for base, g in groups.items():
980 onehot_members.update(g["members"])
981 free_numeric_idx = [j for j in range(p) if (j not in fixed_scalars) and (feature_names[j] not in onehot_members)]
982 free_cat_bases = [b for b in bases if b not in cat_fixed] # optional: filtered by cat_allowed later
984 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases]
985 if not panels:
986 raise ValueError("All features are fixed (or categorical only with single category chosen); nothing to plot.")
988 # --- empirical 1–99% from filtered rows for numeric bounds ---
989 p01p99 = [np.percentile(Xn_train_f[:, j], [1, 99]) for j in range(p)]
990 def _grid_1d(j: int, n: int) -> np.ndarray:
991 p01, p99 = p01p99[j]
992 if j in choice_values:
993 vals = np.asarray(choice_values[j], dtype=float)
994 vals = vals[(vals >= p01) & (vals <= p99)]
995 return np.unique(np.sort(vals)) if vals.size else np.array([np.median(Xn_train_f[:, j])], dtype=float)
996 lo, hi = p01, p99
997 if j in range_windows:
998 rlo, rhi = range_windows[j]
999 lo, hi = max(lo, rlo), min(hi, rhi)
1000 if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
1001 lo, hi = p01, max(p01 + 1e-9, p99)
1002 return np.linspace(lo, hi, n)
1004 # --- one-hot member names (robust) ---
1005 onehot_member_names: set[str] = set()
1006 for base, g in groups.items():
1007 # names recorded by the detector
1008 onehot_member_names.update(g["members"])
1009 # fallback pattern match in case detector missed anything
1010 prefix = f"{base}="
1011 onehot_member_names.update([nm for nm in feature_names if nm.startswith(prefix)])
1013 # --- build panel list: numeric free features + categorical bases (not fixed) ---
1014 free_numeric_idx = [
1015 j for j, nm in enumerate(feature_names)
1016 if (j not in fixed_scalars) and (nm not in onehot_member_names)
1017 ]
1018 free_cat_bases = [b for b in bases if b not in cat_fixed]
1020 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases]
1021 if not panels:
1022 raise ValueError("All features are fixed (or only single-category categoricals remain); nothing to plot.")
1024 # sanity: ensure we didn't accidentally keep any one-hot member columns
1025 assert all(
1026 (feature_names[key] not in onehot_member_names) if kind == "num" else True
1027 for kind, key in panels
1028 ), "internal: one-hot member leaked into numeric panels"
1030 # --- figure scaffold with clean titles ---
1031 def _panel_title(kind: str, key: object) -> str:
1032 return feature_names[int(key)] if kind == "num" else str(key)
1034 subplot_titles = [_panel_title(kind, key) for kind, key in panels]
1035 fig = make_subplots(
1036 rows=len(panels),
1037 cols=1,
1038 shared_xaxes=False,
1039 subplot_titles=subplot_titles,
1040 )
1042 # --- masks/data from filtered rows ---
1043 tgt_col = str(ds.attrs["target"])
1044 success_mask = ~pd.isna(df_raw_f[tgt_col]).to_numpy()
1045 fail_mask = ~success_mask
1046 losses_success = df_raw_f.loc[success_mask, tgt_col].to_numpy().astype(float)
1047 trial_ids_success = df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask]
1048 trial_ids_fail = df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask]
1049 band_fill_rgba = _rgb_to_rgba(line_color, band_alpha)
1051 tidy_rows: list[dict] = []
1053 row_pos = 0
1054 for kind, key in panels:
1055 row_pos += 1
1057 if kind == "num":
1058 j = key
1059 grid = _grid_1d(j, grid_size)
1060 Xn_grid = np.repeat(base_std[None, :], len(grid), axis=0)
1061 Xn_grid[:, j] = grid
1063 # # --- DEBUG: confirm the feature is actually changing in standardized space ---
1064 # print(f"[{feature_names[j]}] std grid head: {grid[:6]}")
1065 # print(f"[{feature_names[j]}] std grid ptp (range): {np.ptp(grid)}")
1066 # print(f"[{feature_names[j]}] Xn_grid[:2, j]: {Xn_grid[:2, j]}")
1067 # print(f"[{feature_names[j]}] Xn 1–99%: {p01p99[j]}")
1069 p_grid = pred_success(Xn_grid)
1070 mu_grid, sd_grid = pred_loss(Xn_grid, include_observation_noise=True)
1071 # print(feature_names[j], "mu range:", float(np.ptp(mu_grid)))
1073 x_internal = grid * X_std[j] + X_mean[j]
1074 x_display = _inverse_transform(transforms[j], x_internal)
1076 # print(f"[{feature_names[j]}] orig head: {x_display[:6]}")
1077 # print(f"[{feature_names[j]}] orig ptp (range): {np.ptp(x_display)}")
1079 if use_log_scale_for_target_y:
1080 mu_plot = np.maximum(mu_grid, log_y_epsilon)
1081 lo_plot = np.maximum(mu_grid - 2.0 * sd_grid, log_y_epsilon)
1082 hi_plot = np.maximum(mu_grid + 2.0 * sd_grid, log_y_epsilon)
1083 losses_s_plot = np.maximum(losses_success, log_y_epsilon) if losses_success.size else losses_success
1084 else:
1085 mu_plot = mu_grid
1086 lo_plot = mu_grid - 2.0 * sd_grid
1087 hi_plot = mu_grid + 2.0 * sd_grid
1088 losses_s_plot = losses_success
1090 y_arrays = [lo_plot, hi_plot] + ([losses_s_plot] if losses_s_plot.size else [])
1091 y_low = float(np.nanmin([np.nanmin(a) for a in y_arrays]))
1092 y_high = float(np.nanmax([np.nanmax(a) for a in y_arrays]))
1093 pad = 0.05 * (y_high - y_low + 1e-12)
1094 y0_plot = (y_low - pad) if not use_log_scale_for_target_y else max(y_low / 1.5, log_y_epsilon)
1095 y1_tmp = (y_high + pad) if not use_log_scale_for_target_y else y_high * 1.2
1096 y_failed_band = y1_tmp + (y_high - y_low + 1e-12) * (0.08 if not use_log_scale_for_target_y else 0.3)
1097 if use_log_scale_for_target_y and y_failed_band <= log_y_epsilon:
1098 y_failed_band = max(10.0 * log_y_epsilon, y_high * 2.0)
1099 y1_plot = y_failed_band + (0.02 if not use_log_scale_for_target_y else 0.05) * (y_high - y_low + 1e-12)
1101 _add_low_success_shading_1d(fig, row_pos, x_display, p_grid, y0_plot, y1_plot)
1103 show_legend = (row_pos == 1)
1104 fig.add_trace(go.Scatter(x=x_display, y=lo_plot, mode="lines",
1105 line=dict(width=0, color=line_color),
1106 name="±2σ", legendgroup="band", showlegend=False, hoverinfo="skip"),
1107 row=row_pos, col=1)
1108 fig.add_trace(go.Scatter(x=x_display, y=hi_plot, mode="lines", fill="tonexty",
1109 line=dict(width=0, color=line_color), fillcolor=band_fill_rgba,
1110 name="±2σ", legendgroup="band", showlegend=show_legend,
1111 hovertemplate="E[target|success]: %{y:.3f}<extra>±2σ</extra>"),
1112 row=row_pos, col=1)
1113 fig.add_trace(go.Scatter(x=x_display, y=mu_plot, mode="lines",
1114 line=dict(width=2, color=line_color),
1115 name="E[target|success]", legendgroup="mean", showlegend=show_legend,
1116 hovertemplate=f"{feature_names[j]}: %{{x:.6g}}<br>E[target|success]: %{{y:.3f}}<extra></extra>"),
1117 row=row_pos, col=1)
1119 # experimental points
1120 if feature_names[j] in df_raw_f.columns:
1121 x_data_all = df_raw_f[feature_names[j]].to_numpy().astype(float)
1122 else:
1123 full_vals = feature_raw_from_artifact_or_reconstruct(ds, j, feature_names[j], transforms[j]).astype(float)
1124 x_data_all = full_vals[row_mask]
1126 x_succ = x_data_all[success_mask]
1127 if x_succ.size:
1128 fig.add_trace(go.Scattergl(
1129 x=x_succ, y=losses_s_plot, mode="markers",
1130 marker=dict(size=5, color="black", line=dict(width=0)),
1131 name="data (success)", legendgroup="data_s", showlegend=show_legend,
1132 hovertemplate=("trial_id: %{customdata}<br>"
1133 f"{feature_names[j]}: %{{x:.6g}}<br>"
1134 f"{tgt_col}: %{{y:.4f}}<extra></extra>"),
1135 customdata=trial_ids_success
1136 ), row=row_pos, col=1)
1138 x_fail = x_data_all[fail_mask]
1139 if x_fail.size:
1140 y_fail_plot = np.full_like(x_fail, y_failed_band, dtype=float)
1141 fig.add_trace(go.Scattergl(
1142 x=x_fail, y=y_fail_plot, mode="markers",
1143 marker=dict(size=6, color="red", line=dict(color="black", width=0.8)),
1144 name="data (failed)", legendgroup="data_f", showlegend=show_legend,
1145 hovertemplate=("trial_id: %{customdata}<br>"
1146 f"{feature_names[j]}: %{{x:.6g}}<br>"
1147 "status: failed (NaN target)<extra></extra>"),
1148 customdata=trial_ids_fail
1149 ), row=row_pos, col=1)
1151 # overlays
1152 if optimal_df is not None and feature_names[j] in optimal_df.columns:
1153 x_opt = optimal_df[feature_names[j]].values
1154 y_opt = optimal_df["pred_target_mean"].values
1155 y_sd = optimal_df["pred_target_sd"].values
1156 fig.add_trace(go.Scattergl(
1157 x=x_opt, y=y_opt, mode="markers",
1158 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"),
1159 name="optimal", legendgroup="optimal", showlegend=show_legend,
1160 hovertemplate=(f"predicted: %{{y:.3g}} ± {y_sd[0]:.3g}<br>"
1161 f"{feature_names[j]}: %{{x:.6g}}<extra></extra>")
1162 ), row=row_pos, col=1)
1164 if suggest_df is not None and feature_names[j] in suggest_df.columns:
1165 x_sug = suggest_df[feature_names[j]].values
1166 y_sug = suggest_df["pred_target_mean"].values
1167 y_sd = suggest_df["pred_target_sd"].values
1168 fig.add_trace(go.Scattergl(
1169 x=x_sug, y=y_sug, mode="markers",
1170 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"),
1171 name="suggested", legendgroup="suggested", showlegend=show_legend,
1172 hovertemplate=(f"predicted: %{{y:.3g}} ± {{y_sd:.3g}}<br>"
1173 f"{feature_names[j]}: %{{x:.6g}}<extra></extra>")
1174 ), row=row_pos, col=1)
1176 # axes
1177 _maybe_log_axis(fig, row_pos, 1, feature_names[j], axis="x", transforms=transforms, j=j)
1178 fig.update_yaxes(title_text=f"{tgt_col}", row=row_pos, col=1)
1179 _set_yaxis_range(fig, row=row_pos, col=1,
1180 y0=y0_plot, y1=y1_plot,
1181 log=use_log_scale_for_target_y, eps=log_y_epsilon)
1183 # restrict x-range if constrained
1184 is_log_x = (transforms[j] == "log10")
1185 def _std_to_orig(val_std: float) -> float:
1186 vi = val_std * X_std[j] + X_mean[j]
1187 return float(_inverse_transform(transforms[j], np.array([vi]))[0])
1189 x_min_override = x_max_override = None
1190 if j in range_windows:
1191 lo_std, hi_std = range_windows[j]
1192 x_min_override = min(_std_to_orig(lo_std), _std_to_orig(hi_std))
1193 x_max_override = max(_std_to_orig(lo_std), _std_to_orig(hi_std))
1194 elif j in choice_values:
1195 ints = choice_values[j] * X_std[j] + X_mean[j]
1196 origs = _inverse_transform(transforms[j], ints)
1197 x_min_override = float(np.min(origs))
1198 x_max_override = float(np.max(origs))
1200 if (x_min_override is not None) and (x_max_override is not None):
1201 if is_log_x:
1202 x0 = max(x_min_override, 1e-12)
1203 x1 = max(x_max_override, x0 * (1 + 1e-9))
1204 pad = (x1 / x0) ** 0.03
1205 fig.update_xaxes(type="log",
1206 range=[np.log10(x0 / pad), np.log10(x1 * pad)],
1207 row=row_pos, col=1)
1208 else:
1209 span = (x_max_override - x_min_override) or 1.0
1210 pad = 0.02 * span
1211 fig.update_xaxes(range=[x_min_override - pad, x_max_override + pad],
1212 row=row_pos, col=1)
1214 fig.update_xaxes(title_text=feature_names[j], row=row_pos, col=1)
1216 # tidy rows
1217 for xd, xi, mu_i, sd_i, p_i in zip(x_display, x_internal, mu_grid, sd_grid, p_grid):
1218 tidy_rows.append({
1219 "feature": feature_names[j],
1220 "x_display": float(xd),
1221 "x_internal": float(xi),
1222 "target_conditional_mean": float(mu_i),
1223 "target_conditional_sd": float(sd_i),
1224 "success_probability": float(p_i),
1225 })
1227 else:
1228 base = key # categorical base
1229 labels_all = groups[base]["labels"]
1230 labels = cat_allowed.get(base, labels_all)
1232 # Build standardized design for each label at the base point
1233 Xn_grid = np.repeat(base_std[None, :], len(labels), axis=0)
1234 for r, lab in enumerate(labels):
1235 for lab2 in labels_all:
1236 member_name = groups[base]["name_by_label"][lab2]
1237 j2 = name_to_idx[member_name]
1238 raw_val = 1.0 if (lab2 == lab) else 0.0
1239 # standardized set:
1240 Xi = (raw_val - X_mean[j2]) / X_std[j2]
1241 Xn_grid[r, j2] = Xi
1243 p_vec = pred_success(Xn_grid)
1244 mu_vec, sd_vec = pred_loss(Xn_grid, include_observation_noise=True)
1245 print(feature_names[j], "mu range:", float(np.ptp(mu_grid)))
1247 # y transform
1248 if use_log_scale_for_target_y:
1249 mu_plot = np.maximum(mu_vec, log_y_epsilon)
1250 lo_plot = np.maximum(mu_vec - 2.0 * sd_vec, log_y_epsilon)
1251 hi_plot = np.maximum(mu_vec + 2.0 * sd_vec, log_y_epsilon)
1252 losses_s_plot = np.maximum(df_raw_f.loc[success_mask, tgt_col].to_numpy().astype(float), log_y_epsilon) if success_mask.any() else np.array([])
1253 else:
1254 mu_plot = mu_vec
1255 lo_plot = mu_vec - 2.0 * sd_vec
1256 hi_plot = mu_vec + 2.0 * sd_vec
1257 losses_s_plot = df_raw_f.loc[success_mask, tgt_col].to_numpy().astype(float) if success_mask.any() else np.array([])
1259 # y-range
1260 y_arrays = [lo_plot, hi_plot] + ([losses_s_plot] if losses_s_plot.size else [])
1261 y_low = float(np.nanmin([np.nanmin(a) for a in y_arrays])) if y_arrays else 0.0
1262 y_high = float(np.nanmax([np.nanmax(a) for a in y_arrays])) if y_arrays else 1.0
1263 pad = 0.05 * (y_high - y_low + 1e-12)
1264 y0_plot = (y_low - pad) if not use_log_scale_for_target_y else max(y_low / 1.5, log_y_epsilon)
1265 y1_tmp = (y_high + pad) if not use_log_scale_for_target_y else y_high * 1.2
1266 y_failed_band = y1_tmp + (y_high - y_low + 1e-12) * (0.08 if not use_log_scale_for_target_y else 0.3)
1267 if use_log_scale_for_target_y and y_failed_band <= log_y_epsilon:
1268 y_failed_band = max(10.0 * log_y_epsilon, y_high * 2.0)
1269 y1_plot = y_failed_band + (0.02 if not use_log_scale_for_target_y else 0.05) * (y_high - y_low + 1e-12)
1271 # x positions are 0..K-1 with tick labels = category names
1272 x_pos = np.arange(len(labels), dtype=float)
1274 # shading per-category threshold regions using shapes
1275 def _shade_for_thresh(thr: float, alpha: float):
1276 for k_i, p_i in enumerate(p_vec):
1277 if p_i < thr:
1278 fig.add_shape(
1279 type="rect",
1280 xref=f"x{'' if row_pos==1 else row_pos}",
1281 yref=f"y{'' if row_pos==1 else row_pos}",
1282 x0=k_i - 0.5, x1=k_i + 0.5,
1283 y0=y0_plot, y1=y1_plot,
1284 line=dict(width=0),
1285 fillcolor=f"rgba(128,128,128,{alpha})",
1286 layer="below",
1287 row=row_pos, col=1
1288 )
1289 _shade_for_thresh(0.8, 0.40)
1290 _shade_for_thresh(0.5, 0.25)
1292 show_legend = (row_pos == 1)
1294 # mean with error bars (±2σ)
1295 fig.add_trace(go.Scatter(
1296 x=x_pos, y=mu_plot, mode="lines+markers",
1297 line=dict(width=2, color=line_color),
1298 marker=dict(size=7, color=line_color),
1299 error_y=dict(type="data", array=(hi_plot - mu_plot), arrayminus=(mu_plot - lo_plot), visible=True),
1300 name="E[target|success]", legendgroup="mean", showlegend=show_legend,
1301 hovertemplate=(f"{base}: %{{text}}<br>E[target|success]: %{{y:.3f}}"
1302 "<br>±2σ shown as error bar<extra></extra>"),
1303 text=labels
1304 ), row=row_pos, col=1)
1306 # experimental points: map each row's label to index
1307 if base in df_raw_f.columns:
1308 lab_series = df_raw_f[base].astype("string")
1309 else:
1310 # reconstruct from one-hot members
1311 member_cols = [groups[base]["name_by_label"][lab] for lab in labels_all]
1312 idx_max = df_raw_f[member_cols].to_numpy().argmax(axis=1)
1313 lab_series = pd.Series([labels_all[i] for i in idx_max], dtype="string")
1315 label_to_idx = {lab: i for i, lab in enumerate(labels)}
1316 x_idx_all = lab_series.map(lambda s: label_to_idx.get(str(s), np.nan)).to_numpy(dtype=float)
1317 x_idx_succ = x_idx_all[success_mask]
1318 x_idx_fail = x_idx_all[fail_mask]
1320 # jitter for visibility
1321 rng = np.random.default_rng(0)
1322 jitter = lambda n: (rng.random(n) - 0.5) * 0.15
1324 if x_idx_succ.size:
1325 fig.add_trace(go.Scattergl(
1326 x=x_idx_succ + jitter(x_idx_succ.size),
1327 y=losses_s_plot,
1328 mode="markers",
1329 marker=dict(size=5, color="black", line=dict(width=0)),
1330 name="data (success)", legendgroup="data_s", showlegend=show_legend,
1331 hovertemplate=("trial_id: %{customdata}<br>"
1332 f"{base}: %{{text}}<br>"
1333 f"{tgt_col}: %{{y:.4f}}<extra></extra>"),
1334 text=[labels[int(i)] if np.isfinite(i) and int(i) < len(labels) else "?" for i in x_idx_succ],
1335 customdata=trial_ids_success
1336 ), row=row_pos, col=1)
1338 if x_idx_fail.size:
1339 y_fail_plot = np.full_like(x_idx_fail, y_failed_band, dtype=float)
1340 fig.add_trace(go.Scattergl(
1341 x=x_idx_fail + jitter(x_idx_fail.size), y=y_fail_plot, mode="markers",
1342 marker=dict(size=6, color="red", line=dict(color="black", width=0.8)),
1343 name="data (failed)", legendgroup="data_f", showlegend=show_legend,
1344 hovertemplate=("trial_id: %{customdata}<br>"
1345 f"{base}: %{{text}}<br>"
1346 "status: failed (NaN target)<extra></extra>"),
1347 text=[labels[int(i)] if np.isfinite(i) and int(i) < len(labels) else "?" for i in x_idx_fail],
1348 customdata=trial_ids_fail
1349 ), row=row_pos, col=1)
1351 # overlays for categorical base: map label to x index
1352 if optimal_df is not None and (base in optimal_df.columns):
1353 lab_opt = str(optimal_df[base].values[0])
1354 if lab_opt in label_to_idx:
1355 x_opt = [float(label_to_idx[lab_opt])]
1356 y_opt = optimal_df["pred_target_mean"].values
1357 y_sd = optimal_df["pred_target_sd"].values
1358 fig.add_trace(go.Scattergl(
1359 x=x_opt, y=y_opt, mode="markers",
1360 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"),
1361 name="optimal", legendgroup="optimal", showlegend=show_legend,
1362 hovertemplate=(f"predicted: %{{y:.3g}} ± {y_sd[0]:.3g}<br>"
1363 f"{base}: {lab_opt}<extra></extra>")
1364 ), row=row_pos, col=1)
1366 if suggest_df is not None and (base in suggest_df.columns):
1367 labs_sug = suggest_df[base].astype(str).tolist()
1368 xs = [label_to_idx[l] for l in labs_sug if l in label_to_idx]
1369 if xs:
1370 keep_mask = [l in label_to_idx for l in labs_sug]
1371 y_sug = suggest_df.loc[keep_mask, "pred_target_mean"].values
1372 fig.add_trace(go.Scattergl(
1373 x=np.array(xs, dtype=float), y=y_sug, mode="markers",
1374 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"),
1375 name="suggested", legendgroup="suggested", showlegend=show_legend,
1376 hovertemplate=(f"{base}: %{{text}}<br>"
1377 "predicted: %{{y:.3g}}<extra>suggested</extra>"),
1378 text=[labels[int(i)] for i in xs]
1379 ), row=row_pos, col=1)
1381 # axes: categorical ticks
1382 fig.update_xaxes(
1383 tickmode="array",
1384 tickvals=x_pos.tolist(),
1385 ticktext=labels,
1386 row=row_pos, col=1
1387 )
1388 fig.update_yaxes(title_text=f"{tgt_col}", row=row_pos, col=1)
1389 _set_yaxis_range(fig, row=row_pos, col=1,
1390 y0=y0_plot, y1=y1_plot,
1391 log=use_log_scale_for_target_y, eps=log_y_epsilon)
1392 fig.update_xaxes(title_text=base, row=row_pos, col=1)
1394 # tidy rows
1395 for lab, mu_i, sd_i, p_i in zip(labels, mu_vec, sd_vec, p_vec):
1396 tidy_rows.append({
1397 "feature": base,
1398 "x_display": str(lab),
1399 "x_internal": float("nan"),
1400 "target_conditional_mean": float(mu_i),
1401 "target_conditional_sd": float(sd_i),
1402 "success_probability": float(p_i),
1403 })
1405 # title w/ constraints summary
1406 def _fmt_c(v):
1407 if isinstance(v, slice):
1408 a = "" if v.start is None else f"{v.start:g}"
1409 b = "" if v.stop is None else f"{v.stop:g}"
1410 return f"[{a},{b}]"
1411 if isinstance(v, (list, tuple, np.ndarray)):
1412 try:
1413 return "[" + ",".join(f"{float(x):g}" for x in np.asarray(v).tolist()) + "]"
1414 except Exception:
1415 return "[" + ",".join(map(str, v)) + "]"
1416 try:
1417 return f"{float(v):g}"
1418 except Exception:
1419 return str(v)
1421 parts = [f"1D partial dependence of expected {tgt_col}"]
1422 if kw_num_raw:
1423 parts.append(", ".join(f"{k}={_fmt_c(v)}" for k, v in kw_num_raw.items()))
1424 if cat_fixed:
1425 parts.append(", ".join(f"{b}={lab}" for b, lab in cat_fixed.items()))
1426 if cat_allowed:
1427 parts.append(", ".join(f"{b}∈{{{', '.join(v)}}}" for b, v in cat_allowed.items()))
1428 title = " — ".join(parts) if len(parts) > 1 else parts[0]
1430 width = width if (width and width > 0) else 1200
1431 height = height if (height and height > 0) else figure_height_per_row_px * len(panels)
1433 fig.update_layout(
1434 height=height,
1435 width=width,
1436 template="simple_white",
1437 title=title,
1438 legend_title_text=""
1439 )
1441 if output:
1442 output = Path(output)
1443 output.parent.mkdir(parents=True, exist_ok=True)
1444 fig.write_html(str(output), include_plotlyjs="cdn")
1445 if csv_out:
1446 csv_out = Path(csv_out)
1447 csv_out.parent.mkdir(parents=True, exist_ok=True)
1448 pd.DataFrame(tidy_rows).to_csv(str(csv_out), index=False)
1449 if show:
1450 fig.show("browser")
1452 return fig
1455# =============================================================================
1456# Helpers: dataset → predictors & featurization
1457# =============================================================================
1458def _build_predictors(ds: xr.Dataset):
1459 """Reconstruct fast GP predictors from the artifact using shared helpers."""
1460 # Training matrices / targets
1461 Xn_all = ds["Xn_train"].values.astype(float) # (N, p)
1462 y_success = ds["y_success"].values.astype(float) # (N,)
1463 Xn_ok = ds["Xn_success_only"].values.astype(float) # (Ns, p)
1464 y_loss_centered = ds["y_loss_centered"].values.astype(float)
1466 # Compatibility: conditional_loss_mean may be a var or an attr
1467 cond_mean = (
1468 float(ds["conditional_loss_mean"].values)
1469 if "conditional_loss_mean" in ds
1470 else float(ds.attrs.get("conditional_loss_mean", 0.0))
1471 )
1473 # Success head MAP params
1474 ell_s = ds["map_success_ell"].values.astype(float) # (p,)
1475 eta_s = float(ds["map_success_eta"].values)
1476 sigma_s = float(ds["map_success_sigma"].values)
1477 beta0_s = float(ds["map_success_beta0"].values)
1479 # Loss head MAP params
1480 ell_l = ds["map_loss_ell"].values.astype(float) # (p,)
1481 eta_l = float(ds["map_loss_eta"].values)
1482 sigma_l = float(ds["map_loss_sigma"].values)
1483 mean_c = float(ds["map_loss_mean_const"].values)
1485 # --- Cholesky precomputations (success) ---
1486 K_s = kernel_m52_ard(Xn_all, Xn_all, ell_s, eta_s) + (sigma_s**2) * np.eye(Xn_all.shape[0])
1487 L_s = np.linalg.cholesky(add_jitter(K_s))
1488 alpha_s = solve_chol(L_s, (y_success - beta0_s))
1490 # --- Cholesky precomputations (loss | success) ---
1491 K_l = kernel_m52_ard(Xn_ok, Xn_ok, ell_l, eta_l) + (sigma_l**2) * np.eye(Xn_ok.shape[0])
1492 L_l = np.linalg.cholesky(add_jitter(K_l))
1493 alpha_l = solve_chol(L_l, (y_loss_centered - mean_c))
1495 def predict_success_probability(Xn: np.ndarray) -> np.ndarray:
1496 Ks = kernel_m52_ard(Xn, Xn_all, ell_s, eta_s)
1497 mu = beta0_s + Ks @ alpha_s
1498 return np.clip(mu, 0.0, 1.0)
1500 def predict_conditional_target(
1501 Xn: np.ndarray,
1502 include_observation_noise: bool = True
1503 ):
1504 Kl = kernel_m52_ard(Xn, Xn_ok, ell_l, eta_l)
1505 mu_centered = mean_c + Kl @ alpha_l
1506 mu = mu_centered + cond_mean
1508 # diag predictive variance
1509 v = solve_lower(L_l, Kl.T) # (Ns, Nt)
1510 var = kernel_diag_m52(Xn, ell_l, eta_l) - np.sum(v * v, axis=0)
1511 var = np.maximum(var, 1e-12)
1512 if include_observation_noise:
1513 var = var + sigma_l**2
1514 sd = np.sqrt(var)
1515 return mu, sd
1517 return predict_success_probability, predict_conditional_target
1520def _raw_dataframe_from_dataset(ds: xr.Dataset) -> pd.DataFrame:
1521 """Collect raw columns from the artifact into a DataFrame for plotting."""
1522 cols = {}
1523 for name in ds.data_vars:
1524 # include only row-aligned arrays
1525 da = ds[name]
1526 if "row" in da.dims and len(da.dims) == 1 and da.sizes["row"] == ds.sizes["row"]:
1527 cols[name] = da.values
1528 # Ensure trial_id exists for hover
1529 if "trial_id" not in cols:
1530 cols["trial_id"] = np.arange(ds.sizes["row"], dtype=int)
1531 return pd.DataFrame(cols)
1534def _apply_fixed_to_base(
1535 base_std: np.ndarray,
1536 fixed: dict[str, float],
1537 feature_names: list[str],
1538 transforms: list[str],
1539 X_mean: np.ndarray,
1540 X_std: np.ndarray,
1541) -> np.ndarray:
1542 """Override base point in standardized space with fixed ORIGINAL values."""
1543 out = base_std.copy()
1544 name_to_idx = {n: i for i, n in enumerate(feature_names)}
1545 for k, v in fixed.items():
1546 if k not in name_to_idx:
1547 raise KeyError(f"Fixed variable '{k}' is not a model feature.")
1548 j = name_to_idx[k]
1549 x_raw = _forward_transform(transforms[j], float(v))
1550 out[j] = (x_raw - X_mean[j]) / X_std[j]
1551 return out
1554def _denormalize_then_inverse_transform(j: int, x_std: np.ndarray, transforms, X_mean, X_std) -> np.ndarray:
1555 x_raw = x_std * X_std[j] + X_mean[j]
1556 return _inverse_transform(transforms[j], x_raw)
1559def _forward_transform(tr: str, x: float | np.ndarray) -> np.ndarray:
1560 if tr == "log10":
1561 x = np.asarray(x, dtype=float)
1562 return np.log10(np.maximum(x, 1e-12))
1563 return np.asarray(x, dtype=float)
1566def _inverse_transform(tr: str, x: np.ndarray) -> np.ndarray:
1567 if tr == "log10":
1568 return 10.0 ** x
1569 return x
1572def _maybe_log_axis(fig: go.Figure, row: int, col: int, name: str, axis: str = "x", transforms: list[str] | None = None, j: int | None = None):
1573 """Use log axis for features that were log10-transformed."""
1574 use_log = False
1575 if transforms is not None and j is not None:
1576 use_log = (transforms[j] == "log10")
1577 else:
1578 use_log = ("learning_rate" in name.lower() or name.lower() == "lr")
1579 if use_log:
1580 if axis == "x":
1581 fig.update_xaxes(type="log", row=row, col=col)
1582 else:
1583 fig.update_yaxes(type="log", row=row, col=col)
1586def _rgb_string_to_tuple(s: str) -> tuple[int, int, int]:
1587 vals = s[s.find("(") + 1 : s.find(")")].split(",")
1588 r, g, b = [int(float(v)) for v in vals[:3]]
1589 return r, g, b
1592def _rgb_to_rgba(rgb: str, alpha: float) -> str:
1593 # expects "rgb(r,g,b)" or "rgba(r,g,b,a)"
1594 try:
1595 r, g, b = _rgb_string_to_tuple(rgb)
1596 except Exception:
1597 r, g, b = (31, 119, 180)
1598 return f"rgba({r},{g},{b},{alpha:.3f})"
1601def _add_low_success_shading_1d(fig: go.Figure, row_idx: int, x_vals: np.ndarray, p: np.ndarray, y0: float, y1: float):
1602 xref = "x" if row_idx == 1 else f"x{row_idx}"
1603 yref = "y" if row_idx == 1 else f"y{row_idx}"
1605 def _spans(vals: np.ndarray, mask: np.ndarray):
1606 m = mask.astype(int)
1607 diff = np.diff(np.concatenate([[0], m, [0]]))
1608 starts = np.where(diff == 1)[0]
1609 ends = np.where(diff == -1)[0] - 1
1610 return [(vals[s], vals[e]) for s, e in zip(starts, ends)]
1612 for x0, x1 in _spans(x_vals, p < 0.5):
1613 fig.add_shape(type="rect", x0=x0, x1=x1, y0=y0, y1=y1, xref=xref, yref=yref,
1614 line=dict(width=0), fillcolor="rgba(128,128,128,0.25)", layer="below")
1615 for x0, x1 in _spans(x_vals, p < 0.8):
1616 fig.add_shape(type="rect", x0=x0, x1=x1, y0=y0, y1=y1, xref=xref, yref=yref,
1617 line=dict(width=0), fillcolor="rgba(128,128,128,0.40)", layer="below")
1620def _set_yaxis_range(fig, *, row: int, col: int, y0: float, y1: float, log: bool, eps: float = 1e-12):
1621 """Update a subplot's Y axis to [y0, y1]. For log axes, the range is given in log10 units."""
1622 if log:
1623 y0 = max(y0, eps)
1624 y1 = max(y1, y0 * (1.0 + 1e-6))
1625 fig.update_yaxes(type="log", range=[np.log10(y0), np.log10(y1)], row=row, col=col)
1626 else:
1627 fig.update_yaxes(type="-", range=[y0, y1], row=row, col=col)
1630def plot1d_at_optimum(
1631 model: xr.Dataset | Path | str,
1632 output: Path | None = None,
1633 csv_out: Path | None = None,
1634 grid_size: int = 300,
1635 line_color: str = "rgb(31,119,180)",
1636 band_alpha: float = 0.25,
1637 figure_height_per_row_px: int = 320,
1638 show: bool = False,
1639 use_log_scale_for_target_y: bool = True,
1640 log_y_epsilon: float = 1e-9,
1641 suggest: int = 0, # optional overlay
1642 width: int | None = None,
1643 height: int | None = None,
1644 seed: int | None = 42,
1645 **kwargs, # constraints in ORIGINAL units (as in your plot1d)
1646) -> go.Figure:
1647 """
1648 1D partial-dependence panels anchored at the *optimal* hyperparameter setting:
1649 - Compute x* = argmin/argmax mean posterior from opt.optimal(...)
1650 - For each feature, sweep that feature; keep all *other* features fixed at x*.
1651 Supports numeric constraints (scalars/slices/choices) and categorical bases.
1652 """
1653 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model)
1654 pred_success, pred_loss = _build_predictors(ds)
1656 # --- metadata ---
1657 feature_names = [str(n) for n in ds["feature"].values.tolist()]
1658 transforms = [str(t) for t in ds["feature_transform"].values.tolist()]
1659 X_mean = ds["feature_mean"].values.astype(float)
1660 X_std = ds["feature_std"].values.astype(float)
1662 df_raw = _raw_dataframe_from_dataset(ds)
1663 Xn_train = ds["Xn_train"].values.astype(float)
1664 n_rows, p = Xn_train.shape
1666 # --- one-hot categorical groups ---
1667 groups = opt._onehot_groups(feature_names)
1668 bases = set(groups.keys())
1669 name_to_idx = {name: j for j, name in enumerate(feature_names)}
1671 # --- canonicalize kwargs: numeric vs categorical (base) ---
1672 idx_map = _canon_key_set(ds) # your helper: maps normalized names -> exact feature column
1673 kw_num_raw: dict[str, object] = {}
1674 kw_cat_raw: dict[str, object] = {}
1675 for k, v in kwargs.items():
1676 if k in bases:
1677 kw_cat_raw[k] = v
1678 elif k in idx_map:
1679 kw_num_raw[idx_map[k]] = v
1680 else:
1681 import re as _re
1682 nk = _re.sub(r"[^a-z0-9]+", "", str(k).lower())
1683 if nk in idx_map:
1684 kw_num_raw[idx_map[nk]] = v
1686 # --- resolve categorical constraints: fixed vs allowed subset ---
1687 cat_fixed: dict[str, str] = {}
1688 cat_allowed: dict[str, list[str]] = {}
1689 for base, val in kw_cat_raw.items():
1690 labels = groups[base]["labels"]
1691 if isinstance(val, str):
1692 if val not in labels:
1693 raise ValueError(f"Unknown category for {base!r}: {val!r}. Choices: {labels}")
1694 cat_fixed[base] = val
1695 elif isinstance(val, (list, tuple, set)):
1696 chosen = [x for x in val if isinstance(x, str) and x in labels]
1697 if not chosen:
1698 raise ValueError(f"No valid categories for {base!r} in {val!r}. Choices: {labels}")
1699 cat_allowed[base] = list(dict.fromkeys(chosen))
1700 else:
1701 raise ValueError(f"Categorical constraint for {base!r} must be a string or list/tuple of strings.")
1703 # ---------- 1) Find the *optimal* base point (original units) ----------
1704 opt_df = opt.optimal(model, count=1, seed=seed, **kwargs) # uses your gradient-based optimal()
1705 # We’ll use this row both for overlays and as the anchor point.
1706 # Expect numeric feature columns and categorical base columns present.
1707 x_opt_std = np.zeros(p, dtype=float)
1709 # Fill numerics from optimal row (orig -> internal -> std)
1710 def _to_std_single(j: int, x_orig: float) -> float:
1711 xi = x_orig
1712 if transforms[j] == "log10":
1713 xi = np.log10(np.maximum(x_orig, 1e-300))
1714 return float((xi - X_mean[j]) / X_std[j])
1716 # Mark one-hot member names
1717 onehot_members: set[str] = set()
1718 for base, g in groups.items():
1719 onehot_members.update(g["members"])
1721 # numeric features (skip one-hot members)
1722 for j, nm in enumerate(feature_names):
1723 if nm in onehot_members:
1724 continue
1725 if nm in opt_df.columns:
1726 x_opt_std[j] = _to_std_single(j, float(opt_df.iloc[0][nm]))
1727 else:
1728 # Fall back to dataset median if not present (rare)
1729 x_opt_std[j] = float(np.median(Xn_train[:, j]))
1731 # Categorical bases: set one-hot block to the optimal label (or fixed)
1732 for base, g in groups.items():
1733 # priority: fixed in kwargs → else from optimal row → else keep current (median/std)
1734 if base in cat_fixed:
1735 label = cat_fixed[base]
1736 elif base in opt_df.columns:
1737 label = str(opt_df.iloc[0][base])
1738 else:
1739 # fallback: most frequent label in data
1740 if base in df_raw.columns:
1741 label = str(df_raw[base].astype("string").mode(dropna=True).iloc[0])
1742 else:
1743 label = g["labels"][0]
1745 for lab in g["labels"]:
1746 member_name = g["name_by_label"][lab]
1747 j2 = name_to_idx[member_name]
1748 raw = 1.0 if lab == label else 0.0
1749 # raw (0/1) → standardized using the artifact stats
1750 x_opt_std[j2] = (raw - X_mean[j2]) / X_std[j2]
1752 # ---------- 2) Numeric constraints in STANDARDIZED space ----------
1753 def _orig_to_std(j: int, x, transforms, mu, sd):
1754 x = np.asarray(x, dtype=float)
1755 if transforms[j] == "log10":
1756 x = np.where(x <= 0, np.nan, x)
1757 x = np.log10(x)
1758 return (x - mu[j]) / sd[j]
1760 fixed_scalars: dict[int, float] = {}
1761 range_windows: dict[int, tuple[float, float]] = {}
1762 choice_values: dict[int, np.ndarray] = {}
1764 for name, val in kw_num_raw.items():
1765 if name not in name_to_idx:
1766 continue
1767 j = name_to_idx[name]
1768 if isinstance(val, slice):
1769 lo = _orig_to_std(j, float(val.start), transforms, X_mean, X_std)
1770 hi = _orig_to_std(j, float(val.stop), transforms, X_mean, X_std)
1771 lo, hi = float(min(lo, hi)), float(max(lo, hi))
1772 range_windows[j] = (lo, hi)
1773 elif isinstance(val, (list, tuple, np.ndarray)):
1774 arr = _orig_to_std(j, np.asarray(val, dtype=float), transforms, X_mean, X_std)
1775 choice_values[j] = np.asarray(arr, dtype=float)
1776 else:
1777 fixed_scalars[j] = float(_orig_to_std(j, float(val), transforms, X_mean, X_std))
1779 # apply numeric fixed overrides on the base point
1780 for j, vstd in fixed_scalars.items():
1781 x_opt_std[j] = vstd
1783 # ---------- 3) Panels: sweep ONE var at a time around x* ----------
1784 # numeric free = not one-hot member and not fixed via kwargs
1785 free_numeric_idx = [
1786 j for j, nm in enumerate(feature_names)
1787 if (nm not in onehot_members) and (j not in fixed_scalars)
1788 ]
1789 # categorical bases: sweep if not fixed; otherwise not shown
1790 free_cat_bases = [b for b in bases if b not in cat_fixed]
1792 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases]
1793 if not panels:
1794 raise ValueError("All features are fixed at the optimum (or categoricals fixed); nothing to plot.")
1796 # empirical 1–99% per feature (for default sweep range)
1797 Xn_p01 = np.percentile(Xn_train, 1, axis=0)
1798 Xn_p99 = np.percentile(Xn_train, 99, axis=0)
1800 def _grid_1d(j: int, n: int) -> np.ndarray:
1801 # default range in std space
1802 lo, hi = float(Xn_p01[j]), float(Xn_p99[j])
1803 if j in range_windows:
1804 lo = max(lo, range_windows[j][0])
1805 hi = min(hi, range_windows[j][1])
1806 if j in choice_values:
1807 vals = np.asarray(choice_values[j], dtype=float)
1808 vals = vals[(vals >= lo) & (vals <= hi)]
1809 return np.unique(np.sort(vals)) if vals.size else np.array([x_opt_std[j]], dtype=float)
1810 if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
1811 lo, hi = x_opt_std[j] - 1.0, x_opt_std[j] + 1.0
1812 return np.linspace(lo, hi, n)
1814 # figure scaffold
1815 subplot_titles = [feature_names[int(k)] if t == "num" else str(k) for t, k in panels]
1816 fig = make_subplots(rows=len(panels), cols=1, shared_xaxes=False, subplot_titles=subplot_titles)
1818 tgt_col = str(ds.attrs["target"])
1819 success_mask = ~pd.isna(df_raw[tgt_col]).to_numpy()
1820 fail_mask = ~success_mask
1821 losses_success = df_raw.loc[success_mask, tgt_col].to_numpy().astype(float)
1822 trial_ids_success = df_raw.get("trial_id", pd.Series(np.arange(len(df_raw)))).to_numpy()[success_mask]
1823 trial_ids_fail = df_raw.get("trial_id", pd.Series(np.arange(len(df_raw)))).to_numpy()[fail_mask]
1824 band_fill_rgba = _rgb_to_rgba(line_color, band_alpha)
1826 # optional overlay
1827 suggest_df = opt.suggest(model, count=suggest, seed=seed, **kwargs) if (suggest and suggest > 0) else None
1829 tidy_rows: list[dict] = []
1830 row_pos = 0
1831 for kind, key in panels:
1832 row_pos += 1
1834 if kind == "num":
1835 j = int(key)
1836 grid = _grid_1d(j, grid_size)
1837 Xn_grid = np.repeat(x_opt_std[None, :], len(grid), axis=0)
1838 Xn_grid[:, j] = grid
1840 p_grid = pred_success(Xn_grid)
1841 mu_grid, sd_grid = pred_loss(Xn_grid, include_observation_noise=True)
1843 x_internal = grid * X_std[j] + X_mean[j]
1844 x_display = _inverse_transform(transforms[j], x_internal)
1846 # y transform
1847 if use_log_scale_for_target_y:
1848 mu_plot = np.maximum(mu_grid, log_y_epsilon)
1849 lo_plot = np.maximum(mu_grid - 2.0 * sd_grid, log_y_epsilon)
1850 hi_plot = np.maximum(mu_grid + 2.0 * sd_grid, log_y_epsilon)
1851 losses_s_plot = np.maximum(losses_success, log_y_epsilon) if losses_success.size else losses_success
1852 else:
1853 mu_plot = mu_grid
1854 lo_plot = mu_grid - 2.0 * sd_grid
1855 hi_plot = mu_grid + 2.0 * sd_grid
1856 losses_s_plot = losses_success
1858 y_arrays = [lo_plot, hi_plot] + ([losses_s_plot] if losses_s_plot.size else [])
1859 y_low = float(np.nanmin([np.nanmin(a) for a in y_arrays]))
1860 y_high = float(np.nanmax([np.nanmax(a) for a in y_arrays]))
1861 pad = 0.05 * (y_high - y_low + 1e-12)
1862 y0_plot = (y_low - pad) if not use_log_scale_for_target_y else max(y_low / 1.5, log_y_epsilon)
1863 y1_tmp = (y_high + pad) if not use_log_scale_for_target_y else y_high * 1.2
1864 y_failed_band = y1_tmp + (y_high - y_low + 1e-12) * (0.08 if not use_log_scale_for_target_y else 0.3)
1865 if use_log_scale_for_target_y and y_failed_band <= log_y_epsilon:
1866 y_failed_band = max(10.0 * log_y_epsilon, y_high * 2.0)
1867 y1_plot = y_failed_band + (0.02 if not use_log_scale_for_target_y else 0.05) * (y_high - y_low + 1e-12)
1869 _add_low_success_shading_1d(fig, row_pos, x_display, p_grid, y0_plot, y1_plot)
1871 show_legend = (row_pos == 1)
1872 # ±2σ band
1873 fig.add_trace(go.Scatter(x=x_display, y=lo_plot, mode="lines",
1874 line=dict(width=0, color=line_color),
1875 name="±2σ", legendgroup="band", showlegend=False, hoverinfo="skip"),
1876 row=row_pos, col=1)
1877 fig.add_trace(go.Scatter(x=x_display, y=hi_plot, mode="lines", fill="tonexty",
1878 line=dict(width=0, color=line_color), fillcolor=band_fill_rgba,
1879 name="±2σ", legendgroup="band", showlegend=show_legend,
1880 hovertemplate="E[target|success]: %{y:.3f}<extra>±2σ</extra>"),
1881 row=row_pos, col=1)
1882 # mean
1883 fig.add_trace(go.Scatter(x=x_display, y=mu_plot, mode="lines",
1884 line=dict(width=2, color=line_color),
1885 name="E[target|success]", legendgroup="mean", showlegend=show_legend,
1886 hovertemplate=f"{feature_names[j]}: %{{x:.6g}}<br>E[target|success]: %{{y:.3f}}<extra></extra>"),
1887 row=row_pos, col=1)
1889 # experimental points at y
1890 if feature_names[j] in df_raw.columns:
1891 x_data_all = df_raw[feature_names[j]].to_numpy().astype(float)
1892 else:
1893 full_vals = feature_raw_from_artifact_or_reconstruct(ds, j, feature_names[j], transforms[j]).astype(float)
1894 x_data_all = full_vals
1896 x_succ = x_data_all[success_mask]
1897 if x_succ.size:
1898 fig.add_trace(go.Scattergl(
1899 x=x_succ, y=losses_s_plot, mode="markers",
1900 marker=dict(size=5, color="black", line=dict(width=0)),
1901 name="data (success)", legendgroup="data_s", showlegend=show_legend,
1902 hovertemplate=("trial_id: %{customdata}<br>"
1903 f"{feature_names[j]}: %{{x:.6g}}<br>"
1904 f"{tgt_col}: %{{y:.4f}}<extra></extra>"),
1905 customdata=trial_ids_success
1906 ), row=row_pos, col=1)
1908 x_fail = x_data_all[fail_mask]
1909 if x_fail.size:
1910 y_fail_plot = np.full_like(x_fail, y_failed_band, dtype=float)
1911 fig.add_trace(go.Scattergl(
1912 x=x_fail, y=y_fail_plot, mode="markers",
1913 marker=dict(size=6, color="red", line=dict(color="black", width=0.8)),
1914 name="data (failed)", legendgroup="data_f", showlegend=show_legend,
1915 hovertemplate=("trial_id: %{customdata}<br>"
1916 f"{feature_names[j]}: %{{x:.6g}}<br>"
1917 "status: failed (NaN target)<extra></extra>"),
1918 customdata=trial_ids_fail
1919 ), row=row_pos, col=1)
1921 # overlays: optimal (single point) and suggested (optional many)
1922 x_opt_disp = None
1923 if feature_names[j] in opt_df.columns:
1924 x_opt_disp = float(opt_df.iloc[0][feature_names[j]])
1925 y_opt = float(opt_df.iloc[0]["pred_target_mean"])
1926 y_opt_sd = float(opt_df.iloc[0].get("pred_target_sd", np.nan))
1927 fig.add_trace(go.Scattergl(
1928 x=[x_opt_disp], y=[y_opt], mode="markers",
1929 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"),
1930 name="optimal", legendgroup="optimal", showlegend=show_legend,
1931 hovertemplate=(f"predicted: %{{y:.3g}}"
1932 + ("" if np.isnan(y_opt_sd) else f" ± {y_opt_sd:.3g}")
1933 + f"<br>{feature_names[j]}: %{{x:.6g}}<extra></extra>")
1934 ), row=row_pos, col=1)
1936 if suggest and (suggest_df is not None) and (feature_names[j] in suggest_df.columns):
1937 xs = suggest_df[feature_names[j]].values.astype(float)
1938 ys = suggest_df["pred_target_mean"].values.astype(float)
1939 ysd = suggest_df.get("pred_target_sd", pd.Series([np.nan]*len(suggest_df))).values
1940 fig.add_trace(go.Scattergl(
1941 x=xs, y=ys, mode="markers",
1942 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"),
1943 name="suggested", legendgroup="suggested", showlegend=show_legend,
1944 hovertemplate=("predicted: %{y:.3g}"
1945 + (" ± %{customdata:.3g}" if not np.isnan(ysd).all() else "")
1946 + f"<br>{feature_names[j]}: %{{x:.6g}}<extra>suggested</extra>"),
1947 customdata=ysd
1948 ), row=row_pos, col=1)
1950 # axes + ranges
1951 _maybe_log_axis(fig, row_pos, 1, feature_names[j], axis="x", transforms=transforms, j=j)
1952 fig.update_yaxes(title_text=f"{tgt_col}", row=row_pos, col=1)
1953 _set_yaxis_range(fig, row=row_pos, col=1,
1954 y0=y0_plot, y1=y1_plot,
1955 log=use_log_scale_for_target_y, eps=log_y_epsilon)
1956 fig.update_xaxes(title_text=feature_names[j], row=row_pos, col=1)
1958 # If a constraint limited the sweep, respect it on the displayed axis
1959 def _std_to_orig(val_std: float) -> float:
1960 vi = val_std * X_std[j] + X_mean[j]
1961 return float(_inverse_transform(transforms[j], np.array([vi]))[0])
1963 if j in range_windows:
1964 lo_std, hi_std = range_windows[j]
1965 x_min_override = min(_std_to_orig(lo_std), _std_to_orig(hi_std))
1966 x_max_override = max(_std_to_orig(lo_std), _std_to_orig(hi_std))
1967 span = (x_max_override - x_min_override) or 1.0
1968 pad = 0.02 * span
1969 fig.update_xaxes(range=[x_min_override - pad, x_max_override + pad], row=row_pos, col=1)
1970 elif j in choice_values and choice_values[j].size:
1971 ints = choice_values[j] * X_std[j] + X_mean[j]
1972 origs = _inverse_transform(transforms[j], ints)
1973 span = float(np.max(origs) - np.min(origs)) or 1.0
1974 pad = 0.05 * span
1975 fig.update_xaxes(range=[float(np.min(origs) - pad), float(np.max(origs) + pad)], row=row_pos, col=1)
1977 # tidy rows
1978 for xd, xi, mu_i, sd_i, p_i in zip(x_display, x_internal, mu_grid, sd_grid, p_grid):
1979 tidy_rows.append({
1980 "feature": feature_names[j],
1981 "x_display": float(xd),
1982 "x_internal": float(xi),
1983 "target_conditional_mean": float(mu_i),
1984 "target_conditional_sd": float(sd_i),
1985 "success_probability": float(p_i),
1986 })
1988 else:
1989 base = str(key)
1990 labels_all = groups[base]["labels"]
1991 labels = cat_allowed.get(base, labels_all)
1993 # Evaluate each label with numerics and other bases fixed at x_opt_std
1994 Xn_grid = np.repeat(x_opt_std[None, :], len(labels), axis=0)
1995 for r, lab in enumerate(labels):
1996 for lab2 in labels_all:
1997 member_name = groups[base]["name_by_label"][lab2]
1998 j2 = name_to_idx[member_name]
1999 raw_val = 1.0 if (lab2 == lab) else 0.0
2000 Xn_grid[r, j2] = (raw_val - X_mean[j2]) / X_std[j2]
2002 p_vec = pred_success(Xn_grid)
2003 mu_vec, sd_vec = pred_loss(Xn_grid, include_observation_noise=True)
2005 # y transform
2006 if use_log_scale_for_target_y:
2007 mu_plot = np.maximum(mu_vec, log_y_epsilon)
2008 lo_plot = np.maximum(mu_vec - 2.0 * sd_vec, log_y_epsilon)
2009 hi_plot = np.maximum(mu_vec + 2.0 * sd_vec, log_y_epsilon)
2010 losses_s_plot = np.maximum(df_raw.loc[success_mask, tgt_col].to_numpy().astype(float), log_y_epsilon) if success_mask.any() else np.array([])
2011 else:
2012 mu_plot = mu_vec
2013 lo_plot = mu_vec - 2.0 * sd_vec
2014 hi_plot = mu_vec + 2.0 * sd_vec
2015 losses_s_plot = df_raw.loc[success_mask, tgt_col].to_numpy().astype(float) if success_mask.any() else np.array([])
2017 y_arrays = [lo_plot, hi_plot] + ([losses_s_plot] if losses_s_plot.size else [])
2018 y_low = float(np.nanmin([np.nanmin(a) for a in y_arrays])) if y_arrays else 0.0
2019 y_high = float(np.nanmax([np.nanmax(a) for a in y_arrays])) if y_arrays else 1.0
2020 pad = 0.05 * (y_high - y_low + 1e-12)
2021 y0_plot = (y_low - pad) if not use_log_scale_for_target_y else max(y_low / 1.5, log_y_epsilon)
2022 y1_tmp = (y_high + pad) if not use_log_scale_for_target_y else y_high * 1.2
2023 y_failed_band = y1_tmp + (y_high - y_low + 1e-12) * (0.08 if not use_log_scale_for_target_y else 0.3)
2024 if use_log_scale_for_target_y and y_failed_band <= log_y_epsilon:
2025 y_failed_band = max(10.0 * log_y_epsilon, y_high * 2.0)
2026 y1_plot = y_failed_band + (0.02 if not use_log_scale_for_target_y else 0.05) * (y_high - y_low + 1e-12)
2028 # x = 0..K-1 with tick labels
2029 x_pos = np.arange(len(labels), dtype=float)
2031 # grey out infeasible (p<thr)
2032 def _shade_for_thresh(thr: float, alpha: float):
2033 for k_i, p_i in enumerate(p_vec):
2034 if p_i < thr:
2035 fig.add_shape(
2036 type="rect",
2037 xref=f"x{'' if row_pos==1 else row_pos}",
2038 yref=f"y{'' if row_pos==1 else row_pos}",
2039 x0=k_i - 0.5, x1=k_i + 0.5,
2040 y0=y0_plot, y1=y1_plot,
2041 line=dict(width=0),
2042 fillcolor=f"rgba(128,128,128,{alpha})",
2043 layer="below",
2044 row=row_pos, col=1
2045 )
2046 _shade_for_thresh(0.8, 0.40)
2047 _shade_for_thresh(0.5, 0.25)
2049 show_legend = (row_pos == 1)
2050 fig.add_trace(go.Scatter(
2051 x=x_pos, y=mu_plot, mode="lines+markers",
2052 line=dict(width=2, color=line_color),
2053 marker=dict(size=7, color=line_color),
2054 error_y=dict(type="data", array=(hi_plot - mu_plot), arrayminus=(mu_plot - lo_plot), visible=True),
2055 name="E[target|success]", legendgroup="mean", showlegend=show_legend,
2056 hovertemplate=(f"{base}: %{{text}}<br>E[target|success]: %{{y:.3f}}<extra></extra>"),
2057 text=labels
2058 ), row=row_pos, col=1)
2060 # overlay optimal point for this base (single label at x*=opt)
2061 if base in opt_df.columns:
2062 lab_opt = str(opt_df.iloc[0][base])
2063 if lab_opt in labels:
2064 xi = float(labels.index(lab_opt))
2065 y_opt = float(opt_df.iloc[0]["pred_target_mean"])
2066 y_opt_sd = float(opt_df.iloc[0].get("pred_target_sd", np.nan))
2067 fig.add_trace(go.Scattergl(
2068 x=[xi], y=[y_opt], mode="markers",
2069 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"),
2070 name="optimal", legendgroup="optimal", showlegend=show_legend,
2071 hovertemplate=(f"predicted: %{{y:.3g}}"
2072 + ("" if np.isnan(y_opt_sd) else f" ± {y_opt_sd:.3g}")
2073 + f"<br>{base}: {lab_opt}<extra></extra>")
2074 ), row=row_pos, col=1)
2076 # overlay suggestions (optional)
2077 if suggest and (suggest_df is not None) and (base in suggest_df.columns):
2078 labs_sug = suggest_df[base].astype(str).tolist()
2079 xs = [labels.index(l) for l in labs_sug if l in labels]
2080 if xs:
2081 keep_mask = [l in labels for l in labs_sug]
2082 y_sug = suggest_df.loc[keep_mask, "pred_target_mean"].values
2083 fig.add_trace(go.Scattergl(
2084 x=np.array(xs, dtype=float), y=y_sug, mode="markers",
2085 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"),
2086 name="suggested", legendgroup="suggested", showlegend=show_legend,
2087 hovertemplate=(f"{base}: %{{text}}<br>"
2088 "predicted: %{{y:.3g}}<extra>suggested</extra>"),
2089 text=[labels[int(i)] for i in xs]
2090 ), row=row_pos, col=1)
2092 fig.update_xaxes(
2093 tickmode="array",
2094 tickvals=x_pos.tolist(),
2095 ticktext=labels,
2096 title_text=base,
2097 row=row_pos, col=1
2098 )
2099 fig.update_yaxes(title_text=f"{tgt_col}", row=row_pos, col=1)
2100 _set_yaxis_range(fig, row=row_pos, col=1,
2101 y0=y0_plot, y1=y1_plot,
2102 log=use_log_scale_for_target_y, eps=log_y_epsilon)
2104 # tidy rows
2105 for lab, mu_i, sd_i, p_i in zip(labels, mu_vec, sd_vec, p_vec):
2106 tidy_rows.append({
2107 "feature": base,
2108 "x_display": str(lab),
2109 "x_internal": float("nan"),
2110 "target_conditional_mean": float(mu_i),
2111 "target_conditional_sd": float(sd_i),
2112 "success_probability": float(p_i),
2113 })
2115 # ---- layout & IO ----
2116 parts = [f"1D PD at optimal setting of all other hyperparameters ({ds.attrs.get('target', 'target')})"]
2117 if kw_num_raw:
2118 def _fmt_c(v):
2119 if isinstance(v, slice):
2120 a = "" if v.start is None else f"{v.start:g}"
2121 b = "" if v.stop is None else f"{v.stop:g}"
2122 return f"[{a},{b}]"
2123 if isinstance(v, (list, tuple, np.ndarray)):
2124 try:
2125 return "[" + ",".join(f"{float(x):g}" for x in np.asarray(v).tolist()) + "]"
2126 except Exception:
2127 return "[" + ",".join(map(str, v)) + "]"
2128 try:
2129 return f"{float(v):g}"
2130 except Exception:
2131 return str(v)
2132 parts.append(", ".join(f"{k}={_fmt_c(v)}" for k, v in kw_num_raw.items()))
2133 if cat_fixed:
2134 parts.append(", ".join(f"{b}={lab}" for b, lab in cat_fixed.items()))
2135 title = " — ".join(parts)
2137 width = width if (width and width > 0) else 1200
2138 height = height if (height and height > 0) else figure_height_per_row_px * len(panels)
2139 fig.update_layout(height=height, width=width, template="simple_white", title=title, legend_title_text="")
2141 if output:
2142 output = Path(output); output.parent.mkdir(parents=True, exist_ok=True)
2143 fig.write_html(str(output), include_plotlyjs="cdn")
2144 if csv_out:
2145 csv_out = Path(csv_out); csv_out.parent.mkdir(parents=True, exist_ok=True)
2146 pd.DataFrame(tidy_rows).to_csv(str(csv_out), index=False)
2147 if show:
2148 fig.show("browser")
2149 return fig