Coverage for psyop/viz.py: 29.13%
1651 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-10-29 03:44 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-10-29 03:44 +0000
1# -*- coding: utf-8 -*-
2from pathlib import Path
3import numpy as np
4import pandas as pd
5import xarray as xr
7import plotly.graph_objects as go
8from plotly.subplots import make_subplots
9from plotly.colors import get_colorscale, sample_colorscale
11from .model import (
12 kernel_diag_m52,
13 kernel_m52_ard,
14 add_jitter,
15 solve_chol,
16 solve_lower,
17 feature_raw_from_artifact_or_reconstruct,
18)
19from . import opt
22def _canon_key_set(ds) -> dict[str, str]:
23 feats = [str(x) for x in ds["feature"].values.tolist()]
24 def _norm(s: str) -> str:
25 import re
26 return re.sub(r"[^a-z0-9]+", "", s.lower())
27 return {**{f: f for f in feats}, **{_norm(f): f for f in feats}}
30def _edges_from_centers(vals: np.ndarray, is_log: bool) -> tuple[float, float]:
31 """Return (min_edge, max_edge) that tightly bound a heatmap with given center coords."""
32 v = np.asarray(vals, float)
33 v = v[np.isfinite(v)]
34 if v.size == 0:
35 return (0.0, 1.0)
36 if v.size == 1:
37 # tiny pad
38 if is_log:
39 lo = max(v[0] / 1.5, 1e-12)
40 hi = v[0] * 1.5
41 else:
42 span = max(abs(v[0]) * 0.5, 1e-9)
43 lo, hi = v[0] - span, v[0] + span
44 return float(lo), float(hi)
46 if is_log:
47 lv = np.log10(v)
48 l0 = lv[0] - 0.5 * (lv[1] - lv[0])
49 lN = lv[-1] + 0.5 * (lv[-1] - lv[-2])
50 lo = 10.0 ** l0
51 hi = 10.0 ** lN
52 lo = max(lo, 1e-12)
53 hi = max(hi, lo * 1.0000001)
54 return float(lo), float(hi)
55 else:
56 d0 = v[1] - v[0]
57 dN = v[-1] - v[-2]
58 lo = v[0] - 0.5 * d0
59 hi = v[-1] + 0.5 * dN
60 if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
61 lo, hi = float(v.min()), float(v.max())
62 return float(lo), float(hi)
64def _update_axis_type_and_range(
65 fig, *, row: int, col: int, axis: str, centers: np.ndarray, is_log: bool
66):
67 """Set axis type and range to heatmap edges so tiles meet the axes exactly."""
68 lo, hi = _edges_from_centers(centers, is_log)
69 if axis == "x":
70 if is_log:
71 fig.update_xaxes(type="log", range=[np.log10(lo), np.log10(hi)], row=row, col=col)
72 else:
73 fig.update_xaxes(range=[lo, hi], row=row, col=col)
74 else:
75 if is_log:
76 fig.update_yaxes(type="log", range=[np.log10(lo), np.log10(hi)], row=row, col=col)
77 else:
78 fig.update_yaxes(range=[lo, hi], row=row, col=col)
81def plot2d(
82 model: xr.Dataset | Path | str,
83 output: Path | None = None,
84 grid_size: int = 70,
85 use_log_scale_for_target: bool = False, # log10 colors for heatmaps
86 log_shift_epsilon: float = 1e-9,
87 colorscale: str = "RdBu",
88 show: bool = False,
89 n_contours: int = 12,
90 optimal: bool = True,
91 suggest: int = 0,
92 seed: int|None = 42,
93 width:int|None = None,
94 height:int|None = None,
95 **kwargs,
96) -> go.Figure:
97 """
98 2D Partial Dependence of E[target|success] (pairwise features), including
99 categorical variables as single axes (one row/column per base).
101 Conditioning via kwargs (original units):
102 - numeric: fixed scalar, slice(lo,hi), list/tuple choices
103 - categorical base (e.g. language="Linear A" or language=("Linear A","Linear B")):
104 * single string: fixed to that label (axis removed and clamped)
105 * list/tuple of labels: restrict the categorical axis to those labels
107 Notes:
108 * one-hot member features (e.g. language=Linear A) never appear as axes.
109 * when a categorical axis is present, we render a heatmap over category index
110 with tick labels set to the category names; data overlays use jitter.
111 """
112 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model)
113 pred_success, pred_loss = _build_predictors(ds)
115 # --- features & transforms
116 feature_names = [str(x) for x in ds["feature"].values.tolist()]
117 transforms = [str(t) for t in ds["feature_transform"].values.tolist()]
118 X_mean = ds["feature_mean"].values.astype(float)
119 X_std = ds["feature_std"].values.astype(float)
120 name_to_idx = {nm: i for i, nm in enumerate(feature_names)}
122 # one-hot groups
123 groups = opt._onehot_groups(feature_names) # { base: {"labels":[...], "name_by_label":{label->member}, "members":[...] } }
124 bases = set(groups.keys())
125 onehot_member_names = {m for g in groups.values() for m in g["members"]}
127 # raw df + train design
128 df_raw = _raw_dataframe_from_dataset(ds)
129 Xn_train = ds["Xn_train"].values.astype(float)
130 n_rows = Xn_train.shape[0]
132 # --- split kwargs into numeric vs categorical (keys are canonical already when coming from CLI)
133 kw_num: dict[str, object] = {}
134 kw_cat: dict[str, object] = {}
135 for k, v in (kwargs or {}).items():
136 if k in bases:
137 kw_cat[k] = v
138 elif k in name_to_idx:
139 kw_num[k] = v
140 else:
141 # unknown/ignored
142 pass
144 # --- resolve categorical constraints:
145 # - cat_fixed[base] -> fixed single label (axis removed and clamped)
146 # - cat_allowed[base] -> labels that are allowed on that axis (if not fixed)
147 cat_fixed: dict[str, str] = {}
148 cat_allowed: dict[str, list[str]] = {}
149 for base in bases:
150 labels = list(groups[base]["labels"])
151 if base not in kw_cat:
152 cat_allowed[base] = labels # unrestricted axis (if not fixed numerically later)
153 continue
154 val = kw_cat[base]
155 if isinstance(val, str):
156 if val not in labels:
157 raise ValueError(f"Unknown category for {base!r}: {val!r}. Choices: {labels}")
158 cat_fixed[base] = val
159 elif isinstance(val, (list, tuple, set)):
160 chosen = [x for x in val if isinstance(x, str) and x in labels]
161 if not chosen:
162 raise ValueError(f"No valid categories for {base!r} in {val!r}. Choices: {labels}")
163 # if only one remains, treat as fixed; else allowed list
164 if len(chosen) == 1:
165 cat_fixed[base] = chosen[0]
166 else:
167 cat_allowed[base] = chosen
168 else:
169 raise ValueError(f"Categorical constraint for {base!r} must be a string or list/tuple of strings.")
171 # --- filter rows to categorical *fixed* selections for medians/percentiles & overlays
172 row_mask = np.ones(n_rows, dtype=bool)
173 for base, label in cat_fixed.items():
174 if base in df_raw.columns:
175 row_mask &= (df_raw[base].astype("string") == pd.Series([label]*len(df_raw), dtype="string")).to_numpy()
176 else:
177 member = groups[base]["name_by_label"][label]
178 j = name_to_idx[member]
179 raw_j = feature_raw_from_artifact_or_reconstruct(ds, j, member, transforms[j]).astype(float)
180 row_mask &= (raw_j >= 0.5)
182 # --- numeric constraints (standardized)
183 def _orig_to_std(j: int, x, transforms, mu, sd):
184 x = np.asarray(x, dtype=float)
185 if transforms[j] == "log10":
186 x = np.where(x <= 0, np.nan, x)
187 x = np.log10(x)
188 return (x - mu[j]) / sd[j]
190 fixed_scalars_std: dict[int, float] = {}
191 range_windows_std: dict[int, tuple[float, float]] = {}
192 choice_values_std: dict[int, np.ndarray] = {}
194 for name, val in kw_num.items():
195 j = name_to_idx[name]
196 if isinstance(val, slice):
197 lo = _orig_to_std(j, float(val.start), transforms, X_mean, X_std)
198 hi = _orig_to_std(j, float(val.stop), transforms, X_mean, X_std)
199 lo, hi = float(min(lo, hi)), float(max(lo, hi))
200 range_windows_std[j] = (lo, hi)
201 elif isinstance(val, (list, tuple, np.ndarray)):
202 arr = _orig_to_std(j, np.asarray(val, dtype=float), transforms, X_mean, X_std)
203 choice_values_std[j] = np.asarray(arr, dtype=float)
204 else:
205 fixed_scalars_std[j] = float(_orig_to_std(j, float(val), transforms, X_mean, X_std))
207 # --- apply categorical *fixed* selections as standardized 0/1 on their member features
208 for base, label in cat_fixed.items():
209 labels = groups[base]["labels"]
210 for lab in labels:
211 member = groups[base]["name_by_label"][lab]
212 j = name_to_idx[member]
213 raw_val = 1.0 if (lab == label) else 0.0
214 fixed_scalars_std[j] = float(_orig_to_std(j, raw_val, transforms, X_mean, X_std))
216 # --- enforce row-level filters so overlays/points respect constraints ---
217 for base, allowed in cat_allowed.items():
218 if (base not in kw_cat) or (base in cat_fixed):
219 continue
220 allowed_labels = [str(x) for x in allowed]
221 if base in df_raw.columns:
222 series = df_raw[base].astype("string").fillna("<NA>")
223 if not allowed_labels:
224 row_mask &= False
225 else:
226 allowed_mask = series.isin(set(allowed_labels)).fillna(False).to_numpy()
227 row_mask &= allowed_mask
228 else:
229 allowed_masks: list[np.ndarray] = []
230 for label in allowed_labels:
231 member = groups[base]["name_by_label"].get(label)
232 if member is None:
233 continue
234 j = name_to_idx[member]
235 raw_j = feature_raw_from_artifact_or_reconstruct(ds, j, member, transforms[j]).astype(float)
236 allowed_masks.append(raw_j >= 0.5)
237 if allowed_masks:
238 row_mask &= np.logical_or.reduce(allowed_masks)
239 else:
240 row_mask &= False
242 for name, val in kw_num.items():
243 if name not in name_to_idx:
244 continue
245 j = name_to_idx[name]
246 if name in df_raw.columns:
247 raw_vals = pd.to_numeric(df_raw[name], errors="coerce").to_numpy(dtype=float)
248 else:
249 raw_vals = feature_raw_from_artifact_or_reconstruct(ds, j, feature_names[j], transforms[j]).astype(float)
251 mask = np.isfinite(raw_vals)
252 if isinstance(val, slice):
253 lo_raw = -np.inf if val.start is None else float(val.start)
254 hi_raw = np.inf if val.stop is None else float(val.stop)
255 if hi_raw < lo_raw:
256 lo_raw, hi_raw = hi_raw, lo_raw
257 mask &= (raw_vals >= lo_raw) & (raw_vals <= hi_raw)
258 elif isinstance(val, (list, tuple, set, np.ndarray)):
259 arr = np.asarray(list(val) if not isinstance(val, np.ndarray) else val, dtype=float)
260 arr = arr[np.isfinite(arr)]
261 if arr.size == 0:
262 mask &= False
263 else:
264 close_mask = np.any(np.isclose(raw_vals[:, None], arr[None, :], rtol=1e-6, atol=1e-9), axis=1)
265 mask &= close_mask
266 else:
267 target = float(val)
268 mask &= np.isclose(raw_vals, target, rtol=1e-6, atol=1e-9)
270 row_mask &= mask
272 if not np.any(row_mask):
273 raise ValueError("No experiments match the provided constraints; cannot plot data points.")
275 row_mask_active = not bool(np.all(row_mask))
276 df_raw_f = df_raw.loc[row_mask].reset_index(drop=True) if row_mask_active else df_raw
277 Xn_train_f = Xn_train[row_mask, :] if row_mask_active else Xn_train
279 # --- free axes = numeric features not scalar-fixed & not one-hot members, plus categorical bases not fixed
280 free_numeric_idx = [
281 j for j, nm in enumerate(feature_names)
282 if (j not in fixed_scalars_std) and (nm not in onehot_member_names)
283 ]
284 free_cat_bases = [b for b in bases if b not in cat_fixed] # we already filtered by allowed above
286 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases]
287 if not panels:
288 raise ValueError("All features are fixed (or only single-category categoricals remain); nothing to plot.")
290 # --- base point (median in standardized space of filtered rows), then apply scalar fixes
291 base_std = np.median(Xn_train_f, axis=0)
292 for j, vstd in fixed_scalars_std.items():
293 base_std[j] = vstd
295 # --- per-feature grids (numeric) over filtered 1–99% + respecting ranges/choices
296 p01p99 = [np.percentile(Xn_train_f[:, j], [1, 99]) for j in range(len(feature_names))]
297 def _grid_std_num(j: int) -> np.ndarray:
298 p01, p99 = p01p99[j]
299 if j in choice_values_std:
300 vals = np.asarray(choice_values_std[j], dtype=float)
301 vals = vals[(vals >= p01) & (vals <= p99)]
302 return np.unique(np.sort(vals)) if vals.size else np.array([np.median(Xn_train_f[:, j])])
303 lo, hi = p01, p99
304 if j in range_windows_std:
305 rlo, rhi = range_windows_std[j]
306 lo, hi = max(lo, rlo), min(hi, rhi)
307 if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
308 hi = lo + 1e-9
309 return np.linspace(lo, hi, grid_size)
311 grids_std_num = {j: _grid_std_num(j) for j in free_numeric_idx}
313 # --- helpers for categorical evaluation ---------------------------------
314 def _std_for_member(member_name: str, raw01: float) -> float:
315 j = name_to_idx[member_name]
316 return float(_orig_to_std(j, raw01, transforms, X_mean, X_std))
318 def _apply_onehot_for_base(Xn_block: np.ndarray, base: str, label: str) -> None:
319 # set the whole block's rows to the 0/1 standardized values for this label
320 for lab in groups[base]["labels"]:
321 member = groups[base]["name_by_label"][lab]
322 j = name_to_idx[member]
323 Xn_block[:, j] = _std_for_member(member, 1.0 if lab == label else 0.0)
325 def _denorm_inv(j: int, std_vals: np.ndarray) -> np.ndarray:
326 internal = std_vals * X_std[j] + X_mean[j]
327 return _inverse_transform(transforms[j], internal)
329 # 1) Robustly detect one-hot member columns.
330 # Use both the detector output AND a fallback "base=" prefix scan,
331 # so any columns like "language=Linear A" are guaranteed to be excluded.
332 onehot_member_names: set[str] = set()
333 for base, g in groups.items():
334 # detector-known members
335 onehot_member_names.update(g["members"])
336 # prefix fallback
337 prefix = f"{base}="
338 onehot_member_names.update([nm for nm in feature_names if nm.startswith(prefix)])
340 # 2) Build panel list: keep numeric features that are not scalar-fixed AND
341 # are not one-hot members; plus categorical bases that are not fixed.
342 free_numeric_idx = [
343 j for j, nm in enumerate(feature_names)
344 if (j not in fixed_scalars_std) and (nm not in onehot_member_names)
345 ]
346 free_cat_bases = [b for b in bases if b not in cat_fixed]
348 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases]
349 if not panels:
350 raise ValueError("All features are fixed (or only single-category categoricals remain); nothing to plot.")
352 # 3) Sanity check: no one-hot member should survive as a numeric panel.
353 assert all(
354 (feature_names[key] not in onehot_member_names) if kind == "num" else True
355 for kind, key in panels
356 ), "internal: one-hot member leaked into numeric panels"
358 # 4) Subplot scaffold (matrix layout k x k) with clear titles.
359 def _panel_title(kind: str, key: object) -> str:
360 return feature_names[int(key)] if kind == "num" else str(key)
362 k = len(panels)
363 fig = make_subplots(
364 rows=k,
365 cols=k,
366 shared_xaxes=False,
367 shared_yaxes=False,
368 horizontal_spacing=0.03,
369 vertical_spacing=0.03,
370 subplot_titles=[_panel_title(kind, key) for kind, key in panels],
371 )
373 # (Keep the rest of your cell-evaluation and rendering logic unchanged.
374 # Because we filtered `onehot_member_names`, rows/columns like
375 # "language=Linear A" / "language=Linear B" will no longer appear.
376 # Categorical bases (e.g., "language") will show as a single axis.)
379 # overlays prepared under the SAME constraints (pass original kwargs straight through)
380 optimal_df = opt.optimal(ds, count=1, seed=seed, **kwargs) if optimal else None
381 suggest_df = opt.suggest(ds, count=suggest, seed=seed, **kwargs) if (suggest and suggest > 0) else None
383 # masks for data overlays (already filtered if cat_fixed)
384 tgt_col = str(ds.attrs["target"])
385 success_mask = ~pd.isna(df_raw_f[tgt_col]).to_numpy()
386 fail_mask = ~success_mask
388 # collect Z blocks for global color bounds
389 all_blocks: list[np.ndarray] = []
390 cell_payload: dict[tuple[int,int], dict] = {}
392 # --- build each cell payload (numeric/num, cat/num, num/cat, cat/cat)
393 for r, (kind_r, key_r) in enumerate(panels):
394 for c, (kind_c, key_c) in enumerate(panels):
395 # X axis = column; Y axis = row
396 if kind_r == "num" and kind_c == "num":
397 i = int(key_r); j = int(key_c)
398 xg = grids_std_num[j]; yg = grids_std_num[i]
399 if i == j:
400 grid = grids_std_num[j]
401 Xn_1d = np.repeat(base_std[None, :], len(grid), axis=0)
402 Xn_1d[:, j] = grid
403 mu_1d, _ = pred_loss(Xn_1d, include_observation_noise=True)
404 p_1d = pred_success(Xn_1d)
405 Zmu = 0.5 * (mu_1d[:, None] + mu_1d[None, :])
406 Zp = np.minimum(p_1d[:, None], p_1d[None, :])
407 x_orig = _denorm_inv(j, grid)
408 y_orig = x_orig
409 else:
410 XX, YY = np.meshgrid(xg, yg)
411 Xn_grid = np.repeat(base_std[None, :], XX.size, axis=0)
412 Xn_grid[:, j] = XX.ravel()
413 Xn_grid[:, i] = YY.ravel()
414 mu_flat, _ = pred_loss(Xn_grid, include_observation_noise=True)
415 p_flat = pred_success(Xn_grid)
416 Zmu = mu_flat.reshape(YY.shape)
417 Zp = p_flat.reshape(YY.shape)
418 x_orig = _denorm_inv(j, xg)
419 y_orig = _denorm_inv(i, yg)
420 cell_payload[(r, c)] = dict(kind=("num","num"), i=i, j=j, x=x_orig, y=y_orig, Zmu=Zmu, Zp=Zp)
422 elif kind_r == "cat" and kind_c == "num":
423 base = str(key_r); j = int(key_c)
424 labels = list(cat_allowed.get(base, groups[base]["labels"]))
425 xg = grids_std_num[j]
426 # build rows per label
427 Zmu_rows = []; Zp_rows = []
428 for lab in labels:
429 Xn_grid = np.repeat(base_std[None, :], len(xg), axis=0)
430 Xn_grid[:, j] = xg
431 _apply_onehot_for_base(Xn_grid, base, lab)
432 mu_row, _ = pred_loss(Xn_grid, include_observation_noise=True)
433 p_row = pred_success(Xn_grid)
434 Zmu_rows.append(mu_row[None, :])
435 Zp_rows.append(p_row[None, :])
436 Zmu = np.concatenate(Zmu_rows, axis=0) # (n_labels, n_x)
437 Zp = np.concatenate(Zp_rows, axis=0)
438 x_orig = _denorm_inv(j, xg)
439 y_cats = labels # categorical ticks
440 cell_payload[(r,c)] = dict(kind=("cat","num"), base=base, j=j, x=x_orig, y=y_cats, Zmu=Zmu, Zp=Zp)
442 elif kind_r == "num" and kind_c == "cat":
443 i = int(key_r); base = str(key_c)
444 labels = list(cat_allowed.get(base, groups[base]["labels"]))
445 yg = grids_std_num[i]
446 # columns per label
447 Zmu_cols = []; Zp_cols = []
448 for lab in labels:
449 Xn_grid = np.repeat(base_std[None, :], len(yg), axis=0)
450 Xn_grid[:, i] = yg
451 _apply_onehot_for_base(Xn_grid, base, lab)
452 mu_col, _ = pred_loss(Xn_grid, include_observation_noise=True)
453 p_col = pred_success(Xn_grid)
454 Zmu_cols.append(mu_col[:, None])
455 Zp_cols.append(p_col[:, None])
456 Zmu = np.concatenate(Zmu_cols, axis=1) # (n_y, n_labels)
457 Zp = np.concatenate(Zp_cols, axis=1)
458 x_cats = labels
459 y_orig = _denorm_inv(i, yg)
460 cell_payload[(r,c)] = dict(kind=("num","cat"), i=i, base=base, x=x_cats, y=y_orig, Zmu=Zmu, Zp=Zp)
462 else: # kind_r == "cat" and kind_c == "cat"
463 base_r = str(key_r); base_c = str(key_c)
464 labels_r = list(cat_allowed.get(base_r, groups[base_r]["labels"]))
465 labels_c = list(cat_allowed.get(base_c, groups[base_c]["labels"]))
466 Z = np.zeros((len(labels_r), len(labels_c)), dtype=float)
467 P = np.zeros_like(Z)
468 # evaluate each pair
469 for rr, lab_r in enumerate(labels_r):
470 for cc, lab_c in enumerate(labels_c):
471 Xn_grid = base_std[None, :].copy()
472 _apply_onehot_for_base(Xn_grid, base_r, lab_r)
473 _apply_onehot_for_base(Xn_grid, base_c, lab_c)
474 mu_val, _ = pred_loss(Xn_grid, include_observation_noise=True)
475 p_val = pred_success(Xn_grid)
476 Z[rr, cc] = float(mu_val[0])
477 P[rr, cc] = float(p_val[0])
478 cell_payload[(r,c)] = dict(kind=("cat","cat"), x=labels_c, y=labels_r, Zmu=Z, Zp=P)
480 all_blocks.append(cell_payload[(r,c)]["Zmu"].ravel())
482 # --- color transform bounds
483 def _color_xform(z_raw: np.ndarray) -> tuple[np.ndarray, float]:
484 if not use_log_scale_for_target:
485 return z_raw, 0.0
486 zmin = float(np.nanmin(z_raw))
487 shift = 0.0 if zmin > 0 else -zmin + float(log_shift_epsilon)
488 return np.log10(np.maximum(z_raw + shift, log_shift_epsilon)), shift
490 z_all = np.concatenate(all_blocks) if all_blocks else np.array([0.0, 1.0])
491 z_all_t, global_shift = _color_xform(z_all)
492 cmin_t = float(np.nanmin(z_all_t))
493 cmax_t = float(np.nanmax(z_all_t))
494 cs = get_colorscale(colorscale)
496 def _contour_line_color(level_raw: float) -> str:
497 zt = np.log10(max(level_raw + global_shift, log_shift_epsilon)) if use_log_scale_for_target else level_raw
498 t = 0.5 if cmax_t == cmin_t else (zt - cmin_t) / (cmax_t - cmin_t)
499 rgb = sample_colorscale(cs, [float(np.clip(t, 0.0, 1.0))])[0]
500 r, g, b = _rgb_string_to_tuple(rgb)
501 lum = (0.2126*r + 0.7152*g + 0.0722*b)/255.0
502 grey = int(round((1.0 - lum) * 255))
503 return f"rgba({grey},{grey},{grey},0.9)"
505 # --- render cells
506 def _is_log_feature(j: int) -> bool: return (transforms[j] == "log10")
508 for (r, c), PAY in cell_payload.items():
509 kind = PAY["kind"]; Zmu_raw = PAY["Zmu"]; Zp = PAY["Zp"]
510 Z_t, _ = _color_xform(Zmu_raw)
512 # axes values (numeric arrays or category indices)
513 if kind == ("num","num"):
514 x_vals = PAY["x"]; y_vals = PAY["y"]
515 fig.add_trace(go.Heatmap(
516 x=x_vals, y=y_vals, z=Z_t,
517 coloraxis="coloraxis", zsmooth=False, showscale=False,
518 hovertemplate=(f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>"
519 f"{feature_names[PAY['i']]}: %{{y:.6g}}"
520 "<br>E[target|success]: %{customdata:.3f}<extra></extra>"),
521 customdata=Zmu_raw
522 ), row=r+1, col=c+1)
524 # p(success) shading + contours
525 for thr, alpha in ((0.5, 0.25), (0.8, 0.40)):
526 mask = np.where(Zp < thr, 1.0, np.nan)
527 fig.add_trace(go.Heatmap(
528 x=x_vals, y=y_vals, z=mask, zmin=0, zmax=1,
529 colorscale=[[0, "rgba(0,0,0,0)"], [1, f"rgba(128,128,128,{alpha})"]],
530 showscale=False, hoverinfo="skip"
531 ), row=r+1, col=c+1)
533 # contour lines
534 zmin_r, zmax_r = float(np.nanmin(Zmu_raw)), float(np.nanmax(Zmu_raw))
535 levels = np.linspace(zmin_r, zmax_r, max(n_contours, 2))
536 for lev in levels:
537 color = _contour_line_color(lev)
538 fig.add_trace(go.Contour(
539 x=x_vals, y=y_vals, z=Zmu_raw,
540 autocontour=False,
541 contours=dict(coloring="lines", showlabels=False, start=lev, end=lev, size=1e-9),
542 line=dict(width=1),
543 colorscale=[[0, color], [1, color]],
544 showscale=False, hoverinfo="skip"
545 ), row=r+1, col=c+1)
547 # data overlays (success/fail)
548 def _data_vals_for_feature(j_full: int) -> np.ndarray:
549 nm = feature_names[j_full]
550 if nm in df_raw_f.columns:
551 return df_raw_f[nm].to_numpy(dtype=float)
552 vals = feature_raw_from_artifact_or_reconstruct(ds, j_full, nm, transforms[j_full]).astype(float)
553 return vals[row_mask] if row_mask_active else vals
555 xd = _data_vals_for_feature(PAY["j"])
556 yd = _data_vals_for_feature(PAY["i"])
557 show_leg = (r == 0 and c == 0)
558 fig.add_trace(go.Scattergl(
559 x=xd[success_mask], y=yd[success_mask], mode="markers",
560 marker=dict(size=4, color="black", line=dict(width=0)),
561 name="data (success)", legendgroup="data_succ", showlegend=show_leg,
562 hovertemplate=("trial_id: %{customdata[0]}<br>"
563 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>"
564 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>"
565 f"{tgt_col}: %{{customdata[1]:.4f}}<extra></extra>"),
566 customdata=np.column_stack([
567 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask],
568 df_raw_f[tgt_col].to_numpy()[success_mask],
569 ])
570 ), row=r+1, col=c+1)
571 fig.add_trace(go.Scattergl(
572 x=xd[fail_mask], y=yd[fail_mask], mode="markers",
573 marker=dict(size=5, color="red", line=dict(color="black", width=0.8)),
574 name="data (failed)", legendgroup="data_fail", showlegend=show_leg,
575 hovertemplate=("trial_id: %{customdata}<br>"
576 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>"
577 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>"
578 "status: failed (NaN target)<extra></extra>"),
579 customdata=df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask]
580 ), row=r+1, col=c+1)
582 # overlays (optimal/suggest) on numeric axes only
583 if optimal and (optimal_df is not None):
584 if feature_names[PAY["j"]] in optimal_df.columns and feature_names[PAY["i"]] in optimal_df.columns:
585 ox = np.asarray(optimal_df[feature_names[PAY["j"]]].values, dtype=float)
586 oy = np.asarray(optimal_df[feature_names[PAY["i"]]].values, dtype=float)
587 if np.isfinite(ox).all() and np.isfinite(oy).all():
588 pmu = float(optimal_df["pred_target_mean"].values[0])
589 psd = float(optimal_df["pred_target_sd"].values[0])
590 fig.add_trace(go.Scattergl(
591 x=ox, y=oy, mode="markers",
592 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"),
593 name="optimal", legendgroup="optimal", showlegend=(r == 0 and c == 0),
594 hovertemplate=(f"predicted: {pmu:.2g} ± {psd:.2g}<br>"
595 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>"
596 f"{feature_names[PAY['i']]}: %{{y:.6g}}<extra></extra>")
597 ), row=r+1, col=c+1)
598 if suggest and (suggest_df is not None):
599 have = (feature_names[PAY["j"]] in suggest_df.columns) and (feature_names[PAY["i"]] in suggest_df.columns)
600 if have:
601 sx = np.asarray(suggest_df[feature_names[PAY["j"]]].values, dtype=float)
602 sy = np.asarray(suggest_df[feature_names[PAY["i"]]].values, dtype=float)
603 keep_s = np.isfinite(sx) & np.isfinite(sy)
604 if keep_s.any():
605 sx, sy = sx[keep_s], sy[keep_s]
606 mu_s = suggest_df.loc[keep_s, "pred_target_mean"].values if "pred_target_mean" in suggest_df else None
607 sd_s = suggest_df.loc[keep_s, "pred_target_sd"].values if "pred_target_sd" in suggest_df else None
608 ps_s = suggest_df.loc[keep_s, "pred_p_success"].values if "pred_p_success" in suggest_df else None
609 if (mu_s is not None) and (sd_s is not None) and (ps_s is not None):
610 custom_s = np.column_stack([mu_s, sd_s, ps_s])
611 hover_s = (
612 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>"
613 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>"
614 "pred: %{customdata[0]:.3g} ± %{customdata[1]:.3g}<br>"
615 "p(success): %{customdata[2]:.2f}<extra>suggested</extra>"
616 )
617 else:
618 custom_s = None
619 hover_s = (
620 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>"
621 f"{feature_names[PAY['i']]}: %{{y:.6g}}<extra>suggested</extra>"
622 )
623 fig.add_trace(go.Scattergl(
624 x=sx, y=sy, mode="markers",
625 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"),
626 name="suggested", legendgroup="suggested",
627 showlegend=(r == 0 and c == 0),
628 customdata=custom_s, hovertemplate=hover_s
629 ), row=r+1, col=c+1)
631 # axis types/ranges
632 _update_axis_type_and_range(fig, row=r+1, col=c+1, axis="x", centers=x_vals, is_log=_is_log_feature(PAY["j"]))
633 _update_axis_type_and_range(fig, row=r+1, col=c+1, axis="y", centers=y_vals, is_log=_is_log_feature(PAY["i"]))
635 elif kind == ("cat","num"):
636 base = PAY["base"]; x_vals = PAY["x"]; labels = PAY["y"]
637 nlab = len(labels)
638 # heatmap (categories on Y)
639 fig.add_trace(go.Heatmap(
640 x=x_vals, y=np.arange(nlab), z=Z_t,
641 coloraxis="coloraxis", zsmooth=False, showscale=False,
642 hovertemplate=(f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>"
643 f"{base}: %{{text}}"
644 "<br>E[target|success]: %{customdata:.3f}<extra></extra>"),
645 text=np.array(labels)[:, None].repeat(len(x_vals), axis=1),
646 customdata=Zmu_raw
647 ), row=r+1, col=c+1)
648 # p(success) shading
649 for thr, alpha in ((0.5, 0.25), (0.8, 0.40)):
650 mask = np.where(Zp < thr, 1.0, np.nan)
651 fig.add_trace(go.Heatmap(
652 x=x_vals, y=np.arange(nlab), z=mask, zmin=0, zmax=1,
653 colorscale=[[0, "rgba(0,0,0,0)"], [1, f"rgba(128,128,128,{alpha})"]],
654 showscale=False, hoverinfo="skip"
655 ), row=r+1, col=c+1)
656 # categorical ticks
657 fig.update_yaxes(tickmode="array", tickvals=list(range(nlab)), ticktext=labels, row=r+1, col=c+1)
658 # data overlays: numeric vs categorical with jitter on Y
659 if base in df_raw_f.columns and feature_names[PAY["j"]] in df_raw_f.columns:
660 cat_series = df_raw_f[base].astype("string")
661 cat_to_idx = {lab: i for i, lab in enumerate(labels)}
662 y_map = cat_series.map(cat_to_idx)
663 ok = y_map.notna().to_numpy()
664 y_idx = y_map.to_numpy(dtype=float)
665 jitter = 0.10 * (np.random.default_rng(0).standard_normal(size=len(y_idx)))
666 yj = y_idx + jitter
667 xd = df_raw_f[feature_names[PAY["j"]]].to_numpy(dtype=float)
668 show_leg = (r == 0 and c == 0)
669 fig.add_trace(go.Scattergl(
670 x=xd[success_mask & ok], y=yj[success_mask & ok], mode="markers",
671 marker=dict(size=4, color="black", line=dict(width=0)),
672 name="data (success)", legendgroup="data_succ", showlegend=show_leg,
673 hovertemplate=("trial_id: %{customdata[0]}<br>"
674 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>"
675 f"{base}: %{{customdata[1]}}<br>"
676 f"{tgt_col}: %{{customdata[2]:.4f}}<extra></extra>"),
677 customdata=np.column_stack([
678 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask & ok],
679 cat_series.to_numpy()[success_mask & ok],
680 df_raw_f[tgt_col].to_numpy()[success_mask & ok],
681 ])
682 ), row=r+1, col=c+1)
683 fig.add_trace(go.Scattergl(
684 x=xd[fail_mask & ok], y=yj[fail_mask & ok], mode="markers",
685 marker=dict(size=5, color="red", line=dict(color="black", width=0.8)),
686 name="data (failed)", legendgroup="data_fail", showlegend=show_leg,
687 hovertemplate=("trial_id: %{customdata[0]}<br>"
688 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>"
689 f"{base}: %{{customdata[1]}}<br>"
690 "status: failed (NaN target)<extra></extra>"),
691 customdata=np.column_stack([
692 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask & ok],
693 cat_series.to_numpy()[fail_mask & ok],
694 ])
695 ), row=r+1, col=c+1)
696 # axes: x numeric; y categorical range
697 _update_axis_type_and_range(fig, row=r+1, col=c+1, axis="x", centers=x_vals, is_log=_is_log_feature(PAY["j"]))
698 fig.update_yaxes(range=[-0.5, nlab - 0.5], row=r+1, col=c+1)
700 elif kind == ("num","cat"):
701 base = PAY["base"]; y_vals = PAY["y"]; labels = PAY["x"]
702 nlab = len(labels)
703 # heatmap (categories on X)
704 fig.add_trace(go.Heatmap(
705 x=np.arange(nlab), y=y_vals, z=Z_t,
706 coloraxis="coloraxis", zsmooth=False, showscale=False,
707 hovertemplate=(f"{base}: %{{text}}<br>"
708 f"{feature_names[PAY['i']]}: %{{y:.6g}}"
709 "<br>E[target|success]: %{customdata:.3f}<extra></extra>"),
710 text=np.array(labels)[None, :].repeat(len(y_vals), axis=0),
711 customdata=Zmu_raw
712 ), row=r+1, col=c+1)
713 for thr, alpha in ((0.5, 0.25), (0.8, 0.40)):
714 mask = np.where(Zp < thr, 1.0, np.nan)
715 fig.add_trace(go.Heatmap(
716 x=np.arange(nlab), y=y_vals, z=mask, zmin=0, zmax=1,
717 colorscale=[[0, "rgba(0,0,0,0)"], [1, f"rgba(128,128,128,{alpha})"]],
718 showscale=False, hoverinfo="skip"
719 ), row=r+1, col=c+1)
720 fig.update_xaxes(tickmode="array", tickvals=list(range(nlab)), ticktext=labels, row=r+1, col=c+1)
721 # data overlays with jitter on X
722 if base in df_raw_f.columns and feature_names[PAY["i"]] in df_raw_f.columns:
723 cat_series = df_raw_f[base].astype("string")
724 cat_to_idx = {lab: i for i, lab in enumerate(labels)}
725 x_map = cat_series.map(cat_to_idx)
726 ok = x_map.notna().to_numpy()
727 x_idx = x_map.to_numpy(dtype=float)
728 jitter = 0.10 * (np.random.default_rng(0).standard_normal(size=len(x_idx)))
729 xj = x_idx + jitter
730 yd = df_raw_f[feature_names[PAY["i"]]].to_numpy(dtype=float)
731 show_leg = (r == 0 and c == 0)
732 fig.add_trace(go.Scattergl(
733 x=xj[success_mask & ok], y=yd[success_mask & ok], mode="markers",
734 marker=dict(size=4, color="black", line=dict(width=0)),
735 name="data (success)", legendgroup="data_succ", showlegend=show_leg,
736 hovertemplate=("trial_id: %{customdata[0]}<br>"
737 f"{base}: %{{customdata[1]}}<br>"
738 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>"
739 f"{tgt_col}: %{{customdata[2]:.4f}}<extra></extra>"),
740 customdata=np.column_stack([
741 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask & ok],
742 cat_series.to_numpy()[success_mask & ok],
743 df_raw_f[tgt_col].to_numpy()[success_mask & ok],
744 ])
745 ), row=r+1, col=c+1)
746 fig.add_trace(go.Scattergl(
747 x=xj[fail_mask & ok], y=yd[fail_mask & ok], mode="markers",
748 marker=dict(size=5, color="red", line=dict(color="black", width=0.8)),
749 name="data (failed)", legendgroup="data_fail", showlegend=show_leg,
750 hovertemplate=("trial_id: %{customdata[0]}<br>"
751 f"{base}: %{{customdata[1]}}<br>"
752 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>"
753 "status: failed (NaN target)<extra></extra>"),
754 customdata=np.column_stack([
755 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask & ok],
756 cat_series.to_numpy()[fail_mask & ok],
757 ])
758 ), row=r+1, col=c+1)
759 # axes: x categorical; y numeric
760 fig.update_xaxes(range=[-0.5, nlab - 0.5], row=r+1, col=c+1)
761 _update_axis_type_and_range(fig, row=r+1, col=c+1, axis="y", centers=y_vals, is_log=_is_log_feature(PAY["i"]))
763 elif kind == ("cat","cat"):
764 labels_y = PAY["y"]
765 labels_x = PAY["x"]
766 ny, nx = len(labels_y), len(labels_x)
768 # Build customdata carrying (row_label, col_label) for hovertemplate.
769 custom = np.dstack((
770 np.array(labels_y, dtype=object)[:, None].repeat(nx, axis=1),
771 np.array(labels_x, dtype=object)[None, :].repeat(ny, axis=0),
772 ))
774 # Heatmap over categorical indices
775 fig.add_trace(go.Heatmap(
776 x=np.arange(nx),
777 y=np.arange(ny),
778 z=Z_t,
779 coloraxis="coloraxis",
780 zsmooth=False,
781 showscale=False,
782 hovertemplate=(
783 "row: %{customdata[0]}<br>"
784 "col: %{customdata[1]}<br>"
785 "E[target|success]: %{z:.3f}<extra></extra>"
786 ),
787 customdata=custom,
788 ), row=r+1, col=c+1)
790 # p(success) shading overlays
791 for thr, alpha in ((0.5, 0.25), (0.8, 0.40)):
792 mask = np.where(Zp < thr, 1.0, np.nan)
793 fig.add_trace(go.Heatmap(
794 x=np.arange(nx),
795 y=np.arange(ny),
796 z=mask,
797 zmin=0,
798 zmax=1,
799 colorscale=[[0, "rgba(0,0,0,0)"], [1, f"rgba(128,128,128,{alpha})"]],
800 showscale=False,
801 hoverinfo="skip",
802 ), row=r+1, col=c+1)
804 # Categorical tick labels on both axes
805 fig.update_xaxes(
806 tickmode="array",
807 tickvals=list(range(nx)),
808 ticktext=labels_x,
809 range=[-0.5, nx - 0.5],
810 row=r+1,
811 col=c+1,
812 )
813 fig.update_yaxes(
814 tickmode="array",
815 tickvals=list(range(ny)),
816 ticktext=labels_y,
817 range=[-0.5, ny - 0.5],
818 row=r+1,
819 col=c+1,
820 )
822 # --- outer axis labels
823 def _panel_title(kind: str, key: object) -> str:
824 return feature_names[int(key)] if kind == "num" else str(key)
826 for c, (_, key_c) in enumerate(panels):
827 fig.update_xaxes(title_text=_panel_title(panels[c][0], key_c), row=k, col=c+1)
828 for r, (kind_r, key_r) in enumerate(panels):
829 fig.update_yaxes(title_text=_panel_title(kind_r, key_r), row=r+1, col=1)
831 # --- title
832 def _fmt_c(v):
833 if isinstance(v, slice):
834 a = f"{v.start:g}" if v.start is not None else ""
835 b = f"{v.stop:g}" if v.stop is not None else ""
836 return f"[{a},{b}]"
837 if isinstance(v, (list, tuple, np.ndarray)):
838 try:
839 return "[" + ",".join(f"{float(x):g}" for x in np.asarray(v).tolist()) + "]"
840 except Exception:
841 return "[" + ",".join(map(str, v)) + "]"
842 return str(v)
844 title_parts = [f"2D partial dependence of expected {tgt_col}"]
846 # numeric constraints shown
847 for name, val in kw_num.items():
848 title_parts.append(f"{name}={_fmt_c(val)}")
849 # categorical constraints: fixed shown as base=Label; allowed ranges omitted in title
850 for base, lab in cat_fixed.items():
851 title_parts.append(f"{base}={lab}")
852 title = " — ".join([title_parts[0], ", ".join(title_parts[1:])]) if len(title_parts) > 1 else title_parts[0]
854 # --- layout
855 cell = 250
856 z_title = "E[target|success]" + (" (log10)" if use_log_scale_for_target else "")
857 if use_log_scale_for_target and global_shift > 0:
858 z_title += f" (shift Δ={global_shift:.3g})"
860 width = width if (width and width > 0) else cell * k
861 width = max(width, 400)
862 height = height if (height and height > 0) else cell * k
863 height = max(height, 400)
865 fig.update_layout(
866 template="simple_white",
867 width=width,
868 height=height,
869 title=title,
870 legend_title_text="",
871 coloraxis=dict(
872 colorscale=colorscale,
873 cmin=cmin_t, cmax=cmax_t,
874 colorbar=dict(
875 title=z_title,
876 thickness=10, # thinner bar
877 len=0.55, # shorter bar (fraction of plot height)
878 lenmode="fraction",
879 x=1.02, y=0.5, # just right of plot, vertically centered
880 xanchor="left", yanchor="middle",
881 ),
882 ),
883 legend=dict(
884 orientation="v",
885 x=1.02, xanchor="left", # to the right of the colorbar
886 y=1.0, yanchor="top",
887 bgcolor="rgba(255,255,255,0.85)"
888 ),
889 margin=dict(t=90, r=100), # room for title + legend + colorbar
890 )
892 if output:
893 write_image(fig, output)
894 if show:
895 fig.show("browser")
896 return fig
899def plot1d(
900 model: xr.Dataset | Path | str,
901 output: Path | None = None,
902 csv_out: Path | None = None,
903 grid_size: int = 300,
904 line_color: str = "rgb(31,119,180)",
905 band_alpha: float = 0.25,
906 show: bool = False,
907 use_log_scale_for_target_y: bool = True, # log-y for target
908 log_y_epsilon: float = 1e-9,
909 optimal: bool = True,
910 suggest: int = 0,
911 width:int|None = None,
912 height:int|None = None,
913 seed: int|None = 42,
914 **kwargs,
915) -> go.Figure:
916 """
917 Vertical 1D PD panels of E[target|success] vs each *free* feature.
918 Scalars (fix & hide), slices (restrict sweep & x-range), lists/tuples (discrete grids).
919 Categorical bases (e.g. language) are plotted as a single categorical subplot
920 when not fixed; passing --language "Linear A" fixes that base and removes it
921 from the plotted axes.
922 """
923 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model)
924 pred_success, pred_loss = _build_predictors(ds)
926 feature_names = [str(n) for n in ds["feature"].values.tolist()]
927 transforms = [str(t) for t in ds["feature_transform"].values.tolist()]
928 X_mean = ds["feature_mean"].values.astype(float)
929 X_std = ds["feature_std"].values.astype(float)
931 df_raw = _raw_dataframe_from_dataset(ds)
932 Xn_train = ds["Xn_train"].values.astype(float)
933 n_rows, p = Xn_train.shape
935 # --- one-hot categorical groups ---
936 groups = opt._onehot_groups(feature_names) # { base: {"labels":[...], "name_by_label":{label:member}, "members":[...]} }
937 bases = set(groups.keys())
938 name_to_idx = {name: j for j, name in enumerate(feature_names)}
940 # --- canonicalize kwargs: numeric vs categorical (base) ---
941 idx_map = _canon_key_set(ds)
942 kw_num_raw: dict[str, object] = {}
943 kw_cat_raw: dict[str, object] = {}
944 for k, v in kwargs.items():
945 if k in bases:
946 kw_cat_raw[k] = v
947 continue
948 if k in idx_map:
949 kw_num_raw[idx_map[k]] = v
950 continue
951 import re as _re
952 nk = _re.sub(r"[^a-z0-9]+", "", str(k).lower())
953 if nk in idx_map:
954 kw_num_raw[idx_map[nk]] = v
956 # --- resolve categorical constraints: fixed (single) vs allowed (multiple) ---
957 cat_fixed: dict[str, str] = {}
958 cat_allowed: dict[str, list[str]] = {}
959 for base, val in kw_cat_raw.items():
960 labels = groups[base]["labels"]
961 if isinstance(val, str):
962 if val not in labels:
963 raise ValueError(f"Unknown category for {base!r}: {val!r}. Choices: {labels}")
964 cat_fixed[base] = val
965 elif isinstance(val, (list, tuple, set)):
966 chosen = [x for x in val if isinstance(x, str) and x in labels]
967 if not chosen:
968 raise ValueError(f"No valid categories for {base!r} in {val!r}. Choices: {labels}")
969 # multiple -> treat as allowed subset (NOT fixed)
970 cat_allowed[base] = list(dict.fromkeys(chosen))
971 else:
972 raise ValueError(f"Categorical constraint for {base!r} must be a string or list/tuple of strings.")
974 # --- filter rows by fixed categoricals (affects medians/percentiles & overlays) ---
975 row_mask = np.ones(n_rows, dtype=bool)
976 for base, label in cat_fixed.items():
977 if base in df_raw.columns:
978 row_mask &= (df_raw[base].astype("string") == pd.Series([label]*len(df_raw), dtype="string")).to_numpy()
979 else:
980 member_name = groups[base]["name_by_label"][label]
981 j = name_to_idx[member_name]
982 raw_j = feature_raw_from_artifact_or_reconstruct(ds, j, member_name, transforms[j]).astype(float)
983 row_mask &= (raw_j >= 0.5)
985 # --- helpers to transform original <-> standardized for feature j ---
986 def _orig_to_std(j: int, x, transforms, mu, sd):
987 x = np.asarray(x, dtype=float)
988 if transforms[j] == "log10":
989 x = np.where(x <= 0, np.nan, x)
990 x = np.log10(x)
991 return (x - mu[j]) / sd[j]
993 # --- numeric constraint split (STANDARDIZED) ---
994 fixed_scalars: dict[int, float] = {}
995 range_windows: dict[int, tuple[float, float]] = {}
996 choice_values: dict[int, np.ndarray] = {}
997 for name, val in kw_num_raw.items():
998 if name not in name_to_idx:
999 continue
1000 j = name_to_idx[name]
1001 if isinstance(val, slice):
1002 lo = _orig_to_std(j, float(val.start), transforms, X_mean, X_std)
1003 hi = _orig_to_std(j, float(val.stop), transforms, X_mean, X_std)
1004 lo, hi = float(min(lo, hi)), float(max(lo, hi))
1005 range_windows[j] = (lo, hi)
1006 elif isinstance(val, (list, tuple, np.ndarray)):
1007 arr = _orig_to_std(j, np.asarray(val, dtype=float), transforms, X_mean, X_std)
1008 choice_values[j] = np.asarray(arr, dtype=float)
1009 else:
1010 fixed_scalars[j] = float(_orig_to_std(j, float(val), transforms, X_mean, X_std))
1012 # --- apply categorical fixed as standardized scalar fixes on each one-hot member ---
1013 for base, label in cat_fixed.items():
1014 labels = groups[base]["labels"]
1015 for lab in labels:
1016 member_name = groups[base]["name_by_label"][lab]
1017 j = name_to_idx[member_name]
1018 raw_val = 1.0 if lab == label else 0.0
1019 fixed_scalars[j] = float(_orig_to_std(j, raw_val, transforms, X_mean, X_std))
1021 # --- enforce row-level filters for categorical allowed sets and numeric constraints ---
1022 for base, allowed in cat_allowed.items():
1023 if base in df_raw.columns:
1024 series = df_raw[base].astype("string").fillna("<NA>")
1025 allowed_set = {str(x) for x in allowed}
1026 allowed_mask = series.isin(allowed_set).fillna(False).to_numpy()
1027 row_mask &= allowed_mask
1028 else:
1029 allowed_masks = []
1030 for label in allowed:
1031 member_name = groups[base]["name_by_label"].get(label)
1032 if member_name is None:
1033 continue
1034 j = name_to_idx[member_name]
1035 raw_j = feature_raw_from_artifact_or_reconstruct(ds, j, member_name, transforms[j]).astype(float)
1036 allowed_masks.append(raw_j >= 0.5)
1037 if allowed_masks:
1038 combined = np.logical_or.reduce(allowed_masks)
1039 row_mask &= combined
1040 else:
1041 row_mask &= False
1043 for name, val in kw_num_raw.items():
1044 if name not in name_to_idx:
1045 continue
1046 j = name_to_idx[name]
1047 if name in df_raw.columns:
1048 raw_vals = pd.to_numeric(df_raw[name], errors="coerce").to_numpy(dtype=float)
1049 else:
1050 raw_vals = feature_raw_from_artifact_or_reconstruct(ds, j, feature_names[j], transforms[j]).astype(float)
1052 mask = np.ones_like(row_mask, dtype=bool)
1053 if isinstance(val, slice):
1054 lo_raw = -np.inf if val.start is None else float(val.start)
1055 hi_raw = np.inf if val.stop is None else float(val.stop)
1056 if hi_raw < lo_raw:
1057 lo_raw, hi_raw = hi_raw, lo_raw
1058 mask &= (raw_vals >= lo_raw) & (raw_vals <= hi_raw)
1059 elif isinstance(val, (list, tuple, set, np.ndarray)):
1060 arr = np.asarray(list(val) if not isinstance(val, np.ndarray) else val, dtype=float)
1061 arr = arr[np.isfinite(arr)]
1062 if arr.size == 0:
1063 mask &= False
1064 else:
1065 mask &= np.any(np.isclose(raw_vals[:, None], arr[None, :], rtol=1e-6, atol=1e-9), axis=1)
1066 else:
1067 target = float(val)
1068 mask &= np.isclose(raw_vals, target, rtol=1e-6, atol=1e-9)
1070 row_mask &= mask
1072 if not np.any(row_mask):
1073 raise ValueError("No experiments match the provided constraints; cannot plot data points.")
1075 df_raw_f = df_raw.loc[row_mask].reset_index(drop=True)
1076 Xn_train_f = Xn_train[row_mask, :]
1078 # --- overlays conditioned on the same kwargs (numeric + categorical) ---
1079 optimal_df = opt.optimal(model, count=1, seed=seed, **kwargs) if optimal else None
1080 suggest_df = opt.suggest(model, count=suggest, seed=seed, **kwargs) if (suggest and suggest > 0) else None
1082 # --- base standardized point (median over filtered rows), then apply scalar fixes ---
1083 base_std = np.median(Xn_train_f, axis=0)
1084 for j, vstd in fixed_scalars.items():
1085 base_std[j] = vstd
1087 # --- plotted panels: numeric free features + categorical bases not fixed ---
1088 onehot_members = set()
1089 for base, g in groups.items():
1090 onehot_members.update(g["members"])
1091 free_numeric_idx = [j for j in range(p) if (j not in fixed_scalars) and (feature_names[j] not in onehot_members)]
1092 free_cat_bases = [b for b in bases if b not in cat_fixed] # optional: filtered by cat_allowed later
1094 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases]
1095 if not panels:
1096 raise ValueError("All features are fixed (or categorical only with single category chosen); nothing to plot.")
1098 # --- empirical 1–99% from filtered rows for numeric bounds ---
1099 p01p99 = [np.percentile(Xn_train_f[:, j], [1, 99]) for j in range(p)]
1100 def _grid_1d(j: int, n: int) -> np.ndarray:
1101 p01, p99 = p01p99[j]
1102 if j in choice_values:
1103 vals = np.asarray(choice_values[j], dtype=float)
1104 vals = vals[(vals >= p01) & (vals <= p99)]
1105 return np.unique(np.sort(vals)) if vals.size else np.array([np.median(Xn_train_f[:, j])], dtype=float)
1106 lo, hi = p01, p99
1107 if j in range_windows:
1108 rlo, rhi = range_windows[j]
1109 lo, hi = max(lo, rlo), min(hi, rhi)
1110 if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
1111 lo, hi = p01, max(p01 + 1e-9, p99)
1112 return np.linspace(lo, hi, n)
1114 # --- one-hot member names (robust) ---
1115 onehot_member_names: set[str] = set()
1116 for base, g in groups.items():
1117 # names recorded by the detector
1118 onehot_member_names.update(g["members"])
1119 # fallback pattern match in case detector missed anything
1120 prefix = f"{base}="
1121 onehot_member_names.update([nm for nm in feature_names if nm.startswith(prefix)])
1123 # --- build panel list: numeric free features + categorical bases (not fixed) ---
1124 free_numeric_idx = [
1125 j for j, nm in enumerate(feature_names)
1126 if (j not in fixed_scalars) and (nm not in onehot_member_names)
1127 ]
1128 free_cat_bases = [b for b in bases if b not in cat_fixed]
1130 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases]
1131 if not panels:
1132 raise ValueError("All features are fixed (or only single-category categoricals remain); nothing to plot.")
1134 # sanity: ensure we didn't accidentally keep any one-hot member columns
1135 assert all(
1136 (feature_names[key] not in onehot_member_names) if kind == "num" else True
1137 for kind, key in panels
1138 ), "internal: one-hot member leaked into numeric panels"
1140 # --- figure scaffold with clean titles ---
1141 def _panel_title(kind: str, key: object) -> str:
1142 return feature_names[int(key)] if kind == "num" else str(key)
1144 fig = make_subplots(
1145 rows=len(panels),
1146 cols=1,
1147 shared_xaxes=False,
1148 )
1150 # --- masks/data from filtered rows ---
1151 tgt_col = str(ds.attrs["target"])
1152 success_mask = ~pd.isna(df_raw_f[tgt_col]).to_numpy()
1153 fail_mask = ~success_mask
1154 losses_success = df_raw_f.loc[success_mask, tgt_col].to_numpy().astype(float)
1155 trial_ids_success = df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask]
1156 trial_ids_fail = df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask]
1157 band_fill_rgba = _rgb_to_rgba(line_color, band_alpha)
1159 tidy_rows: list[dict] = []
1161 row_pos = 0
1162 for kind, key in panels:
1163 row_pos += 1
1165 if kind == "num":
1166 j = key
1167 if feature_names[j] in df_raw.columns:
1168 series_full = pd.to_numeric(df_raw[feature_names[j]], errors="coerce")
1169 x_full_raw = series_full.to_numpy(dtype=float)
1170 else:
1171 x_full_raw = feature_raw_from_artifact_or_reconstruct(
1172 ds, j, feature_names[j], transforms[j]
1173 ).astype(float)
1174 x_data_all = x_full_raw[row_mask]
1176 finite_raw = x_full_raw[np.isfinite(x_full_raw)]
1177 if transforms[j] == "log10":
1178 finite_raw = finite_raw[finite_raw > 0]
1180 grid = _grid_1d(j, grid_size)
1181 if (j not in range_windows) and (j not in choice_values) and finite_raw.size:
1182 finite_std = _orig_to_std(j, finite_raw, transforms, X_mean, X_std)
1183 grid_min = float(np.nanmin(np.concatenate([grid, finite_std])))
1184 grid_max = float(np.nanmax(np.concatenate([grid, finite_std])))
1185 if grid_max > grid_min:
1186 grid = np.linspace(grid_min, grid_max, grid_size)
1187 else:
1188 grid = np.array([grid_min], dtype=float)
1190 Xn_grid = np.repeat(base_std[None, :], len(grid), axis=0)
1191 Xn_grid[:, j] = grid
1193 # # --- DEBUG: confirm the feature is actually changing in standardized space ---
1194 # print(f"[{feature_names[j]}] std grid head: {grid[:6]}")
1195 # print(f"[{feature_names[j]}] std grid ptp (range): {np.ptp(grid)}")
1196 # print(f"[{feature_names[j]}] Xn_grid[:2, j]: {Xn_grid[:2, j]}")
1197 # print(f"[{feature_names[j]}] Xn 1–99%: {p01p99[j]}")
1199 p_grid = pred_success(Xn_grid)
1200 mu_grid, sd_grid = pred_loss(Xn_grid, include_observation_noise=True)
1201 # print(feature_names[j], "mu range:", float(np.ptp(mu_grid)))
1203 x_internal = grid * X_std[j] + X_mean[j]
1204 x_display = _inverse_transform(transforms[j], x_internal)
1206 # print(f"[{feature_names[j]}] orig head: {x_display[:6]}")
1207 # print(f"[{feature_names[j]}] orig ptp (range): {np.ptp(x_display)}")
1209 if use_log_scale_for_target_y:
1210 mu_plot = np.maximum(mu_grid, log_y_epsilon)
1211 lo_plot = np.maximum(mu_grid - 2.0 * sd_grid, log_y_epsilon)
1212 hi_plot = np.maximum(mu_grid + 2.0 * sd_grid, log_y_epsilon)
1213 losses_s_plot = np.maximum(losses_success, log_y_epsilon) if losses_success.size else losses_success
1214 else:
1215 mu_plot = mu_grid
1216 lo_plot = mu_grid - 2.0 * sd_grid
1217 hi_plot = mu_grid + 2.0 * sd_grid
1218 losses_s_plot = losses_success
1220 y_arrays = [lo_plot, hi_plot] + ([losses_s_plot] if losses_s_plot.size else [])
1221 y_low = float(np.nanmin([np.nanmin(a) for a in y_arrays]))
1222 y_high = float(np.nanmax([np.nanmax(a) for a in y_arrays]))
1223 pad = 0.05 * (y_high - y_low + 1e-12)
1224 y0_plot = (y_low - pad) if not use_log_scale_for_target_y else max(y_low / 1.5, log_y_epsilon)
1225 y1_tmp = (y_high + pad) if not use_log_scale_for_target_y else y_high * 1.2
1226 y_failed_band = y1_tmp + (y_high - y_low + 1e-12) * (0.08 if not use_log_scale_for_target_y else 0.3)
1227 if use_log_scale_for_target_y and y_failed_band <= log_y_epsilon:
1228 y_failed_band = max(10.0 * log_y_epsilon, y_high * 2.0)
1229 y1_plot = y_failed_band + (0.02 if not use_log_scale_for_target_y else 0.05) * (y_high - y_low + 1e-12)
1231 _add_low_success_shading_1d(fig, row_pos, x_display, p_grid, y0_plot, y1_plot)
1233 show_legend = (row_pos == 1)
1234 fig.add_trace(go.Scatter(x=x_display, y=lo_plot, mode="lines",
1235 line=dict(width=0, color=line_color),
1236 name="±2σ", legendgroup="band", showlegend=False, hoverinfo="skip"),
1237 row=row_pos, col=1)
1238 fig.add_trace(go.Scatter(x=x_display, y=hi_plot, mode="lines", fill="tonexty",
1239 line=dict(width=0, color=line_color), fillcolor=band_fill_rgba,
1240 name="±2σ", legendgroup="band", showlegend=show_legend,
1241 hovertemplate="E[target|success]: %{y:.3f}<extra>±2σ</extra>"),
1242 row=row_pos, col=1)
1243 fig.add_trace(go.Scatter(x=x_display, y=mu_plot, mode="lines",
1244 line=dict(width=2, color=line_color),
1245 name="E[target|success]", legendgroup="mean", showlegend=show_legend,
1246 hovertemplate=f"{feature_names[j]}: %{{x:.6g}}<br>E[target|success]: %{{y:.3f}}<extra></extra>"),
1247 row=row_pos, col=1)
1249 # experimental points
1250 x_succ = x_data_all[success_mask]
1251 if x_succ.size:
1252 fig.add_trace(go.Scattergl(
1253 x=x_succ, y=losses_s_plot, mode="markers",
1254 marker=dict(size=5, color="black", line=dict(width=0)),
1255 name="data (success)", legendgroup="data_s", showlegend=show_legend,
1256 hovertemplate=("trial_id: %{customdata}<br>"
1257 f"{feature_names[j]}: %{{x:.6g}}<br>"
1258 f"{tgt_col}: %{{y:.4f}}<extra></extra>"),
1259 customdata=trial_ids_success
1260 ), row=row_pos, col=1)
1262 x_fail = x_data_all[fail_mask]
1263 if x_fail.size:
1264 y_fail_plot = np.full_like(x_fail, y_failed_band, dtype=float)
1265 fig.add_trace(go.Scattergl(
1266 x=x_fail, y=y_fail_plot, mode="markers",
1267 marker=dict(size=6, color="red", line=dict(color="black", width=0.8)),
1268 name="data (failed)", legendgroup="data_f", showlegend=show_legend,
1269 hovertemplate=("trial_id: %{customdata}<br>"
1270 f"{feature_names[j]}: %{{x:.6g}}<br>"
1271 "status: failed (NaN target)<extra></extra>"),
1272 customdata=trial_ids_fail
1273 ), row=row_pos, col=1)
1275 # overlays
1276 if optimal_df is not None and feature_names[j] in optimal_df.columns:
1277 x_opt = optimal_df[feature_names[j]].values
1278 y_opt = optimal_df["pred_target_mean"].values
1279 y_sd = optimal_df["pred_target_sd"].values
1280 fig.add_trace(go.Scattergl(
1281 x=x_opt, y=y_opt, mode="markers",
1282 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"),
1283 name="optimal", legendgroup="optimal", showlegend=show_legend,
1284 hovertemplate=(f"predicted: %{{y:.3g}} ± {y_sd[0]:.3g}<br>"
1285 f"{feature_names[j]}: %{{x:.6g}}<extra></extra>")
1286 ), row=row_pos, col=1)
1288 if suggest_df is not None and feature_names[j] in suggest_df.columns:
1289 x_sug = suggest_df[feature_names[j]].values
1290 y_sug = suggest_df["pred_target_mean"].values
1291 y_sd = suggest_df["pred_target_sd"].values
1292 fig.add_trace(go.Scattergl(
1293 x=x_sug, y=y_sug, mode="markers",
1294 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"),
1295 name="suggested", legendgroup="suggested", showlegend=show_legend,
1296 hovertemplate=(f"predicted: %{{y:.3g}} ± {{y_sd:.3g}}<br>"
1297 f"{feature_names[j]}: %{{x:.6g}}<extra></extra>")
1298 ), row=row_pos, col=1)
1300 # axes
1301 _maybe_log_axis(fig, row_pos, 1, feature_names[j], axis="x", transforms=transforms, j=j)
1302 fig.update_yaxes(title_text=f"{tgt_col}", row=row_pos, col=1)
1303 _set_yaxis_range(fig, row=row_pos, col=1,
1304 y0=y0_plot, y1=y1_plot,
1305 log=use_log_scale_for_target_y, eps=log_y_epsilon)
1307 # restrict x-range if constrained
1308 is_log_x = (transforms[j] == "log10")
1309 def _std_to_orig(val_std: float) -> float:
1310 vi = val_std * X_std[j] + X_mean[j]
1311 return float(_inverse_transform(transforms[j], np.array([vi]))[0])
1313 x_min_override = x_max_override = None
1314 if j in range_windows:
1315 lo_std, hi_std = range_windows[j]
1316 x_min_override = min(_std_to_orig(lo_std), _std_to_orig(hi_std))
1317 x_max_override = max(_std_to_orig(lo_std), _std_to_orig(hi_std))
1318 elif j in choice_values:
1319 ints = choice_values[j] * X_std[j] + X_mean[j]
1320 origs = _inverse_transform(transforms[j], ints)
1321 x_min_override = float(np.min(origs))
1322 x_max_override = float(np.max(origs))
1323 else:
1324 finite = finite_raw
1325 if finite.size:
1326 x_min_override = float(np.min(finite))
1327 x_max_override = float(np.max(finite))
1329 if (x_min_override is not None) and (x_max_override is not None):
1330 if is_log_x:
1331 x0 = max(x_min_override, 1e-12)
1332 x1 = max(x_max_override, x0 * (1 + 1e-9))
1333 pad = (x1 / x0) ** 0.03
1334 fig.update_xaxes(type="log",
1335 range=[np.log10(x0 / pad), np.log10(x1 * pad)],
1336 row=row_pos, col=1)
1337 else:
1338 span = (x_max_override - x_min_override) or 1.0
1339 pad = 0.02 * span
1340 fig.update_xaxes(range=[x_min_override - pad, x_max_override + pad],
1341 row=row_pos, col=1)
1343 fig.update_xaxes(title_text=feature_names[j], row=row_pos, col=1)
1345 # tidy rows
1346 for xd, xi, mu_i, sd_i, p_i in zip(x_display, x_internal, mu_grid, sd_grid, p_grid):
1347 tidy_rows.append({
1348 "feature": feature_names[j],
1349 "x_display": float(xd),
1350 "x_internal": float(xi),
1351 "target_conditional_mean": float(mu_i),
1352 "target_conditional_sd": float(sd_i),
1353 "success_probability": float(p_i),
1354 })
1356 else:
1357 base = key # categorical base
1358 labels_all = groups[base]["labels"]
1359 labels = cat_allowed.get(base, labels_all)
1361 # Build standardized design for each label at the base point
1362 Xn_grid = np.repeat(base_std[None, :], len(labels), axis=0)
1363 for r, lab in enumerate(labels):
1364 for lab2 in labels_all:
1365 member_name = groups[base]["name_by_label"][lab2]
1366 j2 = name_to_idx[member_name]
1367 raw_val = 1.0 if (lab2 == lab) else 0.0
1368 # standardized set:
1369 Xi = (raw_val - X_mean[j2]) / X_std[j2]
1370 Xn_grid[r, j2] = Xi
1372 p_vec = pred_success(Xn_grid)
1373 mu_vec, sd_vec = pred_loss(Xn_grid, include_observation_noise=True)
1374 print(feature_names[j], "mu range:", float(np.ptp(mu_grid)))
1376 # y transform
1377 if use_log_scale_for_target_y:
1378 mu_plot = np.maximum(mu_vec, log_y_epsilon)
1379 lo_plot = np.maximum(mu_vec - 2.0 * sd_vec, log_y_epsilon)
1380 hi_plot = np.maximum(mu_vec + 2.0 * sd_vec, log_y_epsilon)
1381 losses_s_plot = np.maximum(df_raw_f.loc[success_mask, tgt_col].to_numpy().astype(float), log_y_epsilon) if success_mask.any() else np.array([])
1382 else:
1383 mu_plot = mu_vec
1384 lo_plot = mu_vec - 2.0 * sd_vec
1385 hi_plot = mu_vec + 2.0 * sd_vec
1386 losses_s_plot = df_raw_f.loc[success_mask, tgt_col].to_numpy().astype(float) if success_mask.any() else np.array([])
1388 # y-range
1389 y_arrays = [lo_plot, hi_plot] + ([losses_s_plot] if losses_s_plot.size else [])
1390 y_low = float(np.nanmin([np.nanmin(a) for a in y_arrays])) if y_arrays else 0.0
1391 y_high = float(np.nanmax([np.nanmax(a) for a in y_arrays])) if y_arrays else 1.0
1392 pad = 0.05 * (y_high - y_low + 1e-12)
1393 y0_plot = (y_low - pad) if not use_log_scale_for_target_y else max(y_low / 1.5, log_y_epsilon)
1394 y1_tmp = (y_high + pad) if not use_log_scale_for_target_y else y_high * 1.2
1395 y_failed_band = y1_tmp + (y_high - y_low + 1e-12) * (0.08 if not use_log_scale_for_target_y else 0.3)
1396 if use_log_scale_for_target_y and y_failed_band <= log_y_epsilon:
1397 y_failed_band = max(10.0 * log_y_epsilon, y_high * 2.0)
1398 y1_plot = y_failed_band + (0.02 if not use_log_scale_for_target_y else 0.05) * (y_high - y_low + 1e-12)
1400 # x positions are 0..K-1 with tick labels = category names
1401 x_pos = np.arange(len(labels), dtype=float)
1403 # shading per-category threshold regions using shapes
1404 def _shade_for_thresh(thr: float, alpha: float):
1405 for k_i, p_i in enumerate(p_vec):
1406 if p_i < thr:
1407 fig.add_shape(
1408 type="rect",
1409 xref=f"x{'' if row_pos==1 else row_pos}",
1410 yref=f"y{'' if row_pos==1 else row_pos}",
1411 x0=k_i - 0.5, x1=k_i + 0.5,
1412 y0=y0_plot, y1=y1_plot,
1413 line=dict(width=0),
1414 fillcolor=f"rgba(128,128,128,{alpha})",
1415 layer="below",
1416 row=row_pos, col=1
1417 )
1418 _shade_for_thresh(0.8, 0.40)
1419 _shade_for_thresh(0.5, 0.25)
1421 show_legend = (row_pos == 1)
1423 # mean with error bars (±2σ)
1424 fig.add_trace(go.Scatter(
1425 x=x_pos, y=mu_plot, mode="lines+markers",
1426 line=dict(width=2, color=line_color),
1427 marker=dict(size=7, color=line_color),
1428 error_y=dict(type="data", array=(hi_plot - mu_plot), arrayminus=(mu_plot - lo_plot), visible=True),
1429 name="E[target|success]", legendgroup="mean", showlegend=show_legend,
1430 hovertemplate=(f"{base}: %{{text}}<br>E[target|success]: %{{y:.3f}}"
1431 "<br>±2σ shown as error bar<extra></extra>"),
1432 text=labels
1433 ), row=row_pos, col=1)
1435 # experimental points: map each row's label to index
1436 if base in df_raw_f.columns:
1437 lab_series = df_raw_f[base].astype("string")
1438 else:
1439 # reconstruct from one-hot members
1440 member_cols = [groups[base]["name_by_label"][lab] for lab in labels_all]
1441 idx_max = df_raw_f[member_cols].to_numpy().argmax(axis=1)
1442 lab_series = pd.Series([labels_all[i] for i in idx_max], dtype="string")
1444 label_to_idx = {lab: i for i, lab in enumerate(labels)}
1445 x_idx_all = lab_series.map(lambda s: label_to_idx.get(str(s), np.nan)).to_numpy(dtype=float)
1446 x_idx_succ = x_idx_all[success_mask]
1447 x_idx_fail = x_idx_all[fail_mask]
1449 # jitter for visibility
1450 rng = np.random.default_rng(0)
1451 jitter = lambda n: (rng.random(n) - 0.5) * 0.15
1453 if x_idx_succ.size:
1454 fig.add_trace(go.Scattergl(
1455 x=x_idx_succ + jitter(x_idx_succ.size),
1456 y=losses_s_plot,
1457 mode="markers",
1458 marker=dict(size=5, color="black", line=dict(width=0)),
1459 name="data (success)", legendgroup="data_s", showlegend=show_legend,
1460 hovertemplate=("trial_id: %{customdata}<br>"
1461 f"{base}: %{{text}}<br>"
1462 f"{tgt_col}: %{{y:.4f}}<extra></extra>"),
1463 text=[labels[int(i)] if np.isfinite(i) and int(i) < len(labels) else "?" for i in x_idx_succ],
1464 customdata=trial_ids_success
1465 ), row=row_pos, col=1)
1467 if x_idx_fail.size:
1468 y_fail_plot = np.full_like(x_idx_fail, y_failed_band, dtype=float)
1469 fig.add_trace(go.Scattergl(
1470 x=x_idx_fail + jitter(x_idx_fail.size), y=y_fail_plot, mode="markers",
1471 marker=dict(size=6, color="red", line=dict(color="black", width=0.8)),
1472 name="data (failed)", legendgroup="data_f", showlegend=show_legend,
1473 hovertemplate=("trial_id: %{customdata}<br>"
1474 f"{base}: %{{text}}<br>"
1475 "status: failed (NaN target)<extra></extra>"),
1476 text=[labels[int(i)] if np.isfinite(i) and int(i) < len(labels) else "?" for i in x_idx_fail],
1477 customdata=trial_ids_fail
1478 ), row=row_pos, col=1)
1480 # overlays for categorical base: map label to x index
1481 if optimal_df is not None and (base in optimal_df.columns):
1482 lab_opt = str(optimal_df[base].values[0])
1483 if lab_opt in label_to_idx:
1484 x_opt = [float(label_to_idx[lab_opt])]
1485 y_opt = optimal_df["pred_target_mean"].values
1486 y_sd = optimal_df["pred_target_sd"].values
1487 fig.add_trace(go.Scattergl(
1488 x=x_opt, y=y_opt, mode="markers",
1489 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"),
1490 name="optimal", legendgroup="optimal", showlegend=show_legend,
1491 hovertemplate=(f"predicted: %{{y:.3g}} ± {y_sd[0]:.3g}<br>"
1492 f"{base}: {lab_opt}<extra></extra>")
1493 ), row=row_pos, col=1)
1495 if suggest_df is not None and (base in suggest_df.columns):
1496 labs_sug = suggest_df[base].astype(str).tolist()
1497 xs = [label_to_idx[l] for l in labs_sug if l in label_to_idx]
1498 if xs:
1499 keep_mask = [l in label_to_idx for l in labs_sug]
1500 y_sug = suggest_df.loc[keep_mask, "pred_target_mean"].values
1501 fig.add_trace(go.Scattergl(
1502 x=np.array(xs, dtype=float), y=y_sug, mode="markers",
1503 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"),
1504 name="suggested", legendgroup="suggested", showlegend=show_legend,
1505 hovertemplate=(f"{base}: %{{text}}<br>"
1506 "predicted: %{{y:.3g}}<extra>suggested</extra>"),
1507 text=[labels[int(i)] for i in xs]
1508 ), row=row_pos, col=1)
1510 # axes: categorical ticks
1511 fig.update_xaxes(
1512 tickmode="array",
1513 tickvals=x_pos.tolist(),
1514 ticktext=labels,
1515 row=row_pos, col=1
1516 )
1517 fig.update_yaxes(title_text=f"{tgt_col}", row=row_pos, col=1)
1518 _set_yaxis_range(fig, row=row_pos, col=1,
1519 y0=y0_plot, y1=y1_plot,
1520 log=use_log_scale_for_target_y, eps=log_y_epsilon)
1521 fig.update_xaxes(title_text=base, row=row_pos, col=1)
1523 # tidy rows
1524 for lab, mu_i, sd_i, p_i in zip(labels, mu_vec, sd_vec, p_vec):
1525 tidy_rows.append({
1526 "feature": base,
1527 "x_display": str(lab),
1528 "x_internal": float("nan"),
1529 "target_conditional_mean": float(mu_i),
1530 "target_conditional_sd": float(sd_i),
1531 "success_probability": float(p_i),
1532 })
1534 # title w/ constraints summary
1535 def _fmt_c(v):
1536 if isinstance(v, slice):
1537 a = "" if v.start is None else f"{v.start:g}"
1538 b = "" if v.stop is None else f"{v.stop:g}"
1539 return f"[{a},{b}]"
1540 if isinstance(v, (list, tuple, np.ndarray)):
1541 try:
1542 return "[" + ",".join(f"{float(x):g}" for x in np.asarray(v).tolist()) + "]"
1543 except Exception:
1544 return "[" + ",".join(map(str, v)) + "]"
1545 try:
1546 return f"{float(v):g}"
1547 except Exception:
1548 return str(v)
1550 parts = [f"1D partial dependence of expected {tgt_col}"]
1551 if kw_num_raw:
1552 parts.append(", ".join(f"{k}={_fmt_c(v)}" for k, v in kw_num_raw.items()))
1553 if cat_fixed:
1554 parts.append(", ".join(f"{b}={lab}" for b, lab in cat_fixed.items()))
1555 if cat_allowed:
1556 parts.append(", ".join(f"{b}∈{{{', '.join(v)}}}" for b, v in cat_allowed.items()))
1557 title = " — ".join(parts) if len(parts) > 1 else parts[0]
1559 width = width if (width and width > 0) else 1200
1560 height = height if (height and height > 0) else 1200
1562 fig.update_layout(
1563 height=height,
1564 width=width,
1565 template="simple_white",
1566 title=title,
1567 legend_title_text=""
1568 )
1570 if output:
1571 write_image(fig, output)
1572 if csv_out:
1573 csv_out = Path(csv_out)
1574 csv_out.parent.mkdir(parents=True, exist_ok=True)
1575 pd.DataFrame(tidy_rows).to_csv(str(csv_out), index=False)
1576 if show:
1577 fig.show("browser")
1579 return fig
1582# =============================================================================
1583# Helpers: dataset → predictors & featurization
1584# =============================================================================
1585def _build_predictors(ds: xr.Dataset):
1586 """Reconstruct fast GP predictors from the artifact using shared helpers."""
1587 # Training matrices / targets
1588 Xn_all = ds["Xn_train"].values.astype(float) # (N, p)
1589 y_success = ds["y_success"].values.astype(float) # (N,)
1590 Xn_ok = ds["Xn_success_only"].values.astype(float) # (Ns, p)
1591 y_loss_centered = ds["y_loss_centered"].values.astype(float)
1593 # Compatibility: conditional_loss_mean may be a var or an attr
1594 cond_mean = (
1595 float(ds["conditional_loss_mean"].values)
1596 if "conditional_loss_mean" in ds
1597 else float(ds.attrs.get("conditional_loss_mean", 0.0))
1598 )
1600 # Success head MAP params
1601 ell_s = ds["map_success_ell"].values.astype(float) # (p,)
1602 eta_s = float(ds["map_success_eta"].values)
1603 sigma_s = float(ds["map_success_sigma"].values)
1604 beta0_s = float(ds["map_success_beta0"].values)
1606 # Loss head MAP params
1607 ell_l = ds["map_loss_ell"].values.astype(float) # (p,)
1608 eta_l = float(ds["map_loss_eta"].values)
1609 sigma_l = float(ds["map_loss_sigma"].values)
1610 mean_c = float(ds["map_loss_mean_const"].values)
1612 # --- Cholesky precomputations (success) ---
1613 K_s = kernel_m52_ard(Xn_all, Xn_all, ell_s, eta_s) + (sigma_s**2) * np.eye(Xn_all.shape[0])
1614 L_s = np.linalg.cholesky(add_jitter(K_s))
1615 alpha_s = solve_chol(L_s, (y_success - beta0_s))
1617 # --- Cholesky precomputations (loss | success) ---
1618 K_l = kernel_m52_ard(Xn_ok, Xn_ok, ell_l, eta_l) + (sigma_l**2) * np.eye(Xn_ok.shape[0])
1619 L_l = np.linalg.cholesky(add_jitter(K_l))
1620 alpha_l = solve_chol(L_l, (y_loss_centered - mean_c))
1622 def predict_success_probability(Xn: np.ndarray) -> np.ndarray:
1623 Ks = kernel_m52_ard(Xn, Xn_all, ell_s, eta_s)
1624 mu = beta0_s + Ks @ alpha_s
1625 return np.clip(mu, 0.0, 1.0)
1627 def predict_conditional_target(
1628 Xn: np.ndarray,
1629 include_observation_noise: bool = True
1630 ):
1631 Kl = kernel_m52_ard(Xn, Xn_ok, ell_l, eta_l)
1632 mu_centered = mean_c + Kl @ alpha_l
1633 mu = mu_centered + cond_mean
1635 # diag predictive variance
1636 v = solve_lower(L_l, Kl.T) # (Ns, Nt)
1637 var = kernel_diag_m52(Xn, ell_l, eta_l) - np.sum(v * v, axis=0)
1638 var = np.maximum(var, 1e-12)
1639 if include_observation_noise:
1640 var = var + sigma_l**2
1641 sd = np.sqrt(var)
1642 return mu, sd
1644 return predict_success_probability, predict_conditional_target
1647def _raw_dataframe_from_dataset(ds: xr.Dataset) -> pd.DataFrame:
1648 """Collect raw columns from the artifact into a DataFrame for plotting."""
1649 cols = {}
1650 for name in ds.data_vars:
1651 # include only row-aligned arrays
1652 da = ds[name]
1653 if "row" in da.dims and len(da.dims) == 1 and da.sizes["row"] == ds.sizes["row"]:
1654 cols[name] = da.values
1655 # Ensure trial_id exists for hover
1656 if "trial_id" not in cols:
1657 cols["trial_id"] = np.arange(ds.sizes["row"], dtype=int)
1658 return pd.DataFrame(cols)
1661def _apply_fixed_to_base(
1662 base_std: np.ndarray,
1663 fixed: dict[str, float],
1664 feature_names: list[str],
1665 transforms: list[str],
1666 X_mean: np.ndarray,
1667 X_std: np.ndarray,
1668) -> np.ndarray:
1669 """Override base point in standardized space with fixed ORIGINAL values."""
1670 out = base_std.copy()
1671 name_to_idx = {n: i for i, n in enumerate(feature_names)}
1672 for k, v in fixed.items():
1673 if k not in name_to_idx:
1674 raise KeyError(f"Fixed variable '{k}' is not a model feature.")
1675 j = name_to_idx[k]
1676 x_raw = _forward_transform(transforms[j], float(v))
1677 out[j] = (x_raw - X_mean[j]) / X_std[j]
1678 return out
1681def _denormalize_then_inverse_transform(j: int, x_std: np.ndarray, transforms, X_mean, X_std) -> np.ndarray:
1682 x_raw = x_std * X_std[j] + X_mean[j]
1683 return _inverse_transform(transforms[j], x_raw)
1686def _forward_transform(tr: str, x: float | np.ndarray) -> np.ndarray:
1687 if tr == "log10":
1688 x = np.asarray(x, dtype=float)
1689 return np.log10(np.maximum(x, 1e-12))
1690 return np.asarray(x, dtype=float)
1693def _inverse_transform(tr: str, x: np.ndarray) -> np.ndarray:
1694 if tr == "log10":
1695 return 10.0 ** x
1696 return x
1699def _maybe_log_axis(fig: go.Figure, row: int, col: int, name: str, axis: str = "x", transforms: list[str] | None = None, j: int | None = None):
1700 """Use log axis for features that were log10-transformed."""
1701 use_log = False
1702 if transforms is not None and j is not None:
1703 use_log = (transforms[j] == "log10")
1704 else:
1705 use_log = ("learning_rate" in name.lower() or name.lower() == "lr")
1706 if use_log:
1707 if axis == "x":
1708 fig.update_xaxes(type="log", row=row, col=col)
1709 else:
1710 fig.update_yaxes(type="log", row=row, col=col)
1713def _rgb_string_to_tuple(s: str) -> tuple[int, int, int]:
1714 vals = s[s.find("(") + 1 : s.find(")")].split(",")
1715 r, g, b = [int(float(v)) for v in vals[:3]]
1716 return r, g, b
1719def _rgb_to_rgba(rgb: str, alpha: float) -> str:
1720 # expects "rgb(r,g,b)" or "rgba(r,g,b,a)"
1721 try:
1722 r, g, b = _rgb_string_to_tuple(rgb)
1723 except Exception:
1724 r, g, b = (31, 119, 180)
1725 return f"rgba({r},{g},{b},{alpha:.3f})"
1728def _add_low_success_shading_1d(fig: go.Figure, row_idx: int, x_vals: np.ndarray, p: np.ndarray, y0: float, y1: float):
1729 xref = "x" if row_idx == 1 else f"x{row_idx}"
1730 yref = "y" if row_idx == 1 else f"y{row_idx}"
1732 def _spans(vals: np.ndarray, mask: np.ndarray):
1733 m = mask.astype(int)
1734 diff = np.diff(np.concatenate([[0], m, [0]]))
1735 starts = np.where(diff == 1)[0]
1736 ends = np.where(diff == -1)[0] - 1
1737 return [(vals[s], vals[e]) for s, e in zip(starts, ends)]
1739 for x0, x1 in _spans(x_vals, p < 0.5):
1740 fig.add_shape(type="rect", x0=x0, x1=x1, y0=y0, y1=y1, xref=xref, yref=yref,
1741 line=dict(width=0), fillcolor="rgba(128,128,128,0.25)", layer="below")
1742 for x0, x1 in _spans(x_vals, p < 0.8):
1743 fig.add_shape(type="rect", x0=x0, x1=x1, y0=y0, y1=y1, xref=xref, yref=yref,
1744 line=dict(width=0), fillcolor="rgba(128,128,128,0.40)", layer="below")
1747def _set_yaxis_range(fig, *, row: int, col: int, y0: float, y1: float, log: bool, eps: float = 1e-12):
1748 """Update a subplot's Y axis to [y0, y1]. For log axes, the range is given in log10 units."""
1749 if log:
1750 y0 = max(y0, eps)
1751 y1 = max(y1, y0 * (1.0 + 1e-6))
1752 fig.update_yaxes(type="log", range=[np.log10(y0), np.log10(y1)], row=row, col=col)
1753 else:
1754 fig.update_yaxes(type="-", range=[y0, y1], row=row, col=col)
1757def optimum_plot1d(
1758 model: xr.Dataset | Path | str,
1759 output: Path | None = None,
1760 csv_out: Path | None = None,
1761 grid_size: int = 300,
1762 line_color: str = "rgb(31,119,180)",
1763 band_alpha: float = 0.25,
1764 show: bool = False,
1765 use_log_scale_for_target_y: bool = True,
1766 log_y_epsilon: float = 1e-9,
1767 optimal: bool = True,
1768 suggest: int = 0, # optional overlay
1769 width: int | None = None,
1770 height: int | None = None,
1771 seed: int | None = 42,
1772 **kwargs, # constraints in ORIGINAL units (as in your plot1d)
1773) -> go.Figure:
1774 """
1775 1D partial-dependence panels anchored at the *optimal* hyperparameter setting:
1776 - Compute x* = argmin/argmax mean posterior from opt.optimal(...)
1777 - For each feature, sweep that feature; keep all *other* features fixed at x*.
1778 Supports numeric constraints (scalars/slices/choices) and categorical bases.
1779 """
1780 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model)
1781 pred_success, pred_loss = _build_predictors(ds)
1783 # --- metadata ---
1784 feature_names = [str(n) for n in ds["feature"].values.tolist()]
1785 transforms = [str(t) for t in ds["feature_transform"].values.tolist()]
1786 X_mean = ds["feature_mean"].values.astype(float)
1787 X_std = ds["feature_std"].values.astype(float)
1789 df_raw = _raw_dataframe_from_dataset(ds)
1790 Xn_train = ds["Xn_train"].values.astype(float)
1791 n_rows, p = Xn_train.shape
1793 # --- one-hot categorical groups ---
1794 groups = opt._onehot_groups(feature_names)
1795 bases = set(groups.keys())
1796 name_to_idx = {name: j for j, name in enumerate(feature_names)}
1798 # --- canonicalize kwargs: numeric vs categorical (base) ---
1799 idx_map = _canon_key_set(ds) # your helper: maps normalized names -> exact feature column
1800 kw_num_raw: dict[str, object] = {}
1801 kw_cat_raw: dict[str, object] = {}
1802 for k, v in kwargs.items():
1803 if k in bases:
1804 kw_cat_raw[k] = v
1805 elif k in idx_map:
1806 kw_num_raw[idx_map[k]] = v
1807 else:
1808 import re as _re
1809 nk = _re.sub(r"[^a-z0-9]+", "", str(k).lower())
1810 if nk in idx_map:
1811 kw_num_raw[idx_map[nk]] = v
1813 # --- resolve categorical constraints: fixed vs allowed subset ---
1814 cat_fixed: dict[str, str] = {}
1815 cat_allowed: dict[str, list[str]] = {}
1816 for base, val in kw_cat_raw.items():
1817 labels = groups[base]["labels"]
1818 if isinstance(val, str):
1819 if val not in labels:
1820 raise ValueError(f"Unknown category for {base!r}: {val!r}. Choices: {labels}")
1821 cat_fixed[base] = val
1822 elif isinstance(val, (list, tuple, set)):
1823 chosen = [x for x in val if isinstance(x, str) and x in labels]
1824 if not chosen:
1825 raise ValueError(f"No valid categories for {base!r} in {val!r}. Choices: {labels}")
1826 cat_allowed[base] = list(dict.fromkeys(chosen))
1827 else:
1828 raise ValueError(f"Categorical constraint for {base!r} must be a string or list/tuple of strings.")
1830 # --- row mask for experimental points (matches constraints) ---
1831 row_mask = np.ones(n_rows, dtype=bool)
1832 for base, label in cat_fixed.items():
1833 if base in df_raw.columns:
1834 series = df_raw[base].astype("string")
1835 row_mask &= series.eq(label).fillna(False).to_numpy()
1836 else:
1837 member_name = groups[base]["name_by_label"][label]
1838 j = name_to_idx[member_name]
1839 raw_j = feature_raw_from_artifact_or_reconstruct(ds, j, member_name, transforms[j]).astype(float)
1840 row_mask &= (raw_j >= 0.5)
1842 for base, allowed in cat_allowed.items():
1843 if base in cat_fixed:
1844 continue # already fixed
1845 allowed_labels = [str(x) for x in allowed]
1846 if base in df_raw.columns:
1847 series = df_raw[base].astype("string").fillna("<NA>")
1848 allowed_mask = series.isin(set(allowed_labels)).fillna(False).to_numpy()
1849 row_mask &= allowed_mask
1850 else:
1851 allowed_masks: list[np.ndarray] = []
1852 for label in allowed_labels:
1853 member_name = groups[base]["name_by_label"].get(label)
1854 if member_name is None:
1855 continue
1856 j = name_to_idx[member_name]
1857 raw_j = feature_raw_from_artifact_or_reconstruct(ds, j, member_name, transforms[j]).astype(float)
1858 allowed_masks.append(raw_j >= 0.5)
1859 if allowed_masks:
1860 row_mask &= np.logical_or.reduce(allowed_masks)
1861 else:
1862 row_mask &= False
1864 for name, val in kw_num_raw.items():
1865 if name not in name_to_idx:
1866 continue
1867 j = name_to_idx[name]
1868 if name in df_raw.columns:
1869 raw_vals = pd.to_numeric(df_raw[name], errors="coerce").to_numpy(dtype=float)
1870 else:
1871 raw_vals = feature_raw_from_artifact_or_reconstruct(ds, j, feature_names[j], transforms[j]).astype(float)
1873 mask = np.isfinite(raw_vals)
1874 if isinstance(val, slice):
1875 lo_raw = -np.inf if val.start is None else float(val.start)
1876 hi_raw = np.inf if val.stop is None else float(val.stop)
1877 if hi_raw < lo_raw:
1878 lo_raw, hi_raw = hi_raw, lo_raw
1879 mask &= (raw_vals >= lo_raw) & (raw_vals <= hi_raw)
1880 elif isinstance(val, (list, tuple, set, np.ndarray)):
1881 arr = np.asarray(list(val) if not isinstance(val, np.ndarray) else val, dtype=float)
1882 arr = arr[np.isfinite(arr)]
1883 if arr.size == 0:
1884 mask &= False
1885 else:
1886 mask &= np.any(np.isclose(raw_vals[:, None], arr[None, :], rtol=1e-6, atol=1e-9), axis=1)
1887 else:
1888 target = float(val)
1889 mask &= np.isclose(raw_vals, target, rtol=1e-6, atol=1e-9)
1891 row_mask &= mask
1893 if not np.any(row_mask):
1894 raise ValueError("No experiments match the provided constraints; cannot plot data points.")
1896 df_raw_f = df_raw.loc[row_mask].reset_index(drop=True)
1898 # ---------- 1) Find the *optimal* base point (original units) ----------
1899 opt_df = opt.optimal(model, count=1, seed=seed, **kwargs) # uses your gradient-based optimal()
1900 # We’ll use this row both for overlays and as the anchor point.
1901 # Expect numeric feature columns and categorical base columns present.
1902 x_opt_std = np.zeros(p, dtype=float)
1904 # Fill numerics from optimal row (orig -> internal -> std)
1905 def _to_std_single(j: int, x_orig: float) -> float:
1906 xi = x_orig
1907 if transforms[j] == "log10":
1908 xi = np.log10(np.maximum(x_orig, 1e-300))
1909 return float((xi - X_mean[j]) / X_std[j])
1911 # Mark one-hot member names
1912 onehot_members: set[str] = set()
1913 for base, g in groups.items():
1914 onehot_members.update(g["members"])
1916 # numeric features (skip one-hot members)
1917 for j, nm in enumerate(feature_names):
1918 if nm in onehot_members:
1919 continue
1920 if nm in opt_df.columns:
1921 x_opt_std[j] = _to_std_single(j, float(opt_df.iloc[0][nm]))
1922 else:
1923 # Fall back to dataset median if not present (rare)
1924 x_opt_std[j] = float(np.median(Xn_train[:, j]))
1926 # Categorical bases: set one-hot block to the optimal label (or fixed)
1927 for base, g in groups.items():
1928 # priority: fixed in kwargs → else from optimal row → else keep current (median/std)
1929 if base in cat_fixed:
1930 label = cat_fixed[base]
1931 elif base in opt_df.columns:
1932 label = str(opt_df.iloc[0][base])
1933 else:
1934 # fallback: most frequent label in data
1935 if base in df_raw.columns:
1936 label = str(df_raw[base].astype("string").mode(dropna=True).iloc[0])
1937 else:
1938 label = g["labels"][0]
1940 for lab in g["labels"]:
1941 member_name = g["name_by_label"][lab]
1942 j2 = name_to_idx[member_name]
1943 raw = 1.0 if lab == label else 0.0
1944 # raw (0/1) → standardized using the artifact stats
1945 x_opt_std[j2] = (raw - X_mean[j2]) / X_std[j2]
1947 # ---------- 2) Numeric constraints in STANDARDIZED space ----------
1948 def _orig_to_std(j: int, x, transforms, mu, sd):
1949 x = np.asarray(x, dtype=float)
1950 if transforms[j] == "log10":
1951 x = np.where(x <= 0, np.nan, x)
1952 x = np.log10(x)
1953 return (x - mu[j]) / sd[j]
1955 fixed_scalars: dict[int, float] = {}
1956 range_windows: dict[int, tuple[float, float]] = {}
1957 choice_values: dict[int, np.ndarray] = {}
1959 for name, val in kw_num_raw.items():
1960 if name not in name_to_idx:
1961 continue
1962 j = name_to_idx[name]
1963 if isinstance(val, slice):
1964 lo = _orig_to_std(j, float(val.start), transforms, X_mean, X_std)
1965 hi = _orig_to_std(j, float(val.stop), transforms, X_mean, X_std)
1966 lo, hi = float(min(lo, hi)), float(max(lo, hi))
1967 range_windows[j] = (lo, hi)
1968 elif isinstance(val, (list, tuple, np.ndarray)):
1969 arr = _orig_to_std(j, np.asarray(val, dtype=float), transforms, X_mean, X_std)
1970 choice_values[j] = np.asarray(arr, dtype=float)
1971 else:
1972 fixed_scalars[j] = float(_orig_to_std(j, float(val), transforms, X_mean, X_std))
1974 # apply numeric fixed overrides on the base point
1975 for j, vstd in fixed_scalars.items():
1976 x_opt_std[j] = vstd
1978 # ---------- 3) Panels: sweep ONE var at a time around x* ----------
1979 # numeric free = not one-hot member and not fixed via kwargs
1980 free_numeric_idx = [
1981 j for j, nm in enumerate(feature_names)
1982 if (nm not in onehot_members) and (j not in fixed_scalars)
1983 ]
1984 # categorical bases: sweep if not fixed; otherwise not shown
1985 free_cat_bases = [b for b in bases if b not in cat_fixed]
1987 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases]
1988 if not panels:
1989 raise ValueError("All features are fixed at the optimum (or categoricals fixed); nothing to plot.")
1991 # empirical 1–99% per feature (for default sweep range)
1992 Xn_p01 = np.percentile(Xn_train, 1, axis=0)
1993 Xn_p99 = np.percentile(Xn_train, 99, axis=0)
1995 def _grid_1d(j: int, n: int) -> np.ndarray:
1996 # default range in std space
1997 lo, hi = float(Xn_p01[j]), float(Xn_p99[j])
1998 if j in range_windows:
1999 lo = max(lo, range_windows[j][0])
2000 hi = min(hi, range_windows[j][1])
2001 if j in choice_values:
2002 vals = np.asarray(choice_values[j], dtype=float)
2003 vals = vals[(vals >= lo) & (vals <= hi)]
2004 return np.unique(np.sort(vals)) if vals.size else np.array([x_opt_std[j]], dtype=float)
2005 if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
2006 lo, hi = x_opt_std[j] - 1.0, x_opt_std[j] + 1.0
2007 return np.linspace(lo, hi, n)
2009 # figure scaffold
2010 subplot_titles = [feature_names[int(k)] if t == "num" else str(k) for t, k in panels]
2011 fig = make_subplots(rows=len(panels), cols=1, shared_xaxes=False, subplot_titles=subplot_titles)
2013 tgt_col = str(ds.attrs["target"])
2014 success_mask = ~pd.isna(df_raw_f[tgt_col]).to_numpy()
2015 fail_mask = ~success_mask
2016 losses_success = df_raw_f.loc[success_mask, tgt_col].to_numpy().astype(float)
2017 trial_ids_success = df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask]
2018 trial_ids_fail = df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask]
2019 band_fill_rgba = _rgb_to_rgba(line_color, band_alpha)
2021 # optional overlay
2022 suggest_df = opt.suggest(model, count=suggest, seed=seed, **kwargs) if (suggest and suggest > 0) else None
2024 tidy_rows: list[dict] = []
2025 row_pos = 0
2026 for kind, key in panels:
2027 row_pos += 1
2029 if kind == "num":
2030 j = int(key)
2031 if feature_names[j] in df_raw.columns:
2032 series_full = pd.to_numeric(df_raw[feature_names[j]], errors="coerce")
2033 x_full_raw = series_full.to_numpy(dtype=float)
2034 else:
2035 x_full_raw = feature_raw_from_artifact_or_reconstruct(
2036 ds, j, feature_names[j], transforms[j]
2037 ).astype(float)
2038 x_data_all = x_full_raw[row_mask]
2039 finite_raw = x_full_raw[np.isfinite(x_full_raw)]
2040 if transforms[j] == "log10":
2041 finite_raw = finite_raw[finite_raw > 0]
2043 grid = _grid_1d(j, grid_size)
2044 if (j not in range_windows) and (j not in choice_values) and finite_raw.size:
2045 finite_std = _orig_to_std(j, finite_raw, transforms, X_mean, X_std)
2046 grid_min = float(np.nanmin(np.concatenate([grid, finite_std])))
2047 grid_max = float(np.nanmax(np.concatenate([grid, finite_std])))
2048 if grid_max > grid_min:
2049 grid = np.linspace(grid_min, grid_max, grid_size)
2050 else:
2051 grid = np.array([grid_min], dtype=float)
2052 Xn_grid = np.repeat(x_opt_std[None, :], len(grid), axis=0)
2053 Xn_grid[:, j] = grid
2055 p_grid = pred_success(Xn_grid)
2056 mu_grid, sd_grid = pred_loss(Xn_grid, include_observation_noise=True)
2058 x_internal = grid * X_std[j] + X_mean[j]
2059 x_display = _inverse_transform(transforms[j], x_internal)
2061 # y transform
2062 if use_log_scale_for_target_y:
2063 mu_plot = np.maximum(mu_grid, log_y_epsilon)
2064 lo_plot = np.maximum(mu_grid - 2.0 * sd_grid, log_y_epsilon)
2065 hi_plot = np.maximum(mu_grid + 2.0 * sd_grid, log_y_epsilon)
2066 losses_s_plot = np.maximum(losses_success, log_y_epsilon) if losses_success.size else losses_success
2067 else:
2068 mu_plot = mu_grid
2069 lo_plot = mu_grid - 2.0 * sd_grid
2070 hi_plot = mu_grid + 2.0 * sd_grid
2071 losses_s_plot = losses_success
2073 y_arrays = [lo_plot, hi_plot] + ([losses_s_plot] if losses_s_plot.size else [])
2074 y_low = float(np.nanmin([np.nanmin(a) for a in y_arrays]))
2075 y_high = float(np.nanmax([np.nanmax(a) for a in y_arrays]))
2076 pad = 0.05 * (y_high - y_low + 1e-12)
2077 y0_plot = (y_low - pad) if not use_log_scale_for_target_y else max(y_low / 1.5, log_y_epsilon)
2078 y1_tmp = (y_high + pad) if not use_log_scale_for_target_y else y_high * 1.2
2079 y_failed_band = y1_tmp + (y_high - y_low + 1e-12) * (0.08 if not use_log_scale_for_target_y else 0.3)
2080 if use_log_scale_for_target_y and y_failed_band <= log_y_epsilon:
2081 y_failed_band = max(10.0 * log_y_epsilon, y_high * 2.0)
2082 y1_plot = y_failed_band + (0.02 if not use_log_scale_for_target_y else 0.05) * (y_high - y_low + 1e-12)
2084 _add_low_success_shading_1d(fig, row_pos, x_display, p_grid, y0_plot, y1_plot)
2086 show_legend = (row_pos == 1)
2087 # ±2σ band
2088 fig.add_trace(go.Scatter(x=x_display, y=lo_plot, mode="lines",
2089 line=dict(width=0, color=line_color),
2090 name="±2σ", legendgroup="band", showlegend=False, hoverinfo="skip"),
2091 row=row_pos, col=1)
2092 fig.add_trace(go.Scatter(x=x_display, y=hi_plot, mode="lines", fill="tonexty",
2093 line=dict(width=0, color=line_color), fillcolor=band_fill_rgba,
2094 name="±2σ", legendgroup="band", showlegend=show_legend,
2095 hovertemplate="E[target|success]: %{y:.3f}<extra>±2σ</extra>"),
2096 row=row_pos, col=1)
2097 # mean
2098 fig.add_trace(go.Scatter(x=x_display, y=mu_plot, mode="lines",
2099 line=dict(width=2, color=line_color),
2100 name="E[target|success]", legendgroup="mean", showlegend=show_legend,
2101 hovertemplate=f"{feature_names[j]}: %{{x:.6g}}<br>E[target|success]: %{{y:.3f}}<extra></extra>"),
2102 row=row_pos, col=1)
2104 # experimental points at y (filtered to constraint-satisfied rows)
2105 x_succ = x_data_all[success_mask]
2106 if x_succ.size:
2107 fig.add_trace(go.Scattergl(
2108 x=x_succ, y=losses_s_plot, mode="markers",
2109 marker=dict(size=5, color="black", line=dict(width=0)),
2110 name="data (success)", legendgroup="data_s", showlegend=show_legend,
2111 hovertemplate=("trial_id: %{customdata}<br>"
2112 f"{feature_names[j]}: %{{x:.6g}}<br>"
2113 f"{tgt_col}: %{{y:.4f}}<extra></extra>"),
2114 customdata=trial_ids_success
2115 ), row=row_pos, col=1)
2117 x_fail = x_data_all[fail_mask]
2118 if x_fail.size:
2119 y_fail_plot = np.full_like(x_fail, y_failed_band, dtype=float)
2120 fig.add_trace(go.Scattergl(
2121 x=x_fail, y=y_fail_plot, mode="markers",
2122 marker=dict(size=6, color="red", line=dict(color="black", width=0.8)),
2123 name="data (failed)", legendgroup="data_f", showlegend=show_legend,
2124 hovertemplate=("trial_id: %{customdata}<br>"
2125 f"{feature_names[j]}: %{{x:.6g}}<br>"
2126 "status: failed (NaN target)<extra></extra>"),
2127 customdata=trial_ids_fail
2128 ), row=row_pos, col=1)
2130 # overlays: optimal (single point) and suggested (optional many)
2131 if optimal and feature_names[j] in opt_df.columns:
2132 x_opt_disp = float(opt_df.iloc[0][feature_names[j]])
2133 y_opt = float(opt_df.iloc[0]["pred_target_mean"])
2134 y_opt_sd = float(opt_df.iloc[0].get("pred_target_sd", np.nan))
2135 fig.add_trace(go.Scattergl(
2136 x=[x_opt_disp], y=[y_opt], mode="markers",
2137 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"),
2138 name="optimal", legendgroup="optimal", showlegend=show_legend,
2139 hovertemplate=(f"predicted: %{{y:.3g}}"
2140 + ("" if np.isnan(y_opt_sd) else f" ± {y_opt_sd:.3g}")
2141 + f"<br>{feature_names[j]}: %{{x:.6g}}<extra></extra>")
2142 ), row=row_pos, col=1)
2144 if suggest and (suggest_df is not None) and (feature_names[j] in suggest_df.columns):
2145 xs = suggest_df[feature_names[j]].values.astype(float)
2146 ys = suggest_df["pred_target_mean"].values.astype(float)
2147 ysd = suggest_df.get("pred_target_sd", pd.Series([np.nan]*len(suggest_df))).values
2148 fig.add_trace(go.Scattergl(
2149 x=xs, y=ys, mode="markers",
2150 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"),
2151 name="suggested", legendgroup="suggested", showlegend=show_legend,
2152 hovertemplate=("predicted: %{y:.3g}"
2153 + (" ± %{customdata:.3g}" if not np.isnan(ysd).all() else "")
2154 + f"<br>{feature_names[j]}: %{{x:.6g}}<extra>suggested</extra>"),
2155 customdata=ysd
2156 ), row=row_pos, col=1)
2158 # axes + ranges
2159 _maybe_log_axis(fig, row_pos, 1, feature_names[j], axis="x", transforms=transforms, j=j)
2160 fig.update_yaxes(title_text=f"{tgt_col}", row=row_pos, col=1)
2161 _set_yaxis_range(fig, row=row_pos, col=1,
2162 y0=y0_plot, y1=y1_plot,
2163 log=use_log_scale_for_target_y, eps=log_y_epsilon)
2164 fig.update_xaxes(title_text=feature_names[j], row=row_pos, col=1)
2165 is_log_x = (transforms[j] == "log10")
2167 # If a constraint limited the sweep, respect it on the displayed axis
2168 def _std_to_orig(val_std: float) -> float:
2169 vi = val_std * X_std[j] + X_mean[j]
2170 return float(_inverse_transform(transforms[j], np.array([vi]))[0])
2172 if j in range_windows:
2173 lo_std, hi_std = range_windows[j]
2174 x_min_override = min(_std_to_orig(lo_std), _std_to_orig(hi_std))
2175 x_max_override = max(_std_to_orig(lo_std), _std_to_orig(hi_std))
2176 span = (x_max_override - x_min_override) or 1.0
2177 pad = 0.02 * span
2178 fig.update_xaxes(range=[x_min_override - pad, x_max_override + pad], row=row_pos, col=1)
2179 elif j in choice_values and choice_values[j].size:
2180 ints = choice_values[j] * X_std[j] + X_mean[j]
2181 origs = _inverse_transform(transforms[j], ints)
2182 span = float(np.max(origs) - np.min(origs)) or 1.0
2183 pad = 0.05 * span
2184 fig.update_xaxes(range=[float(np.min(origs) - pad), float(np.max(origs) + pad)], row=row_pos, col=1)
2185 else:
2186 if finite_raw.size:
2187 if is_log_x:
2188 x0 = max(float(np.min(finite_raw)), 1e-12)
2189 x1 = max(float(np.max(finite_raw)), x0 * (1 + 1e-9))
2190 pad = (x1 / x0) ** 0.03
2191 fig.update_xaxes(
2192 range=[np.log10(x0 / pad), np.log10(x1 * pad)],
2193 row=row_pos,
2194 col=1,
2195 )
2196 else:
2197 x0 = float(np.min(finite_raw))
2198 x1 = float(np.max(finite_raw))
2199 span = (x1 - x0) or 1.0
2200 pad = 0.02 * span
2201 fig.update_xaxes(range=[x0 - pad, x1 + pad], row=row_pos, col=1)
2203 # tidy rows
2204 for xd, xi, mu_i, sd_i, p_i in zip(x_display, x_internal, mu_grid, sd_grid, p_grid):
2205 tidy_rows.append({
2206 "feature": feature_names[j],
2207 "x_display": float(xd),
2208 "x_internal": float(xi),
2209 "target_conditional_mean": float(mu_i),
2210 "target_conditional_sd": float(sd_i),
2211 "success_probability": float(p_i),
2212 })
2214 else:
2215 base = str(key)
2216 labels_all = groups[base]["labels"]
2217 labels = cat_allowed.get(base, labels_all)
2219 # Evaluate each label with numerics and other bases fixed at x_opt_std
2220 Xn_grid = np.repeat(x_opt_std[None, :], len(labels), axis=0)
2221 for r, lab in enumerate(labels):
2222 for lab2 in labels_all:
2223 member_name = groups[base]["name_by_label"][lab2]
2224 j2 = name_to_idx[member_name]
2225 raw_val = 1.0 if (lab2 == lab) else 0.0
2226 Xn_grid[r, j2] = (raw_val - X_mean[j2]) / X_std[j2]
2228 p_vec = pred_success(Xn_grid)
2229 mu_vec, sd_vec = pred_loss(Xn_grid, include_observation_noise=True)
2231 # y transform
2232 if use_log_scale_for_target_y:
2233 mu_plot = np.maximum(mu_vec, log_y_epsilon)
2234 lo_plot = np.maximum(mu_vec - 2.0 * sd_vec, log_y_epsilon)
2235 hi_plot = np.maximum(mu_vec + 2.0 * sd_vec, log_y_epsilon)
2236 losses_s_plot = np.maximum(df_raw_f.loc[success_mask, tgt_col].to_numpy().astype(float), log_y_epsilon) if success_mask.any() else np.array([])
2237 else:
2238 mu_plot = mu_vec
2239 lo_plot = mu_vec - 2.0 * sd_vec
2240 hi_plot = mu_vec + 2.0 * sd_vec
2241 losses_s_plot = df_raw_f.loc[success_mask, tgt_col].to_numpy().astype(float) if success_mask.any() else np.array([])
2243 y_arrays = [lo_plot, hi_plot] + ([losses_s_plot] if losses_s_plot.size else [])
2244 y_low = float(np.nanmin([np.nanmin(a) for a in y_arrays])) if y_arrays else 0.0
2245 y_high = float(np.nanmax([np.nanmax(a) for a in y_arrays])) if y_arrays else 1.0
2246 pad = 0.05 * (y_high - y_low + 1e-12)
2247 y0_plot = (y_low - pad) if not use_log_scale_for_target_y else max(y_low / 1.5, log_y_epsilon)
2248 y1_tmp = (y_high + pad) if not use_log_scale_for_target_y else y_high * 1.2
2249 y_failed_band = y1_tmp + (y_high - y_low + 1e-12) * (0.08 if not use_log_scale_for_target_y else 0.3)
2250 if use_log_scale_for_target_y and y_failed_band <= log_y_epsilon:
2251 y_failed_band = max(10.0 * log_y_epsilon, y_high * 2.0)
2252 y1_plot = y_failed_band + (0.02 if not use_log_scale_for_target_y else 0.05) * (y_high - y_low + 1e-12)
2254 # x = 0..K-1 with tick labels
2255 x_pos = np.arange(len(labels), dtype=float)
2257 # grey out infeasible (p<thr)
2258 def _shade_for_thresh(thr: float, alpha: float):
2259 for k_i, p_i in enumerate(p_vec):
2260 if p_i < thr:
2261 fig.add_shape(
2262 type="rect",
2263 xref=f"x{'' if row_pos==1 else row_pos}",
2264 yref=f"y{'' if row_pos==1 else row_pos}",
2265 x0=k_i - 0.5, x1=k_i + 0.5,
2266 y0=y0_plot, y1=y1_plot,
2267 line=dict(width=0),
2268 fillcolor=f"rgba(128,128,128,{alpha})",
2269 layer="below",
2270 row=row_pos, col=1
2271 )
2272 _shade_for_thresh(0.8, 0.40)
2273 _shade_for_thresh(0.5, 0.25)
2275 show_legend = (row_pos == 1)
2276 fig.add_trace(go.Scatter(
2277 x=x_pos, y=mu_plot, mode="lines+markers",
2278 line=dict(width=2, color=line_color),
2279 marker=dict(size=7, color=line_color),
2280 error_y=dict(type="data", array=(hi_plot - mu_plot), arrayminus=(mu_plot - lo_plot), visible=True),
2281 name="E[target|success]", legendgroup="mean", showlegend=show_legend,
2282 hovertemplate=(f"{base}: %{{text}}<br>E[target|success]: %{{y:.3f}}<extra></extra>"),
2283 text=labels
2284 ), row=row_pos, col=1)
2286 # overlay optimal point for this base (single label at x*=opt)
2287 if optimal and base in opt_df.columns:
2288 lab_opt = str(opt_df.iloc[0][base])
2289 if lab_opt in labels:
2290 xi = float(labels.index(lab_opt))
2291 y_opt = float(opt_df.iloc[0]["pred_target_mean"])
2292 y_opt_sd = float(opt_df.iloc[0].get("pred_target_sd", np.nan))
2293 fig.add_trace(go.Scattergl(
2294 x=[xi], y=[y_opt], mode="markers",
2295 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"),
2296 name="optimal", legendgroup="optimal", showlegend=show_legend,
2297 hovertemplate=(f"predicted: %{{y:.3g}}"
2298 + ("" if np.isnan(y_opt_sd) else f" ± {y_opt_sd:.3g}")
2299 + f"<br>{base}: {lab_opt}<extra></extra>")
2300 ), row=row_pos, col=1)
2302 # overlay suggestions (optional)
2303 if suggest and (suggest_df is not None) and (base in suggest_df.columns):
2304 labs_sug = suggest_df[base].astype(str).tolist()
2305 xs = [labels.index(l) for l in labs_sug if l in labels]
2306 if xs:
2307 keep_mask = [l in labels for l in labs_sug]
2308 y_sug = suggest_df.loc[keep_mask, "pred_target_mean"].values
2309 fig.add_trace(go.Scattergl(
2310 x=np.array(xs, dtype=float), y=y_sug, mode="markers",
2311 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"),
2312 name="suggested", legendgroup="suggested", showlegend=show_legend,
2313 hovertemplate=(f"{base}: %{{text}}<br>"
2314 "predicted: %{{y:.3g}}<extra>suggested</extra>"),
2315 text=[labels[int(i)] for i in xs]
2316 ), row=row_pos, col=1)
2318 fig.update_xaxes(
2319 tickmode="array",
2320 tickvals=x_pos.tolist(),
2321 ticktext=labels,
2322 title_text=base,
2323 row=row_pos, col=1
2324 )
2325 fig.update_yaxes(title_text=f"{tgt_col}", row=row_pos, col=1)
2326 _set_yaxis_range(fig, row=row_pos, col=1,
2327 y0=y0_plot, y1=y1_plot,
2328 log=use_log_scale_for_target_y, eps=log_y_epsilon)
2330 # tidy rows
2331 for lab, mu_i, sd_i, p_i in zip(labels, mu_vec, sd_vec, p_vec):
2332 tidy_rows.append({
2333 "feature": base,
2334 "x_display": str(lab),
2335 "x_internal": float("nan"),
2336 "target_conditional_mean": float(mu_i),
2337 "target_conditional_sd": float(sd_i),
2338 "success_probability": float(p_i),
2339 })
2341 # ---- layout & IO ----
2342 parts = [f"1D PD at optimal setting of all other hyperparameters ({ds.attrs.get('target', 'target')})"]
2343 if kw_num_raw:
2344 def _fmt_c(v):
2345 if isinstance(v, slice):
2346 a = "" if v.start is None else f"{v.start:g}"
2347 b = "" if v.stop is None else f"{v.stop:g}"
2348 return f"[{a},{b}]"
2349 if isinstance(v, (list, tuple, np.ndarray)):
2350 try:
2351 return "[" + ",".join(f"{float(x):g}" for x in np.asarray(v).tolist()) + "]"
2352 except Exception:
2353 return "[" + ",".join(map(str, v)) + "]"
2354 try:
2355 return f"{float(v):g}"
2356 except Exception:
2357 return str(v)
2358 parts.append(", ".join(f"{k}={_fmt_c(v)}" for k, v in kw_num_raw.items()))
2359 if cat_fixed:
2360 parts.append(", ".join(f"{b}={lab}" for b, lab in cat_fixed.items()))
2361 title = " — ".join(parts)
2363 width = width if (width and width > 0) else 1200
2364 height = height if (height and height > 0) else 1200
2365 fig.update_layout(height=height, width=width, template="simple_white", title=title, legend_title_text="")
2367 if output:
2368 write_image(fig, output)
2369 if csv_out:
2370 csv_out = Path(csv_out); csv_out.parent.mkdir(parents=True, exist_ok=True)
2371 pd.DataFrame(tidy_rows).to_csv(str(csv_out), index=False)
2372 if show:
2373 fig.show("browser")
2374 return fig
2377def optimum_plot2d(
2378 model: xr.Dataset | Path | str,
2379 output: Path | None = None,
2380 grid_size: int = 70,
2381 use_log_scale_for_target: bool = False,
2382 log_shift_epsilon: float = 1e-9,
2383 colorscale: str = "RdBu",
2384 show: bool = False,
2385 n_contours: int = 12,
2386 optimal: bool = True,
2387 suggest: int = 0,
2388 width: int | None = None,
2389 height: int | None = None,
2390 seed: int | None = 42,
2391 **kwargs,
2392) -> go.Figure:
2393 """2D PD panels anchored at the optimal hyperparameter setting."""
2394 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model)
2395 pred_success, pred_loss = _build_predictors(ds)
2397 feature_names = [str(n) for n in ds["feature"].values.tolist()]
2398 transforms = [str(t) for t in ds["feature_transform"].values.tolist()]
2399 X_mean = ds["feature_mean"].values.astype(float)
2400 X_std = ds["feature_std"].values.astype(float)
2401 name_to_idx = {name: j for j, name in enumerate(feature_names)}
2403 df_raw = _raw_dataframe_from_dataset(ds)
2404 Xn_train = ds["Xn_train"].values.astype(float)
2405 n_rows = Xn_train.shape[0]
2407 groups = opt._onehot_groups(feature_names)
2408 bases = set(groups.keys())
2410 idx_map = _canon_key_set(ds)
2411 kw_num_raw: dict[str, object] = {}
2412 kw_cat_raw: dict[str, object] = {}
2413 for k, v in kwargs.items():
2414 if k in bases:
2415 kw_cat_raw[k] = v
2416 continue
2417 if k in idx_map:
2418 kw_num_raw[idx_map[k]] = v
2419 continue
2420 import re as _re
2421 nk = _re.sub(r"[^a-z0-9]+", "", str(k).lower())
2422 if nk in idx_map:
2423 kw_num_raw[idx_map[nk]] = v
2425 cat_fixed: dict[str, str] = {}
2426 for base, val in kw_cat_raw.items():
2427 labels = groups[base]["labels"]
2428 if isinstance(val, str):
2429 if val not in labels:
2430 raise ValueError(f"Unknown category for {base!r}: {val!r}. Choices: {labels}")
2431 cat_fixed[base] = val
2432 else:
2433 chosen = [x for x in (list(val) if isinstance(val, (list, tuple, set)) else [val])
2434 if isinstance(x, str) and x in labels]
2435 if not chosen:
2436 raise ValueError(f"No valid categories for {base!r} in {val!r}. Choices: {labels}")
2437 if len(chosen) == 1:
2438 cat_fixed[base] = chosen[0]
2439 else:
2440 raise ValueError("optimum_plot2d currently requires categorical bases to be fixed.")
2442 row_mask = np.ones(n_rows, dtype=bool)
2443 for base, label in cat_fixed.items():
2444 if base in df_raw.columns:
2445 series = df_raw[base].astype("string")
2446 row_mask &= series.eq(label).fillna(False).to_numpy()
2447 else:
2448 member_name = groups[base]["name_by_label"][label]
2449 j = name_to_idx[member_name]
2450 raw_j = feature_raw_from_artifact_or_reconstruct(ds, j, member_name, transforms[j]).astype(float)
2451 row_mask &= (raw_j >= 0.5)
2453 for name, val in kw_num_raw.items():
2454 if name not in name_to_idx:
2455 continue
2456 j = name_to_idx[name]
2457 if name in df_raw.columns:
2458 raw_vals = pd.to_numeric(df_raw[name], errors="coerce").to_numpy(dtype=float)
2459 else:
2460 raw_vals = feature_raw_from_artifact_or_reconstruct(ds, j, feature_names[j], transforms[j]).astype(float)
2461 mask = np.isfinite(raw_vals)
2462 if isinstance(val, slice):
2463 lo_raw = -np.inf if val.start is None else float(val.start)
2464 hi_raw = np.inf if val.stop is None else float(val.stop)
2465 if hi_raw < lo_raw:
2466 lo_raw, hi_raw = hi_raw, lo_raw
2467 mask &= (raw_vals >= lo_raw) & (raw_vals <= hi_raw)
2468 elif isinstance(val, (list, tuple, set, np.ndarray)):
2469 arr = np.asarray(list(val) if not isinstance(val, np.ndarray) else val, dtype=float)
2470 arr = arr[np.isfinite(arr)]
2471 if arr.size == 0:
2472 mask &= False
2473 else:
2474 mask &= np.any(np.isclose(raw_vals[:, None], arr[None, :], rtol=1e-6, atol=1e-9), axis=1)
2475 else:
2476 target = float(val)
2477 mask &= np.isclose(raw_vals, target, rtol=1e-6, atol=1e-9)
2478 row_mask &= mask
2480 if not np.any(row_mask):
2481 raise ValueError("No experiments match the provided constraints; nothing to plot.")
2483 df_raw_f = df_raw.loc[row_mask].reset_index(drop=True)
2485 opt_df = opt.optimal(model, count=1, seed=seed, **kwargs)
2486 x_opt_std = np.zeros(len(feature_names), dtype=float)
2488 def _to_std_single(j: int, x_orig: float) -> float:
2489 xi = x_orig
2490 if transforms[j] == "log10":
2491 xi = np.log10(np.maximum(x_orig, 1e-300))
2492 return float((xi - X_mean[j]) / X_std[j])
2494 onehot_members: set[str] = set()
2495 for base, g in groups.items():
2496 onehot_members.update(g["members"])
2498 for j, name in enumerate(feature_names):
2499 if name in onehot_members:
2500 continue
2501 if name in opt_df.columns:
2502 x_opt_std[j] = _to_std_single(j, float(opt_df.iloc[0][name]))
2503 else:
2504 x_opt_std[j] = float(np.median(Xn_train[:, j]))
2506 for base, g in groups.items():
2507 if base in cat_fixed:
2508 label = cat_fixed[base]
2509 elif base in opt_df.columns:
2510 label = str(opt_df.iloc[0][base])
2511 else:
2512 if base in df_raw.columns:
2513 label = str(df_raw[base].astype("string").mode(dropna=True).iloc[0])
2514 else:
2515 label = g["labels"][0]
2516 for lab in g["labels"]:
2517 member_name = g["name_by_label"][lab]
2518 j = name_to_idx[member_name]
2519 raw = 1.0 if lab == label else 0.0
2520 x_opt_std[j] = (raw - X_mean[j]) / X_std[j]
2522 def _orig_to_std(j: int, x, transforms, mu, sd):
2523 x = np.asarray(x, dtype=float)
2524 if transforms[j] == "log10":
2525 x = np.where(x <= 0, np.nan, x)
2526 x = np.log10(x)
2527 return (x - mu[j]) / sd[j]
2529 fixed_scalars_std: dict[int, float] = {}
2530 range_windows_std: dict[int, tuple[float, float]] = {}
2531 choice_values_std: dict[int, np.ndarray] = {}
2533 for name, val in kw_num_raw.items():
2534 if name not in name_to_idx:
2535 continue
2536 j = name_to_idx[name]
2537 if isinstance(val, slice):
2538 lo = _orig_to_std(j, float(val.start), transforms, X_mean, X_std)
2539 hi = _orig_to_std(j, float(val.stop), transforms, X_mean, X_std)
2540 lo, hi = float(min(lo, hi)), float(max(lo, hi))
2541 range_windows_std[j] = (lo, hi)
2542 elif isinstance(val, (list, tuple, np.ndarray)):
2543 arr = _orig_to_std(j, np.asarray(val, dtype=float), transforms, X_mean, X_std)
2544 choice_values_std[j] = np.asarray(arr, dtype=float)
2545 else:
2546 fixed_scalars_std[j] = float(_orig_to_std(j, float(val), transforms, X_mean, X_std))
2548 for j, v in fixed_scalars_std.items():
2549 x_opt_std[j] = v
2551 free_numeric_idx = [
2552 j for j, name in enumerate(feature_names)
2553 if (j not in fixed_scalars_std) and (name not in onehot_members)
2554 ]
2555 if len(free_numeric_idx) == 0:
2556 raise ValueError("All numeric features are fixed at the optimum; nothing to plot.")
2558 grids_std_num: dict[int, np.ndarray] = {}
2559 raw_full_cache: dict[int, np.ndarray] = {}
2560 Xn_p01 = np.percentile(Xn_train, 1, axis=0)
2561 Xn_p99 = np.percentile(Xn_train, 99, axis=0)
2563 def _grid_std_num(j: int) -> np.ndarray:
2564 lo, hi = float(Xn_p01[j]), float(Xn_p99[j])
2565 if j in range_windows_std:
2566 lo = max(lo, range_windows_std[j][0])
2567 hi = min(hi, range_windows_std[j][1])
2568 if j in choice_values_std:
2569 vals = np.asarray(choice_values_std[j], dtype=float)
2570 vals = vals[(vals >= lo) & (vals <= hi)]
2571 return np.unique(np.sort(vals)) if vals.size else np.array([x_opt_std[j]], dtype=float)
2572 if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
2573 lo, hi = x_opt_std[j] - 1.0, x_opt_std[j] + 1.0
2574 return np.linspace(lo, hi, grid_size)
2576 for j in free_numeric_idx:
2577 if feature_names[j] in df_raw.columns:
2578 raw_vals = pd.to_numeric(df_raw[feature_names[j]], errors="coerce").to_numpy(dtype=float)
2579 else:
2580 raw_vals = feature_raw_from_artifact_or_reconstruct(ds, j, feature_names[j], transforms[j]).astype(float)
2581 raw_full_cache[j] = raw_vals
2582 grid = _grid_std_num(j)
2583 if (j not in range_windows_std) and (j not in choice_values_std):
2584 finite_raw = raw_vals[np.isfinite(raw_vals)]
2585 if transforms[j] == "log10":
2586 finite_raw = finite_raw[finite_raw > 0]
2587 if finite_raw.size:
2588 finite_std = _orig_to_std(j, finite_raw, transforms, X_mean, X_std)
2589 grid_min = float(np.nanmin(np.concatenate([grid, finite_std])))
2590 grid_max = float(np.nanmax(np.concatenate([grid, finite_std])))
2591 if grid_max > grid_min:
2592 grid = np.linspace(grid_min, grid_max, grid_size)
2593 grids_std_num[j] = grid
2595 subplot_titles = [feature_names[j] for j in free_numeric_idx]
2596 k = len(free_numeric_idx)
2597 fig = make_subplots(
2598 rows=k,
2599 cols=k,
2600 shared_xaxes=True,
2601 shared_yaxes=True,
2602 horizontal_spacing=0.01,
2603 vertical_spacing=0.01,
2604 subplot_titles=subplot_titles,
2605 )
2607 optimal_df = opt_df.copy() if optimal else None
2608 suggest_df = opt.suggest(ds, count=suggest, seed=seed, **kwargs) if (suggest and suggest > 0) else None
2610 tgt_col = str(ds.attrs.get("target", "target"))
2611 success_mask = ~pd.isna(df_raw_f[tgt_col]).to_numpy()
2612 fail_mask = ~success_mask
2614 all_blocks: list[np.ndarray] = []
2615 cell_payload: dict[tuple[int, int], dict] = {}
2616 base_std = x_opt_std.copy()
2618 def _denorm_inv_opt(j: int, std_vals: np.ndarray) -> np.ndarray:
2619 internal = std_vals * X_std[j] + X_mean[j]
2620 return _inverse_transform(transforms[j], internal)
2622 for row_idx, i in enumerate(free_numeric_idx):
2623 for col_idx, j in enumerate(free_numeric_idx):
2624 xg = grids_std_num[j]
2625 yg = grids_std_num[i]
2626 if i == j:
2627 grid = grids_std_num[j]
2628 Xn_1d = np.repeat(base_std[None, :], len(grid), axis=0)
2629 Xn_1d[:, j] = grid
2630 mu_1d, _ = pred_loss(Xn_1d, include_observation_noise=True)
2631 p_1d = pred_success(Xn_1d)
2632 Zmu = 0.5 * (mu_1d[:, None] + mu_1d[None, :])
2633 Zp = np.minimum(p_1d[:, None], p_1d[None, :])
2634 x_orig = _denorm_inv_opt(j, grid)
2635 y_orig = x_orig
2636 else:
2637 XX, YY = np.meshgrid(xg, yg)
2638 Xn_grid = np.repeat(base_std[None, :], XX.size, axis=0)
2639 Xn_grid[:, j] = XX.ravel()
2640 Xn_grid[:, i] = YY.ravel()
2641 mu_flat, _ = pred_loss(Xn_grid, include_observation_noise=True)
2642 p_flat = pred_success(Xn_grid)
2643 Zmu = mu_flat.reshape(YY.shape)
2644 Zp = p_flat.reshape(YY.shape)
2645 x_orig = _denorm_inv_opt(j, xg)
2646 y_orig = _denorm_inv_opt(i, yg)
2647 cell_payload[(row_idx, col_idx)] = dict(i=i, j=j, x=x_orig, y=y_orig, Zmu=Zmu, Zp=Zp)
2648 all_blocks.append(Zmu.ravel())
2650 def _color_xform(z_raw: np.ndarray) -> tuple[np.ndarray, float]:
2651 if not use_log_scale_for_target:
2652 return z_raw, 0.0
2653 zmin = float(np.nanmin(z_raw))
2654 shift = 0.0 if zmin > 0 else -zmin + float(log_shift_epsilon)
2655 return np.log10(np.maximum(z_raw + shift, log_shift_epsilon)), shift
2657 z_all = np.concatenate(all_blocks) if all_blocks else np.array([0.0, 1.0])
2658 z_all_t, global_shift = _color_xform(z_all)
2659 cmin_t = float(np.nanmin(z_all_t))
2660 cmax_t = float(np.nanmax(z_all_t))
2661 cs = get_colorscale(colorscale)
2663 def _contour_line_color(level_raw: float) -> str:
2664 zt = np.log10(max(level_raw + global_shift, log_shift_epsilon)) if use_log_scale_for_target else level_raw
2665 t = 0.5 if cmax_t == cmin_t else (zt - cmin_t) / (cmax_t - cmin_t)
2666 rgb = sample_colorscale(cs, [float(np.clip(t, 0.0, 1.0))])[0]
2667 r, g, b = _rgb_string_to_tuple(rgb)
2668 lum = (0.2126*r + 0.7152*g + 0.0722*b)/255.0
2669 grey = int(round((1.0 - lum) * 255))
2670 return f"rgba({grey},{grey},{grey},0.9)"
2672 def _data_vals_for_feature(j_full: int) -> np.ndarray:
2673 name = feature_names[j_full]
2674 if name in df_raw_f.columns:
2675 return df_raw_f[name].to_numpy(dtype=float)
2676 vals = feature_raw_from_artifact_or_reconstruct(ds, j_full, name, transforms[j_full]).astype(float)
2677 return vals[row_mask]
2679 for (r, c), payload in cell_payload.items():
2680 Zmu_raw = payload["Zmu"]
2681 Zp = payload["Zp"]
2682 Z_t, _ = _color_xform(Zmu_raw)
2683 x_vals = payload["x"]
2684 y_vals = payload["y"]
2685 if payload["i"] == payload["j"]:
2686 diag_vals = np.asarray(x_vals, dtype=float)
2687 x_vals = diag_vals
2688 y_vals = diag_vals
2689 fig.add_trace(go.Heatmap(
2690 x=x_vals, y=y_vals, z=Z_t,
2691 coloraxis="coloraxis", zsmooth=False, showscale=False,
2692 hovertemplate=(f"{feature_names[payload['j']]}: %{{x:.6g}}<br>"
2693 f"{feature_names[payload['i']]}: %{{y:.6g}}"
2694 "<br>E[target|success]: %{customdata:.3f}<extra></extra>"),
2695 customdata=Zmu_raw
2696 ), row=r+1, col=c+1)
2698 for thr, alpha in ((0.5, 0.25), (0.8, 0.40)):
2699 mask = np.where(Zp < thr, 1.0, np.nan)
2700 fig.add_trace(go.Heatmap(
2701 x=x_vals, y=y_vals, z=mask, zmin=0, zmax=1,
2702 colorscale=[[0, "rgba(0,0,0,0)"], [1, f"rgba(128,128,128,{alpha})"]],
2703 showscale=False, hoverinfo="skip"
2704 ), row=r+1, col=c+1)
2706 zmin_r, zmax_r = float(np.nanmin(Zmu_raw)), float(np.nanmax(Zmu_raw))
2707 levels = np.linspace(zmin_r, zmax_r, max(n_contours, 2))
2708 for lev in levels:
2709 color = _contour_line_color(lev)
2710 fig.add_trace(go.Contour(
2711 x=x_vals, y=y_vals, z=Zmu_raw,
2712 autocontour=False,
2713 contours=dict(coloring="lines", showlabels=False, start=lev, end=lev, size=1e-9),
2714 line=dict(width=1),
2715 colorscale=[[0, color], [1, color]],
2716 showscale=False, hoverinfo="skip"
2717 ), row=r+1, col=c+1)
2719 xd = _data_vals_for_feature(payload["j"])
2720 yd = _data_vals_for_feature(payload["i"])
2721 show_leg = (r == 0 and c == 0)
2722 fig.add_trace(go.Scattergl(
2723 x=xd[success_mask], y=yd[success_mask], mode="markers",
2724 marker=dict(size=4, color="black", line=dict(width=0)),
2725 name="data (success)", legendgroup="data_succ", showlegend=show_leg,
2726 hovertemplate=("trial_id: %{customdata[0]}<br>"
2727 f"{feature_names[payload['j']]}: %{{x:.6g}}<br>"
2728 f"{feature_names[payload['i']]}: %{{y:.6g}}<br>"
2729 f"{tgt_col}: %{{customdata[1]:.4f}}<extra></extra>"),
2730 customdata=np.column_stack([
2731 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask],
2732 df_raw_f[tgt_col].to_numpy()[success_mask],
2733 ])
2734 ), row=r+1, col=c+1)
2735 fig.add_trace(go.Scattergl(
2736 x=xd[fail_mask], y=yd[fail_mask], mode="markers",
2737 marker=dict(size=5, color="red", line=dict(color="black", width=0.8)),
2738 name="data (failed)", legendgroup="data_fail", showlegend=show_leg,
2739 hovertemplate=("trial_id: %{customdata}<br>"
2740 f"{feature_names[payload['j']]}: %{{x:.6g}}<br>"
2741 f"{feature_names[payload['i']]}: %{{y:.6g}}<br>"
2742 "status: failed (NaN target)<extra></extra>"),
2743 customdata=df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask]
2744 ), row=r+1, col=c+1)
2746 if (
2747 optimal
2748 and optimal_df is not None
2749 and feature_names[payload["j"]] in optimal_df.columns
2750 and feature_names[payload["i"]] in optimal_df.columns
2751 ):
2752 ox = np.asarray(optimal_df[feature_names[payload["j"]]].values, dtype=float)
2753 oy = np.asarray(optimal_df[feature_names[payload["i"]]].values, dtype=float)
2754 pmu = float(optimal_df["pred_target_mean"].values[0])
2755 psd = float(optimal_df.get("pred_target_sd", pd.Series([np.nan])).values[0])
2756 fig.add_trace(go.Scattergl(
2757 x=ox, y=oy, mode="markers",
2758 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"),
2759 name="optimal", legendgroup="optimal", showlegend=show_leg,
2760 hovertemplate=(f"predicted: {pmu:.3g}"
2761 + ("" if np.isnan(psd) else f" ± {psd:.3g}")
2762 + f"<br>{feature_names[payload['j']]}: %{{x:.6g}}"
2763 f"<br>{feature_names[payload['i']]}: %{{y:.6g}}<extra></extra>")
2764 ), row=r+1, col=c+1)
2766 if (
2767 suggest
2768 and suggest_df is not None
2769 and feature_names[payload["j"]] in suggest_df.columns
2770 and feature_names[payload["i"]] in suggest_df.columns
2771 ):
2772 xs = suggest_df[feature_names[payload["j"]]].values.astype(float)
2773 ys = suggest_df[feature_names[payload["i"]]].values.astype(float)
2774 ymu = suggest_df["pred_target_mean"].values.astype(float)
2775 ysd = suggest_df.get("pred_target_sd", pd.Series([np.nan]*len(suggest_df))).values
2776 fig.add_trace(go.Scattergl(
2777 x=xs, y=ys, mode="markers",
2778 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"),
2779 name="suggested", legendgroup="suggested", showlegend=show_leg,
2780 hovertemplate=("predicted: %{customdata[0]:.3f}"
2781 + (" ± %{customdata[1]:.3f}" if not np.isnan(ysd).all() else "")
2782 + f"<br>{feature_names[payload['j']]}: %{{x:.6g}}"
2783 f"<br>{feature_names[payload['i']]}: %{{y:.6g}}<extra>suggested</extra>"),
2784 customdata=np.column_stack([ymu, ysd])
2785 ), row=r+1, col=c+1)
2787 _maybe_log_axis(fig, row=r+1, col=c+1, name=feature_names[payload["j"]], axis="x", transforms=transforms, j=payload["j"])
2788 _maybe_log_axis(fig, row=r+1, col=c+1, name=feature_names[payload["i"]], axis="y", transforms=transforms, j=payload["i"])
2789 if r == k - 1:
2790 fig.update_xaxes(title_text=feature_names[payload["j"]], row=r+1, col=c+1)
2791 else:
2792 fig.update_xaxes(tickmode=None, row=r+1, col=c+1)
2793 if c == 0:
2794 fig.update_yaxes(title_text=feature_names[payload["i"]], row=r+1, col=c+1)
2795 else:
2796 fig.update_yaxes(tickmode=None, row=r+1, col=c+1)
2798 if payload["j"] in raw_full_cache and payload["j"] not in range_windows_std and payload["j"] not in choice_values_std:
2799 finite_raw = raw_full_cache[payload["j"]][np.isfinite(raw_full_cache[payload["j"]])]
2800 if transforms[payload["j"]] == "log10":
2801 finite_raw = finite_raw[finite_raw > 0]
2802 if finite_raw.size:
2803 x0 = float(np.min(finite_raw)); x1 = float(np.max(finite_raw))
2804 if transforms[payload["j"]] == "log10":
2805 x0 = max(x0, 1e-12); x1 = max(x1, x0 * (1 + 1e-9))
2806 pad = (x1 / x0) ** 0.03
2807 fig.update_xaxes(range=[np.log10(x0 / pad), np.log10(x1 * pad)], row=r+1, col=c+1)
2808 else:
2809 span = (x1 - x0) or 1.0
2810 pad = 0.02 * span
2811 fig.update_xaxes(range=[x0 - pad, x1 + pad], row=r+1, col=c+1)
2813 if payload["i"] in raw_full_cache and payload["i"] not in range_windows_std and payload["i"] not in choice_values_std:
2814 finite_raw = raw_full_cache[payload["i"]][np.isfinite(raw_full_cache[payload["i"]])]
2815 if transforms[payload["i"]] == "log10":
2816 finite_raw = finite_raw[finite_raw > 0]
2817 if finite_raw.size:
2818 y0 = float(np.min(finite_raw)); y1 = float(np.max(finite_raw))
2819 if transforms[payload["i"]] == "log10":
2820 y0 = max(y0, 1e-12); y1 = max(y1, y0 * (1 + 1e-9))
2821 pad = (y1 / y0) ** 0.03
2822 fig.update_yaxes(range=[np.log10(y0 / pad), np.log10(y1 * pad)], row=r+1, col=c+1)
2823 else:
2824 span = (y1 - y0) or 1.0
2825 pad = 0.02 * span
2826 fig.update_yaxes(range=[y0 - pad, y1 + pad], row=r+1, col=c+1)
2828 z_title = "E[target|success]" + (" (log10)" if use_log_scale_for_target else "")
2829 if use_log_scale_for_target and global_shift > 0:
2830 z_title += f" (shift Δ={global_shift:.3g})"
2832 width = width if (width and width > 0) else 1100
2833 height = height if (height and height > 0) else 1100
2834 fig.update_layout(
2835 height=height,
2836 width=width,
2837 template="simple_white",
2838 coloraxis=dict(
2839 colorscale=colorscale,
2840 cmin=cmin_t, cmax=cmax_t,
2841 colorbar=dict(
2842 title=z_title,
2843 thickness=10, # thinner bar
2844 len=0.55, # shorter bar (fraction of plot height)
2845 lenmode="fraction",
2846 x=1.02, y=0.5, # just right of plot, vertically centered
2847 xanchor="left", yanchor="middle",
2848 ),
2849 ),
2850 legend=dict(
2851 orientation="v",
2852 x=1.02, xanchor="left", # to the right of the colorbar
2853 y=1.0, yanchor="top",
2854 bgcolor="rgba(255,255,255,0.85)"
2855 ),
2856 title=f"2D PD at optimal setting of all other hyperparameters ({tgt_col})",
2857 legend_title_text="",
2858 )
2860 if output:
2861 write_image(fig, output)
2862 if show:
2863 fig.show("browser")
2864 return fig
2867def write_image(fig, output:Path|str):
2868 """Write a Plotly figure to an image file (PNG, JPEG, etc). Requires kaleido."""
2869 output = Path(output)
2870 output.parent.mkdir(parents=True, exist_ok=True)
2871 if output.suffix.lower() == ".html":
2872 fig.write_html(str(output))
2873 else:
2874 fig.write_image(str(output))