Coverage for psyop/viz.py: 29.13%

1651 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-10-29 03:44 +0000

1# -*- coding: utf-8 -*- 

2from pathlib import Path 

3import numpy as np 

4import pandas as pd 

5import xarray as xr 

6 

7import plotly.graph_objects as go 

8from plotly.subplots import make_subplots 

9from plotly.colors import get_colorscale, sample_colorscale 

10 

11from .model import ( 

12 kernel_diag_m52, 

13 kernel_m52_ard, 

14 add_jitter, 

15 solve_chol, 

16 solve_lower, 

17 feature_raw_from_artifact_or_reconstruct, 

18) 

19from . import opt 

20 

21 

22def _canon_key_set(ds) -> dict[str, str]: 

23 feats = [str(x) for x in ds["feature"].values.tolist()] 

24 def _norm(s: str) -> str: 

25 import re 

26 return re.sub(r"[^a-z0-9]+", "", s.lower()) 

27 return {**{f: f for f in feats}, **{_norm(f): f for f in feats}} 

28 

29 

30def _edges_from_centers(vals: np.ndarray, is_log: bool) -> tuple[float, float]: 

31 """Return (min_edge, max_edge) that tightly bound a heatmap with given center coords.""" 

32 v = np.asarray(vals, float) 

33 v = v[np.isfinite(v)] 

34 if v.size == 0: 

35 return (0.0, 1.0) 

36 if v.size == 1: 

37 # tiny pad 

38 if is_log: 

39 lo = max(v[0] / 1.5, 1e-12) 

40 hi = v[0] * 1.5 

41 else: 

42 span = max(abs(v[0]) * 0.5, 1e-9) 

43 lo, hi = v[0] - span, v[0] + span 

44 return float(lo), float(hi) 

45 

46 if is_log: 

47 lv = np.log10(v) 

48 l0 = lv[0] - 0.5 * (lv[1] - lv[0]) 

49 lN = lv[-1] + 0.5 * (lv[-1] - lv[-2]) 

50 lo = 10.0 ** l0 

51 hi = 10.0 ** lN 

52 lo = max(lo, 1e-12) 

53 hi = max(hi, lo * 1.0000001) 

54 return float(lo), float(hi) 

55 else: 

56 d0 = v[1] - v[0] 

57 dN = v[-1] - v[-2] 

58 lo = v[0] - 0.5 * d0 

59 hi = v[-1] + 0.5 * dN 

60 if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo: 

61 lo, hi = float(v.min()), float(v.max()) 

62 return float(lo), float(hi) 

63 

64def _update_axis_type_and_range( 

65 fig, *, row: int, col: int, axis: str, centers: np.ndarray, is_log: bool 

66): 

67 """Set axis type and range to heatmap edges so tiles meet the axes exactly.""" 

68 lo, hi = _edges_from_centers(centers, is_log) 

69 if axis == "x": 

70 if is_log: 

71 fig.update_xaxes(type="log", range=[np.log10(lo), np.log10(hi)], row=row, col=col) 

72 else: 

73 fig.update_xaxes(range=[lo, hi], row=row, col=col) 

74 else: 

75 if is_log: 

76 fig.update_yaxes(type="log", range=[np.log10(lo), np.log10(hi)], row=row, col=col) 

77 else: 

78 fig.update_yaxes(range=[lo, hi], row=row, col=col) 

79 

80 

81def plot2d( 

82 model: xr.Dataset | Path | str, 

83 output: Path | None = None, 

84 grid_size: int = 70, 

85 use_log_scale_for_target: bool = False, # log10 colors for heatmaps 

86 log_shift_epsilon: float = 1e-9, 

87 colorscale: str = "RdBu", 

88 show: bool = False, 

89 n_contours: int = 12, 

90 optimal: bool = True, 

91 suggest: int = 0, 

92 seed: int|None = 42, 

93 width:int|None = None, 

94 height:int|None = None, 

95 **kwargs, 

96) -> go.Figure: 

97 """ 

98 2D Partial Dependence of E[target|success] (pairwise features), including 

99 categorical variables as single axes (one row/column per base). 

100 

101 Conditioning via kwargs (original units): 

102 - numeric: fixed scalar, slice(lo,hi), list/tuple choices 

103 - categorical base (e.g. language="Linear A" or language=("Linear A","Linear B")): 

104 * single string: fixed to that label (axis removed and clamped) 

105 * list/tuple of labels: restrict the categorical axis to those labels 

106 

107 Notes: 

108 * one-hot member features (e.g. language=Linear A) never appear as axes. 

109 * when a categorical axis is present, we render a heatmap over category index 

110 with tick labels set to the category names; data overlays use jitter. 

111 """ 

112 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model) 

113 pred_success, pred_loss = _build_predictors(ds) 

114 

115 # --- features & transforms 

116 feature_names = [str(x) for x in ds["feature"].values.tolist()] 

117 transforms = [str(t) for t in ds["feature_transform"].values.tolist()] 

118 X_mean = ds["feature_mean"].values.astype(float) 

119 X_std = ds["feature_std"].values.astype(float) 

120 name_to_idx = {nm: i for i, nm in enumerate(feature_names)} 

121 

122 # one-hot groups 

123 groups = opt._onehot_groups(feature_names) # { base: {"labels":[...], "name_by_label":{label->member}, "members":[...] } } 

124 bases = set(groups.keys()) 

125 onehot_member_names = {m for g in groups.values() for m in g["members"]} 

126 

127 # raw df + train design 

128 df_raw = _raw_dataframe_from_dataset(ds) 

129 Xn_train = ds["Xn_train"].values.astype(float) 

130 n_rows = Xn_train.shape[0] 

131 

132 # --- split kwargs into numeric vs categorical (keys are canonical already when coming from CLI) 

133 kw_num: dict[str, object] = {} 

134 kw_cat: dict[str, object] = {} 

135 for k, v in (kwargs or {}).items(): 

136 if k in bases: 

137 kw_cat[k] = v 

138 elif k in name_to_idx: 

139 kw_num[k] = v 

140 else: 

141 # unknown/ignored 

142 pass 

143 

144 # --- resolve categorical constraints: 

145 # - cat_fixed[base] -> fixed single label (axis removed and clamped) 

146 # - cat_allowed[base] -> labels that are allowed on that axis (if not fixed) 

147 cat_fixed: dict[str, str] = {} 

148 cat_allowed: dict[str, list[str]] = {} 

149 for base in bases: 

150 labels = list(groups[base]["labels"]) 

151 if base not in kw_cat: 

152 cat_allowed[base] = labels # unrestricted axis (if not fixed numerically later) 

153 continue 

154 val = kw_cat[base] 

155 if isinstance(val, str): 

156 if val not in labels: 

157 raise ValueError(f"Unknown category for {base!r}: {val!r}. Choices: {labels}") 

158 cat_fixed[base] = val 

159 elif isinstance(val, (list, tuple, set)): 

160 chosen = [x for x in val if isinstance(x, str) and x in labels] 

161 if not chosen: 

162 raise ValueError(f"No valid categories for {base!r} in {val!r}. Choices: {labels}") 

163 # if only one remains, treat as fixed; else allowed list 

164 if len(chosen) == 1: 

165 cat_fixed[base] = chosen[0] 

166 else: 

167 cat_allowed[base] = chosen 

168 else: 

169 raise ValueError(f"Categorical constraint for {base!r} must be a string or list/tuple of strings.") 

170 

171 # --- filter rows to categorical *fixed* selections for medians/percentiles & overlays 

172 row_mask = np.ones(n_rows, dtype=bool) 

173 for base, label in cat_fixed.items(): 

174 if base in df_raw.columns: 

175 row_mask &= (df_raw[base].astype("string") == pd.Series([label]*len(df_raw), dtype="string")).to_numpy() 

176 else: 

177 member = groups[base]["name_by_label"][label] 

178 j = name_to_idx[member] 

179 raw_j = feature_raw_from_artifact_or_reconstruct(ds, j, member, transforms[j]).astype(float) 

180 row_mask &= (raw_j >= 0.5) 

181 

182 # --- numeric constraints (standardized) 

183 def _orig_to_std(j: int, x, transforms, mu, sd): 

184 x = np.asarray(x, dtype=float) 

185 if transforms[j] == "log10": 

186 x = np.where(x <= 0, np.nan, x) 

187 x = np.log10(x) 

188 return (x - mu[j]) / sd[j] 

189 

190 fixed_scalars_std: dict[int, float] = {} 

191 range_windows_std: dict[int, tuple[float, float]] = {} 

192 choice_values_std: dict[int, np.ndarray] = {} 

193 

194 for name, val in kw_num.items(): 

195 j = name_to_idx[name] 

196 if isinstance(val, slice): 

197 lo = _orig_to_std(j, float(val.start), transforms, X_mean, X_std) 

198 hi = _orig_to_std(j, float(val.stop), transforms, X_mean, X_std) 

199 lo, hi = float(min(lo, hi)), float(max(lo, hi)) 

200 range_windows_std[j] = (lo, hi) 

201 elif isinstance(val, (list, tuple, np.ndarray)): 

202 arr = _orig_to_std(j, np.asarray(val, dtype=float), transforms, X_mean, X_std) 

203 choice_values_std[j] = np.asarray(arr, dtype=float) 

204 else: 

205 fixed_scalars_std[j] = float(_orig_to_std(j, float(val), transforms, X_mean, X_std)) 

206 

207 # --- apply categorical *fixed* selections as standardized 0/1 on their member features 

208 for base, label in cat_fixed.items(): 

209 labels = groups[base]["labels"] 

210 for lab in labels: 

211 member = groups[base]["name_by_label"][lab] 

212 j = name_to_idx[member] 

213 raw_val = 1.0 if (lab == label) else 0.0 

214 fixed_scalars_std[j] = float(_orig_to_std(j, raw_val, transforms, X_mean, X_std)) 

215 

216 # --- enforce row-level filters so overlays/points respect constraints --- 

217 for base, allowed in cat_allowed.items(): 

218 if (base not in kw_cat) or (base in cat_fixed): 

219 continue 

220 allowed_labels = [str(x) for x in allowed] 

221 if base in df_raw.columns: 

222 series = df_raw[base].astype("string").fillna("<NA>") 

223 if not allowed_labels: 

224 row_mask &= False 

225 else: 

226 allowed_mask = series.isin(set(allowed_labels)).fillna(False).to_numpy() 

227 row_mask &= allowed_mask 

228 else: 

229 allowed_masks: list[np.ndarray] = [] 

230 for label in allowed_labels: 

231 member = groups[base]["name_by_label"].get(label) 

232 if member is None: 

233 continue 

234 j = name_to_idx[member] 

235 raw_j = feature_raw_from_artifact_or_reconstruct(ds, j, member, transforms[j]).astype(float) 

236 allowed_masks.append(raw_j >= 0.5) 

237 if allowed_masks: 

238 row_mask &= np.logical_or.reduce(allowed_masks) 

239 else: 

240 row_mask &= False 

241 

242 for name, val in kw_num.items(): 

243 if name not in name_to_idx: 

244 continue 

245 j = name_to_idx[name] 

246 if name in df_raw.columns: 

247 raw_vals = pd.to_numeric(df_raw[name], errors="coerce").to_numpy(dtype=float) 

248 else: 

249 raw_vals = feature_raw_from_artifact_or_reconstruct(ds, j, feature_names[j], transforms[j]).astype(float) 

250 

251 mask = np.isfinite(raw_vals) 

252 if isinstance(val, slice): 

253 lo_raw = -np.inf if val.start is None else float(val.start) 

254 hi_raw = np.inf if val.stop is None else float(val.stop) 

255 if hi_raw < lo_raw: 

256 lo_raw, hi_raw = hi_raw, lo_raw 

257 mask &= (raw_vals >= lo_raw) & (raw_vals <= hi_raw) 

258 elif isinstance(val, (list, tuple, set, np.ndarray)): 

259 arr = np.asarray(list(val) if not isinstance(val, np.ndarray) else val, dtype=float) 

260 arr = arr[np.isfinite(arr)] 

261 if arr.size == 0: 

262 mask &= False 

263 else: 

264 close_mask = np.any(np.isclose(raw_vals[:, None], arr[None, :], rtol=1e-6, atol=1e-9), axis=1) 

265 mask &= close_mask 

266 else: 

267 target = float(val) 

268 mask &= np.isclose(raw_vals, target, rtol=1e-6, atol=1e-9) 

269 

270 row_mask &= mask 

271 

272 if not np.any(row_mask): 

273 raise ValueError("No experiments match the provided constraints; cannot plot data points.") 

274 

275 row_mask_active = not bool(np.all(row_mask)) 

276 df_raw_f = df_raw.loc[row_mask].reset_index(drop=True) if row_mask_active else df_raw 

277 Xn_train_f = Xn_train[row_mask, :] if row_mask_active else Xn_train 

278 

279 # --- free axes = numeric features not scalar-fixed & not one-hot members, plus categorical bases not fixed 

280 free_numeric_idx = [ 

281 j for j, nm in enumerate(feature_names) 

282 if (j not in fixed_scalars_std) and (nm not in onehot_member_names) 

283 ] 

284 free_cat_bases = [b for b in bases if b not in cat_fixed] # we already filtered by allowed above 

285 

286 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases] 

287 if not panels: 

288 raise ValueError("All features are fixed (or only single-category categoricals remain); nothing to plot.") 

289 

290 # --- base point (median in standardized space of filtered rows), then apply scalar fixes 

291 base_std = np.median(Xn_train_f, axis=0) 

292 for j, vstd in fixed_scalars_std.items(): 

293 base_std[j] = vstd 

294 

295 # --- per-feature grids (numeric) over filtered 1–99% + respecting ranges/choices 

296 p01p99 = [np.percentile(Xn_train_f[:, j], [1, 99]) for j in range(len(feature_names))] 

297 def _grid_std_num(j: int) -> np.ndarray: 

298 p01, p99 = p01p99[j] 

299 if j in choice_values_std: 

300 vals = np.asarray(choice_values_std[j], dtype=float) 

301 vals = vals[(vals >= p01) & (vals <= p99)] 

302 return np.unique(np.sort(vals)) if vals.size else np.array([np.median(Xn_train_f[:, j])]) 

303 lo, hi = p01, p99 

304 if j in range_windows_std: 

305 rlo, rhi = range_windows_std[j] 

306 lo, hi = max(lo, rlo), min(hi, rhi) 

307 if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo: 

308 hi = lo + 1e-9 

309 return np.linspace(lo, hi, grid_size) 

310 

311 grids_std_num = {j: _grid_std_num(j) for j in free_numeric_idx} 

312 

313 # --- helpers for categorical evaluation --------------------------------- 

314 def _std_for_member(member_name: str, raw01: float) -> float: 

315 j = name_to_idx[member_name] 

316 return float(_orig_to_std(j, raw01, transforms, X_mean, X_std)) 

317 

318 def _apply_onehot_for_base(Xn_block: np.ndarray, base: str, label: str) -> None: 

319 # set the whole block's rows to the 0/1 standardized values for this label 

320 for lab in groups[base]["labels"]: 

321 member = groups[base]["name_by_label"][lab] 

322 j = name_to_idx[member] 

323 Xn_block[:, j] = _std_for_member(member, 1.0 if lab == label else 0.0) 

324 

325 def _denorm_inv(j: int, std_vals: np.ndarray) -> np.ndarray: 

326 internal = std_vals * X_std[j] + X_mean[j] 

327 return _inverse_transform(transforms[j], internal) 

328 

329 # 1) Robustly detect one-hot member columns. 

330 # Use both the detector output AND a fallback "base=" prefix scan, 

331 # so any columns like "language=Linear A" are guaranteed to be excluded. 

332 onehot_member_names: set[str] = set() 

333 for base, g in groups.items(): 

334 # detector-known members 

335 onehot_member_names.update(g["members"]) 

336 # prefix fallback 

337 prefix = f"{base}=" 

338 onehot_member_names.update([nm for nm in feature_names if nm.startswith(prefix)]) 

339 

340 # 2) Build panel list: keep numeric features that are not scalar-fixed AND 

341 # are not one-hot members; plus categorical bases that are not fixed. 

342 free_numeric_idx = [ 

343 j for j, nm in enumerate(feature_names) 

344 if (j not in fixed_scalars_std) and (nm not in onehot_member_names) 

345 ] 

346 free_cat_bases = [b for b in bases if b not in cat_fixed] 

347 

348 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases] 

349 if not panels: 

350 raise ValueError("All features are fixed (or only single-category categoricals remain); nothing to plot.") 

351 

352 # 3) Sanity check: no one-hot member should survive as a numeric panel. 

353 assert all( 

354 (feature_names[key] not in onehot_member_names) if kind == "num" else True 

355 for kind, key in panels 

356 ), "internal: one-hot member leaked into numeric panels" 

357 

358 # 4) Subplot scaffold (matrix layout k x k) with clear titles. 

359 def _panel_title(kind: str, key: object) -> str: 

360 return feature_names[int(key)] if kind == "num" else str(key) 

361 

362 k = len(panels) 

363 fig = make_subplots( 

364 rows=k, 

365 cols=k, 

366 shared_xaxes=False, 

367 shared_yaxes=False, 

368 horizontal_spacing=0.03, 

369 vertical_spacing=0.03, 

370 subplot_titles=[_panel_title(kind, key) for kind, key in panels], 

371 ) 

372 

373 # (Keep the rest of your cell-evaluation and rendering logic unchanged. 

374 # Because we filtered `onehot_member_names`, rows/columns like 

375 # "language=Linear A" / "language=Linear B" will no longer appear. 

376 # Categorical bases (e.g., "language") will show as a single axis.) 

377 

378 

379 # overlays prepared under the SAME constraints (pass original kwargs straight through) 

380 optimal_df = opt.optimal(ds, count=1, seed=seed, **kwargs) if optimal else None 

381 suggest_df = opt.suggest(ds, count=suggest, seed=seed, **kwargs) if (suggest and suggest > 0) else None 

382 

383 # masks for data overlays (already filtered if cat_fixed) 

384 tgt_col = str(ds.attrs["target"]) 

385 success_mask = ~pd.isna(df_raw_f[tgt_col]).to_numpy() 

386 fail_mask = ~success_mask 

387 

388 # collect Z blocks for global color bounds 

389 all_blocks: list[np.ndarray] = [] 

390 cell_payload: dict[tuple[int,int], dict] = {} 

391 

392 # --- build each cell payload (numeric/num, cat/num, num/cat, cat/cat) 

393 for r, (kind_r, key_r) in enumerate(panels): 

394 for c, (kind_c, key_c) in enumerate(panels): 

395 # X axis = column; Y axis = row 

396 if kind_r == "num" and kind_c == "num": 

397 i = int(key_r); j = int(key_c) 

398 xg = grids_std_num[j]; yg = grids_std_num[i] 

399 if i == j: 

400 grid = grids_std_num[j] 

401 Xn_1d = np.repeat(base_std[None, :], len(grid), axis=0) 

402 Xn_1d[:, j] = grid 

403 mu_1d, _ = pred_loss(Xn_1d, include_observation_noise=True) 

404 p_1d = pred_success(Xn_1d) 

405 Zmu = 0.5 * (mu_1d[:, None] + mu_1d[None, :]) 

406 Zp = np.minimum(p_1d[:, None], p_1d[None, :]) 

407 x_orig = _denorm_inv(j, grid) 

408 y_orig = x_orig 

409 else: 

410 XX, YY = np.meshgrid(xg, yg) 

411 Xn_grid = np.repeat(base_std[None, :], XX.size, axis=0) 

412 Xn_grid[:, j] = XX.ravel() 

413 Xn_grid[:, i] = YY.ravel() 

414 mu_flat, _ = pred_loss(Xn_grid, include_observation_noise=True) 

415 p_flat = pred_success(Xn_grid) 

416 Zmu = mu_flat.reshape(YY.shape) 

417 Zp = p_flat.reshape(YY.shape) 

418 x_orig = _denorm_inv(j, xg) 

419 y_orig = _denorm_inv(i, yg) 

420 cell_payload[(r, c)] = dict(kind=("num","num"), i=i, j=j, x=x_orig, y=y_orig, Zmu=Zmu, Zp=Zp) 

421 

422 elif kind_r == "cat" and kind_c == "num": 

423 base = str(key_r); j = int(key_c) 

424 labels = list(cat_allowed.get(base, groups[base]["labels"])) 

425 xg = grids_std_num[j] 

426 # build rows per label 

427 Zmu_rows = []; Zp_rows = [] 

428 for lab in labels: 

429 Xn_grid = np.repeat(base_std[None, :], len(xg), axis=0) 

430 Xn_grid[:, j] = xg 

431 _apply_onehot_for_base(Xn_grid, base, lab) 

432 mu_row, _ = pred_loss(Xn_grid, include_observation_noise=True) 

433 p_row = pred_success(Xn_grid) 

434 Zmu_rows.append(mu_row[None, :]) 

435 Zp_rows.append(p_row[None, :]) 

436 Zmu = np.concatenate(Zmu_rows, axis=0) # (n_labels, n_x) 

437 Zp = np.concatenate(Zp_rows, axis=0) 

438 x_orig = _denorm_inv(j, xg) 

439 y_cats = labels # categorical ticks 

440 cell_payload[(r,c)] = dict(kind=("cat","num"), base=base, j=j, x=x_orig, y=y_cats, Zmu=Zmu, Zp=Zp) 

441 

442 elif kind_r == "num" and kind_c == "cat": 

443 i = int(key_r); base = str(key_c) 

444 labels = list(cat_allowed.get(base, groups[base]["labels"])) 

445 yg = grids_std_num[i] 

446 # columns per label 

447 Zmu_cols = []; Zp_cols = [] 

448 for lab in labels: 

449 Xn_grid = np.repeat(base_std[None, :], len(yg), axis=0) 

450 Xn_grid[:, i] = yg 

451 _apply_onehot_for_base(Xn_grid, base, lab) 

452 mu_col, _ = pred_loss(Xn_grid, include_observation_noise=True) 

453 p_col = pred_success(Xn_grid) 

454 Zmu_cols.append(mu_col[:, None]) 

455 Zp_cols.append(p_col[:, None]) 

456 Zmu = np.concatenate(Zmu_cols, axis=1) # (n_y, n_labels) 

457 Zp = np.concatenate(Zp_cols, axis=1) 

458 x_cats = labels 

459 y_orig = _denorm_inv(i, yg) 

460 cell_payload[(r,c)] = dict(kind=("num","cat"), i=i, base=base, x=x_cats, y=y_orig, Zmu=Zmu, Zp=Zp) 

461 

462 else: # kind_r == "cat" and kind_c == "cat" 

463 base_r = str(key_r); base_c = str(key_c) 

464 labels_r = list(cat_allowed.get(base_r, groups[base_r]["labels"])) 

465 labels_c = list(cat_allowed.get(base_c, groups[base_c]["labels"])) 

466 Z = np.zeros((len(labels_r), len(labels_c)), dtype=float) 

467 P = np.zeros_like(Z) 

468 # evaluate each pair 

469 for rr, lab_r in enumerate(labels_r): 

470 for cc, lab_c in enumerate(labels_c): 

471 Xn_grid = base_std[None, :].copy() 

472 _apply_onehot_for_base(Xn_grid, base_r, lab_r) 

473 _apply_onehot_for_base(Xn_grid, base_c, lab_c) 

474 mu_val, _ = pred_loss(Xn_grid, include_observation_noise=True) 

475 p_val = pred_success(Xn_grid) 

476 Z[rr, cc] = float(mu_val[0]) 

477 P[rr, cc] = float(p_val[0]) 

478 cell_payload[(r,c)] = dict(kind=("cat","cat"), x=labels_c, y=labels_r, Zmu=Z, Zp=P) 

479 

480 all_blocks.append(cell_payload[(r,c)]["Zmu"].ravel()) 

481 

482 # --- color transform bounds 

483 def _color_xform(z_raw: np.ndarray) -> tuple[np.ndarray, float]: 

484 if not use_log_scale_for_target: 

485 return z_raw, 0.0 

486 zmin = float(np.nanmin(z_raw)) 

487 shift = 0.0 if zmin > 0 else -zmin + float(log_shift_epsilon) 

488 return np.log10(np.maximum(z_raw + shift, log_shift_epsilon)), shift 

489 

490 z_all = np.concatenate(all_blocks) if all_blocks else np.array([0.0, 1.0]) 

491 z_all_t, global_shift = _color_xform(z_all) 

492 cmin_t = float(np.nanmin(z_all_t)) 

493 cmax_t = float(np.nanmax(z_all_t)) 

494 cs = get_colorscale(colorscale) 

495 

496 def _contour_line_color(level_raw: float) -> str: 

497 zt = np.log10(max(level_raw + global_shift, log_shift_epsilon)) if use_log_scale_for_target else level_raw 

498 t = 0.5 if cmax_t == cmin_t else (zt - cmin_t) / (cmax_t - cmin_t) 

499 rgb = sample_colorscale(cs, [float(np.clip(t, 0.0, 1.0))])[0] 

500 r, g, b = _rgb_string_to_tuple(rgb) 

501 lum = (0.2126*r + 0.7152*g + 0.0722*b)/255.0 

502 grey = int(round((1.0 - lum) * 255)) 

503 return f"rgba({grey},{grey},{grey},0.9)" 

504 

505 # --- render cells 

506 def _is_log_feature(j: int) -> bool: return (transforms[j] == "log10") 

507 

508 for (r, c), PAY in cell_payload.items(): 

509 kind = PAY["kind"]; Zmu_raw = PAY["Zmu"]; Zp = PAY["Zp"] 

510 Z_t, _ = _color_xform(Zmu_raw) 

511 

512 # axes values (numeric arrays or category indices) 

513 if kind == ("num","num"): 

514 x_vals = PAY["x"]; y_vals = PAY["y"] 

515 fig.add_trace(go.Heatmap( 

516 x=x_vals, y=y_vals, z=Z_t, 

517 coloraxis="coloraxis", zsmooth=False, showscale=False, 

518 hovertemplate=(f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>" 

519 f"{feature_names[PAY['i']]}: %{{y:.6g}}" 

520 "<br>E[target|success]: %{customdata:.3f}<extra></extra>"), 

521 customdata=Zmu_raw 

522 ), row=r+1, col=c+1) 

523 

524 # p(success) shading + contours 

525 for thr, alpha in ((0.5, 0.25), (0.8, 0.40)): 

526 mask = np.where(Zp < thr, 1.0, np.nan) 

527 fig.add_trace(go.Heatmap( 

528 x=x_vals, y=y_vals, z=mask, zmin=0, zmax=1, 

529 colorscale=[[0, "rgba(0,0,0,0)"], [1, f"rgba(128,128,128,{alpha})"]], 

530 showscale=False, hoverinfo="skip" 

531 ), row=r+1, col=c+1) 

532 

533 # contour lines 

534 zmin_r, zmax_r = float(np.nanmin(Zmu_raw)), float(np.nanmax(Zmu_raw)) 

535 levels = np.linspace(zmin_r, zmax_r, max(n_contours, 2)) 

536 for lev in levels: 

537 color = _contour_line_color(lev) 

538 fig.add_trace(go.Contour( 

539 x=x_vals, y=y_vals, z=Zmu_raw, 

540 autocontour=False, 

541 contours=dict(coloring="lines", showlabels=False, start=lev, end=lev, size=1e-9), 

542 line=dict(width=1), 

543 colorscale=[[0, color], [1, color]], 

544 showscale=False, hoverinfo="skip" 

545 ), row=r+1, col=c+1) 

546 

547 # data overlays (success/fail) 

548 def _data_vals_for_feature(j_full: int) -> np.ndarray: 

549 nm = feature_names[j_full] 

550 if nm in df_raw_f.columns: 

551 return df_raw_f[nm].to_numpy(dtype=float) 

552 vals = feature_raw_from_artifact_or_reconstruct(ds, j_full, nm, transforms[j_full]).astype(float) 

553 return vals[row_mask] if row_mask_active else vals 

554 

555 xd = _data_vals_for_feature(PAY["j"]) 

556 yd = _data_vals_for_feature(PAY["i"]) 

557 show_leg = (r == 0 and c == 0) 

558 fig.add_trace(go.Scattergl( 

559 x=xd[success_mask], y=yd[success_mask], mode="markers", 

560 marker=dict(size=4, color="black", line=dict(width=0)), 

561 name="data (success)", legendgroup="data_succ", showlegend=show_leg, 

562 hovertemplate=("trial_id: %{customdata[0]}<br>" 

563 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>" 

564 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>" 

565 f"{tgt_col}: %{{customdata[1]:.4f}}<extra></extra>"), 

566 customdata=np.column_stack([ 

567 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask], 

568 df_raw_f[tgt_col].to_numpy()[success_mask], 

569 ]) 

570 ), row=r+1, col=c+1) 

571 fig.add_trace(go.Scattergl( 

572 x=xd[fail_mask], y=yd[fail_mask], mode="markers", 

573 marker=dict(size=5, color="red", line=dict(color="black", width=0.8)), 

574 name="data (failed)", legendgroup="data_fail", showlegend=show_leg, 

575 hovertemplate=("trial_id: %{customdata}<br>" 

576 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>" 

577 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>" 

578 "status: failed (NaN target)<extra></extra>"), 

579 customdata=df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask] 

580 ), row=r+1, col=c+1) 

581 

582 # overlays (optimal/suggest) on numeric axes only 

583 if optimal and (optimal_df is not None): 

584 if feature_names[PAY["j"]] in optimal_df.columns and feature_names[PAY["i"]] in optimal_df.columns: 

585 ox = np.asarray(optimal_df[feature_names[PAY["j"]]].values, dtype=float) 

586 oy = np.asarray(optimal_df[feature_names[PAY["i"]]].values, dtype=float) 

587 if np.isfinite(ox).all() and np.isfinite(oy).all(): 

588 pmu = float(optimal_df["pred_target_mean"].values[0]) 

589 psd = float(optimal_df["pred_target_sd"].values[0]) 

590 fig.add_trace(go.Scattergl( 

591 x=ox, y=oy, mode="markers", 

592 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"), 

593 name="optimal", legendgroup="optimal", showlegend=(r == 0 and c == 0), 

594 hovertemplate=(f"predicted: {pmu:.2g} ± {psd:.2g}<br>" 

595 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>" 

596 f"{feature_names[PAY['i']]}: %{{y:.6g}}<extra></extra>") 

597 ), row=r+1, col=c+1) 

598 if suggest and (suggest_df is not None): 

599 have = (feature_names[PAY["j"]] in suggest_df.columns) and (feature_names[PAY["i"]] in suggest_df.columns) 

600 if have: 

601 sx = np.asarray(suggest_df[feature_names[PAY["j"]]].values, dtype=float) 

602 sy = np.asarray(suggest_df[feature_names[PAY["i"]]].values, dtype=float) 

603 keep_s = np.isfinite(sx) & np.isfinite(sy) 

604 if keep_s.any(): 

605 sx, sy = sx[keep_s], sy[keep_s] 

606 mu_s = suggest_df.loc[keep_s, "pred_target_mean"].values if "pred_target_mean" in suggest_df else None 

607 sd_s = suggest_df.loc[keep_s, "pred_target_sd"].values if "pred_target_sd" in suggest_df else None 

608 ps_s = suggest_df.loc[keep_s, "pred_p_success"].values if "pred_p_success" in suggest_df else None 

609 if (mu_s is not None) and (sd_s is not None) and (ps_s is not None): 

610 custom_s = np.column_stack([mu_s, sd_s, ps_s]) 

611 hover_s = ( 

612 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>" 

613 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>" 

614 "pred: %{customdata[0]:.3g} ± %{customdata[1]:.3g}<br>" 

615 "p(success): %{customdata[2]:.2f}<extra>suggested</extra>" 

616 ) 

617 else: 

618 custom_s = None 

619 hover_s = ( 

620 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>" 

621 f"{feature_names[PAY['i']]}: %{{y:.6g}}<extra>suggested</extra>" 

622 ) 

623 fig.add_trace(go.Scattergl( 

624 x=sx, y=sy, mode="markers", 

625 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"), 

626 name="suggested", legendgroup="suggested", 

627 showlegend=(r == 0 and c == 0), 

628 customdata=custom_s, hovertemplate=hover_s 

629 ), row=r+1, col=c+1) 

630 

631 # axis types/ranges 

632 _update_axis_type_and_range(fig, row=r+1, col=c+1, axis="x", centers=x_vals, is_log=_is_log_feature(PAY["j"])) 

633 _update_axis_type_and_range(fig, row=r+1, col=c+1, axis="y", centers=y_vals, is_log=_is_log_feature(PAY["i"])) 

634 

635 elif kind == ("cat","num"): 

636 base = PAY["base"]; x_vals = PAY["x"]; labels = PAY["y"] 

637 nlab = len(labels) 

638 # heatmap (categories on Y) 

639 fig.add_trace(go.Heatmap( 

640 x=x_vals, y=np.arange(nlab), z=Z_t, 

641 coloraxis="coloraxis", zsmooth=False, showscale=False, 

642 hovertemplate=(f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>" 

643 f"{base}: %{{text}}" 

644 "<br>E[target|success]: %{customdata:.3f}<extra></extra>"), 

645 text=np.array(labels)[:, None].repeat(len(x_vals), axis=1), 

646 customdata=Zmu_raw 

647 ), row=r+1, col=c+1) 

648 # p(success) shading 

649 for thr, alpha in ((0.5, 0.25), (0.8, 0.40)): 

650 mask = np.where(Zp < thr, 1.0, np.nan) 

651 fig.add_trace(go.Heatmap( 

652 x=x_vals, y=np.arange(nlab), z=mask, zmin=0, zmax=1, 

653 colorscale=[[0, "rgba(0,0,0,0)"], [1, f"rgba(128,128,128,{alpha})"]], 

654 showscale=False, hoverinfo="skip" 

655 ), row=r+1, col=c+1) 

656 # categorical ticks 

657 fig.update_yaxes(tickmode="array", tickvals=list(range(nlab)), ticktext=labels, row=r+1, col=c+1) 

658 # data overlays: numeric vs categorical with jitter on Y 

659 if base in df_raw_f.columns and feature_names[PAY["j"]] in df_raw_f.columns: 

660 cat_series = df_raw_f[base].astype("string") 

661 cat_to_idx = {lab: i for i, lab in enumerate(labels)} 

662 y_map = cat_series.map(cat_to_idx) 

663 ok = y_map.notna().to_numpy() 

664 y_idx = y_map.to_numpy(dtype=float) 

665 jitter = 0.10 * (np.random.default_rng(0).standard_normal(size=len(y_idx))) 

666 yj = y_idx + jitter 

667 xd = df_raw_f[feature_names[PAY["j"]]].to_numpy(dtype=float) 

668 show_leg = (r == 0 and c == 0) 

669 fig.add_trace(go.Scattergl( 

670 x=xd[success_mask & ok], y=yj[success_mask & ok], mode="markers", 

671 marker=dict(size=4, color="black", line=dict(width=0)), 

672 name="data (success)", legendgroup="data_succ", showlegend=show_leg, 

673 hovertemplate=("trial_id: %{customdata[0]}<br>" 

674 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>" 

675 f"{base}: %{{customdata[1]}}<br>" 

676 f"{tgt_col}: %{{customdata[2]:.4f}}<extra></extra>"), 

677 customdata=np.column_stack([ 

678 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask & ok], 

679 cat_series.to_numpy()[success_mask & ok], 

680 df_raw_f[tgt_col].to_numpy()[success_mask & ok], 

681 ]) 

682 ), row=r+1, col=c+1) 

683 fig.add_trace(go.Scattergl( 

684 x=xd[fail_mask & ok], y=yj[fail_mask & ok], mode="markers", 

685 marker=dict(size=5, color="red", line=dict(color="black", width=0.8)), 

686 name="data (failed)", legendgroup="data_fail", showlegend=show_leg, 

687 hovertemplate=("trial_id: %{customdata[0]}<br>" 

688 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>" 

689 f"{base}: %{{customdata[1]}}<br>" 

690 "status: failed (NaN target)<extra></extra>"), 

691 customdata=np.column_stack([ 

692 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask & ok], 

693 cat_series.to_numpy()[fail_mask & ok], 

694 ]) 

695 ), row=r+1, col=c+1) 

696 # axes: x numeric; y categorical range 

697 _update_axis_type_and_range(fig, row=r+1, col=c+1, axis="x", centers=x_vals, is_log=_is_log_feature(PAY["j"])) 

698 fig.update_yaxes(range=[-0.5, nlab - 0.5], row=r+1, col=c+1) 

699 

700 elif kind == ("num","cat"): 

701 base = PAY["base"]; y_vals = PAY["y"]; labels = PAY["x"] 

702 nlab = len(labels) 

703 # heatmap (categories on X) 

704 fig.add_trace(go.Heatmap( 

705 x=np.arange(nlab), y=y_vals, z=Z_t, 

706 coloraxis="coloraxis", zsmooth=False, showscale=False, 

707 hovertemplate=(f"{base}: %{{text}}<br>" 

708 f"{feature_names[PAY['i']]}: %{{y:.6g}}" 

709 "<br>E[target|success]: %{customdata:.3f}<extra></extra>"), 

710 text=np.array(labels)[None, :].repeat(len(y_vals), axis=0), 

711 customdata=Zmu_raw 

712 ), row=r+1, col=c+1) 

713 for thr, alpha in ((0.5, 0.25), (0.8, 0.40)): 

714 mask = np.where(Zp < thr, 1.0, np.nan) 

715 fig.add_trace(go.Heatmap( 

716 x=np.arange(nlab), y=y_vals, z=mask, zmin=0, zmax=1, 

717 colorscale=[[0, "rgba(0,0,0,0)"], [1, f"rgba(128,128,128,{alpha})"]], 

718 showscale=False, hoverinfo="skip" 

719 ), row=r+1, col=c+1) 

720 fig.update_xaxes(tickmode="array", tickvals=list(range(nlab)), ticktext=labels, row=r+1, col=c+1) 

721 # data overlays with jitter on X 

722 if base in df_raw_f.columns and feature_names[PAY["i"]] in df_raw_f.columns: 

723 cat_series = df_raw_f[base].astype("string") 

724 cat_to_idx = {lab: i for i, lab in enumerate(labels)} 

725 x_map = cat_series.map(cat_to_idx) 

726 ok = x_map.notna().to_numpy() 

727 x_idx = x_map.to_numpy(dtype=float) 

728 jitter = 0.10 * (np.random.default_rng(0).standard_normal(size=len(x_idx))) 

729 xj = x_idx + jitter 

730 yd = df_raw_f[feature_names[PAY["i"]]].to_numpy(dtype=float) 

731 show_leg = (r == 0 and c == 0) 

732 fig.add_trace(go.Scattergl( 

733 x=xj[success_mask & ok], y=yd[success_mask & ok], mode="markers", 

734 marker=dict(size=4, color="black", line=dict(width=0)), 

735 name="data (success)", legendgroup="data_succ", showlegend=show_leg, 

736 hovertemplate=("trial_id: %{customdata[0]}<br>" 

737 f"{base}: %{{customdata[1]}}<br>" 

738 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>" 

739 f"{tgt_col}: %{{customdata[2]:.4f}}<extra></extra>"), 

740 customdata=np.column_stack([ 

741 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask & ok], 

742 cat_series.to_numpy()[success_mask & ok], 

743 df_raw_f[tgt_col].to_numpy()[success_mask & ok], 

744 ]) 

745 ), row=r+1, col=c+1) 

746 fig.add_trace(go.Scattergl( 

747 x=xj[fail_mask & ok], y=yd[fail_mask & ok], mode="markers", 

748 marker=dict(size=5, color="red", line=dict(color="black", width=0.8)), 

749 name="data (failed)", legendgroup="data_fail", showlegend=show_leg, 

750 hovertemplate=("trial_id: %{customdata[0]}<br>" 

751 f"{base}: %{{customdata[1]}}<br>" 

752 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>" 

753 "status: failed (NaN target)<extra></extra>"), 

754 customdata=np.column_stack([ 

755 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask & ok], 

756 cat_series.to_numpy()[fail_mask & ok], 

757 ]) 

758 ), row=r+1, col=c+1) 

759 # axes: x categorical; y numeric 

760 fig.update_xaxes(range=[-0.5, nlab - 0.5], row=r+1, col=c+1) 

761 _update_axis_type_and_range(fig, row=r+1, col=c+1, axis="y", centers=y_vals, is_log=_is_log_feature(PAY["i"])) 

762 

763 elif kind == ("cat","cat"): 

764 labels_y = PAY["y"] 

765 labels_x = PAY["x"] 

766 ny, nx = len(labels_y), len(labels_x) 

767 

768 # Build customdata carrying (row_label, col_label) for hovertemplate. 

769 custom = np.dstack(( 

770 np.array(labels_y, dtype=object)[:, None].repeat(nx, axis=1), 

771 np.array(labels_x, dtype=object)[None, :].repeat(ny, axis=0), 

772 )) 

773 

774 # Heatmap over categorical indices 

775 fig.add_trace(go.Heatmap( 

776 x=np.arange(nx), 

777 y=np.arange(ny), 

778 z=Z_t, 

779 coloraxis="coloraxis", 

780 zsmooth=False, 

781 showscale=False, 

782 hovertemplate=( 

783 "row: %{customdata[0]}<br>" 

784 "col: %{customdata[1]}<br>" 

785 "E[target|success]: %{z:.3f}<extra></extra>" 

786 ), 

787 customdata=custom, 

788 ), row=r+1, col=c+1) 

789 

790 # p(success) shading overlays 

791 for thr, alpha in ((0.5, 0.25), (0.8, 0.40)): 

792 mask = np.where(Zp < thr, 1.0, np.nan) 

793 fig.add_trace(go.Heatmap( 

794 x=np.arange(nx), 

795 y=np.arange(ny), 

796 z=mask, 

797 zmin=0, 

798 zmax=1, 

799 colorscale=[[0, "rgba(0,0,0,0)"], [1, f"rgba(128,128,128,{alpha})"]], 

800 showscale=False, 

801 hoverinfo="skip", 

802 ), row=r+1, col=c+1) 

803 

804 # Categorical tick labels on both axes 

805 fig.update_xaxes( 

806 tickmode="array", 

807 tickvals=list(range(nx)), 

808 ticktext=labels_x, 

809 range=[-0.5, nx - 0.5], 

810 row=r+1, 

811 col=c+1, 

812 ) 

813 fig.update_yaxes( 

814 tickmode="array", 

815 tickvals=list(range(ny)), 

816 ticktext=labels_y, 

817 range=[-0.5, ny - 0.5], 

818 row=r+1, 

819 col=c+1, 

820 ) 

821 

822 # --- outer axis labels 

823 def _panel_title(kind: str, key: object) -> str: 

824 return feature_names[int(key)] if kind == "num" else str(key) 

825 

826 for c, (_, key_c) in enumerate(panels): 

827 fig.update_xaxes(title_text=_panel_title(panels[c][0], key_c), row=k, col=c+1) 

828 for r, (kind_r, key_r) in enumerate(panels): 

829 fig.update_yaxes(title_text=_panel_title(kind_r, key_r), row=r+1, col=1) 

830 

831 # --- title 

832 def _fmt_c(v): 

833 if isinstance(v, slice): 

834 a = f"{v.start:g}" if v.start is not None else "" 

835 b = f"{v.stop:g}" if v.stop is not None else "" 

836 return f"[{a},{b}]" 

837 if isinstance(v, (list, tuple, np.ndarray)): 

838 try: 

839 return "[" + ",".join(f"{float(x):g}" for x in np.asarray(v).tolist()) + "]" 

840 except Exception: 

841 return "[" + ",".join(map(str, v)) + "]" 

842 return str(v) 

843 

844 title_parts = [f"2D partial dependence of expected {tgt_col}"] 

845 

846 # numeric constraints shown 

847 for name, val in kw_num.items(): 

848 title_parts.append(f"{name}={_fmt_c(val)}") 

849 # categorical constraints: fixed shown as base=Label; allowed ranges omitted in title 

850 for base, lab in cat_fixed.items(): 

851 title_parts.append(f"{base}={lab}") 

852 title = " — ".join([title_parts[0], ", ".join(title_parts[1:])]) if len(title_parts) > 1 else title_parts[0] 

853 

854 # --- layout 

855 cell = 250 

856 z_title = "E[target|success]" + (" (log10)" if use_log_scale_for_target else "") 

857 if use_log_scale_for_target and global_shift > 0: 

858 z_title += f" (shift Δ={global_shift:.3g})" 

859 

860 width = width if (width and width > 0) else cell * k 

861 width = max(width, 400) 

862 height = height if (height and height > 0) else cell * k 

863 height = max(height, 400) 

864 

865 fig.update_layout( 

866 template="simple_white", 

867 width=width, 

868 height=height, 

869 title=title, 

870 legend_title_text="", 

871 coloraxis=dict( 

872 colorscale=colorscale, 

873 cmin=cmin_t, cmax=cmax_t, 

874 colorbar=dict( 

875 title=z_title, 

876 thickness=10, # thinner bar 

877 len=0.55, # shorter bar (fraction of plot height) 

878 lenmode="fraction", 

879 x=1.02, y=0.5, # just right of plot, vertically centered 

880 xanchor="left", yanchor="middle", 

881 ), 

882 ), 

883 legend=dict( 

884 orientation="v", 

885 x=1.02, xanchor="left", # to the right of the colorbar 

886 y=1.0, yanchor="top", 

887 bgcolor="rgba(255,255,255,0.85)" 

888 ), 

889 margin=dict(t=90, r=100), # room for title + legend + colorbar 

890 ) 

891 

892 if output: 

893 write_image(fig, output) 

894 if show: 

895 fig.show("browser") 

896 return fig 

897 

898 

899def plot1d( 

900 model: xr.Dataset | Path | str, 

901 output: Path | None = None, 

902 csv_out: Path | None = None, 

903 grid_size: int = 300, 

904 line_color: str = "rgb(31,119,180)", 

905 band_alpha: float = 0.25, 

906 show: bool = False, 

907 use_log_scale_for_target_y: bool = True, # log-y for target 

908 log_y_epsilon: float = 1e-9, 

909 optimal: bool = True, 

910 suggest: int = 0, 

911 width:int|None = None, 

912 height:int|None = None, 

913 seed: int|None = 42, 

914 **kwargs, 

915) -> go.Figure: 

916 """ 

917 Vertical 1D PD panels of E[target|success] vs each *free* feature. 

918 Scalars (fix & hide), slices (restrict sweep & x-range), lists/tuples (discrete grids). 

919 Categorical bases (e.g. language) are plotted as a single categorical subplot 

920 when not fixed; passing --language "Linear A" fixes that base and removes it 

921 from the plotted axes. 

922 """ 

923 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model) 

924 pred_success, pred_loss = _build_predictors(ds) 

925 

926 feature_names = [str(n) for n in ds["feature"].values.tolist()] 

927 transforms = [str(t) for t in ds["feature_transform"].values.tolist()] 

928 X_mean = ds["feature_mean"].values.astype(float) 

929 X_std = ds["feature_std"].values.astype(float) 

930 

931 df_raw = _raw_dataframe_from_dataset(ds) 

932 Xn_train = ds["Xn_train"].values.astype(float) 

933 n_rows, p = Xn_train.shape 

934 

935 # --- one-hot categorical groups --- 

936 groups = opt._onehot_groups(feature_names) # { base: {"labels":[...], "name_by_label":{label:member}, "members":[...]} } 

937 bases = set(groups.keys()) 

938 name_to_idx = {name: j for j, name in enumerate(feature_names)} 

939 

940 # --- canonicalize kwargs: numeric vs categorical (base) --- 

941 idx_map = _canon_key_set(ds) 

942 kw_num_raw: dict[str, object] = {} 

943 kw_cat_raw: dict[str, object] = {} 

944 for k, v in kwargs.items(): 

945 if k in bases: 

946 kw_cat_raw[k] = v 

947 continue 

948 if k in idx_map: 

949 kw_num_raw[idx_map[k]] = v 

950 continue 

951 import re as _re 

952 nk = _re.sub(r"[^a-z0-9]+", "", str(k).lower()) 

953 if nk in idx_map: 

954 kw_num_raw[idx_map[nk]] = v 

955 

956 # --- resolve categorical constraints: fixed (single) vs allowed (multiple) --- 

957 cat_fixed: dict[str, str] = {} 

958 cat_allowed: dict[str, list[str]] = {} 

959 for base, val in kw_cat_raw.items(): 

960 labels = groups[base]["labels"] 

961 if isinstance(val, str): 

962 if val not in labels: 

963 raise ValueError(f"Unknown category for {base!r}: {val!r}. Choices: {labels}") 

964 cat_fixed[base] = val 

965 elif isinstance(val, (list, tuple, set)): 

966 chosen = [x for x in val if isinstance(x, str) and x in labels] 

967 if not chosen: 

968 raise ValueError(f"No valid categories for {base!r} in {val!r}. Choices: {labels}") 

969 # multiple -> treat as allowed subset (NOT fixed) 

970 cat_allowed[base] = list(dict.fromkeys(chosen)) 

971 else: 

972 raise ValueError(f"Categorical constraint for {base!r} must be a string or list/tuple of strings.") 

973 

974 # --- filter rows by fixed categoricals (affects medians/percentiles & overlays) --- 

975 row_mask = np.ones(n_rows, dtype=bool) 

976 for base, label in cat_fixed.items(): 

977 if base in df_raw.columns: 

978 row_mask &= (df_raw[base].astype("string") == pd.Series([label]*len(df_raw), dtype="string")).to_numpy() 

979 else: 

980 member_name = groups[base]["name_by_label"][label] 

981 j = name_to_idx[member_name] 

982 raw_j = feature_raw_from_artifact_or_reconstruct(ds, j, member_name, transforms[j]).astype(float) 

983 row_mask &= (raw_j >= 0.5) 

984 

985 # --- helpers to transform original <-> standardized for feature j --- 

986 def _orig_to_std(j: int, x, transforms, mu, sd): 

987 x = np.asarray(x, dtype=float) 

988 if transforms[j] == "log10": 

989 x = np.where(x <= 0, np.nan, x) 

990 x = np.log10(x) 

991 return (x - mu[j]) / sd[j] 

992 

993 # --- numeric constraint split (STANDARDIZED) --- 

994 fixed_scalars: dict[int, float] = {} 

995 range_windows: dict[int, tuple[float, float]] = {} 

996 choice_values: dict[int, np.ndarray] = {} 

997 for name, val in kw_num_raw.items(): 

998 if name not in name_to_idx: 

999 continue 

1000 j = name_to_idx[name] 

1001 if isinstance(val, slice): 

1002 lo = _orig_to_std(j, float(val.start), transforms, X_mean, X_std) 

1003 hi = _orig_to_std(j, float(val.stop), transforms, X_mean, X_std) 

1004 lo, hi = float(min(lo, hi)), float(max(lo, hi)) 

1005 range_windows[j] = (lo, hi) 

1006 elif isinstance(val, (list, tuple, np.ndarray)): 

1007 arr = _orig_to_std(j, np.asarray(val, dtype=float), transforms, X_mean, X_std) 

1008 choice_values[j] = np.asarray(arr, dtype=float) 

1009 else: 

1010 fixed_scalars[j] = float(_orig_to_std(j, float(val), transforms, X_mean, X_std)) 

1011 

1012 # --- apply categorical fixed as standardized scalar fixes on each one-hot member --- 

1013 for base, label in cat_fixed.items(): 

1014 labels = groups[base]["labels"] 

1015 for lab in labels: 

1016 member_name = groups[base]["name_by_label"][lab] 

1017 j = name_to_idx[member_name] 

1018 raw_val = 1.0 if lab == label else 0.0 

1019 fixed_scalars[j] = float(_orig_to_std(j, raw_val, transforms, X_mean, X_std)) 

1020 

1021 # --- enforce row-level filters for categorical allowed sets and numeric constraints --- 

1022 for base, allowed in cat_allowed.items(): 

1023 if base in df_raw.columns: 

1024 series = df_raw[base].astype("string").fillna("<NA>") 

1025 allowed_set = {str(x) for x in allowed} 

1026 allowed_mask = series.isin(allowed_set).fillna(False).to_numpy() 

1027 row_mask &= allowed_mask 

1028 else: 

1029 allowed_masks = [] 

1030 for label in allowed: 

1031 member_name = groups[base]["name_by_label"].get(label) 

1032 if member_name is None: 

1033 continue 

1034 j = name_to_idx[member_name] 

1035 raw_j = feature_raw_from_artifact_or_reconstruct(ds, j, member_name, transforms[j]).astype(float) 

1036 allowed_masks.append(raw_j >= 0.5) 

1037 if allowed_masks: 

1038 combined = np.logical_or.reduce(allowed_masks) 

1039 row_mask &= combined 

1040 else: 

1041 row_mask &= False 

1042 

1043 for name, val in kw_num_raw.items(): 

1044 if name not in name_to_idx: 

1045 continue 

1046 j = name_to_idx[name] 

1047 if name in df_raw.columns: 

1048 raw_vals = pd.to_numeric(df_raw[name], errors="coerce").to_numpy(dtype=float) 

1049 else: 

1050 raw_vals = feature_raw_from_artifact_or_reconstruct(ds, j, feature_names[j], transforms[j]).astype(float) 

1051 

1052 mask = np.ones_like(row_mask, dtype=bool) 

1053 if isinstance(val, slice): 

1054 lo_raw = -np.inf if val.start is None else float(val.start) 

1055 hi_raw = np.inf if val.stop is None else float(val.stop) 

1056 if hi_raw < lo_raw: 

1057 lo_raw, hi_raw = hi_raw, lo_raw 

1058 mask &= (raw_vals >= lo_raw) & (raw_vals <= hi_raw) 

1059 elif isinstance(val, (list, tuple, set, np.ndarray)): 

1060 arr = np.asarray(list(val) if not isinstance(val, np.ndarray) else val, dtype=float) 

1061 arr = arr[np.isfinite(arr)] 

1062 if arr.size == 0: 

1063 mask &= False 

1064 else: 

1065 mask &= np.any(np.isclose(raw_vals[:, None], arr[None, :], rtol=1e-6, atol=1e-9), axis=1) 

1066 else: 

1067 target = float(val) 

1068 mask &= np.isclose(raw_vals, target, rtol=1e-6, atol=1e-9) 

1069 

1070 row_mask &= mask 

1071 

1072 if not np.any(row_mask): 

1073 raise ValueError("No experiments match the provided constraints; cannot plot data points.") 

1074 

1075 df_raw_f = df_raw.loc[row_mask].reset_index(drop=True) 

1076 Xn_train_f = Xn_train[row_mask, :] 

1077 

1078 # --- overlays conditioned on the same kwargs (numeric + categorical) --- 

1079 optimal_df = opt.optimal(model, count=1, seed=seed, **kwargs) if optimal else None 

1080 suggest_df = opt.suggest(model, count=suggest, seed=seed, **kwargs) if (suggest and suggest > 0) else None 

1081 

1082 # --- base standardized point (median over filtered rows), then apply scalar fixes --- 

1083 base_std = np.median(Xn_train_f, axis=0) 

1084 for j, vstd in fixed_scalars.items(): 

1085 base_std[j] = vstd 

1086 

1087 # --- plotted panels: numeric free features + categorical bases not fixed --- 

1088 onehot_members = set() 

1089 for base, g in groups.items(): 

1090 onehot_members.update(g["members"]) 

1091 free_numeric_idx = [j for j in range(p) if (j not in fixed_scalars) and (feature_names[j] not in onehot_members)] 

1092 free_cat_bases = [b for b in bases if b not in cat_fixed] # optional: filtered by cat_allowed later 

1093 

1094 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases] 

1095 if not panels: 

1096 raise ValueError("All features are fixed (or categorical only with single category chosen); nothing to plot.") 

1097 

1098 # --- empirical 1–99% from filtered rows for numeric bounds --- 

1099 p01p99 = [np.percentile(Xn_train_f[:, j], [1, 99]) for j in range(p)] 

1100 def _grid_1d(j: int, n: int) -> np.ndarray: 

1101 p01, p99 = p01p99[j] 

1102 if j in choice_values: 

1103 vals = np.asarray(choice_values[j], dtype=float) 

1104 vals = vals[(vals >= p01) & (vals <= p99)] 

1105 return np.unique(np.sort(vals)) if vals.size else np.array([np.median(Xn_train_f[:, j])], dtype=float) 

1106 lo, hi = p01, p99 

1107 if j in range_windows: 

1108 rlo, rhi = range_windows[j] 

1109 lo, hi = max(lo, rlo), min(hi, rhi) 

1110 if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo: 

1111 lo, hi = p01, max(p01 + 1e-9, p99) 

1112 return np.linspace(lo, hi, n) 

1113 

1114 # --- one-hot member names (robust) --- 

1115 onehot_member_names: set[str] = set() 

1116 for base, g in groups.items(): 

1117 # names recorded by the detector 

1118 onehot_member_names.update(g["members"]) 

1119 # fallback pattern match in case detector missed anything 

1120 prefix = f"{base}=" 

1121 onehot_member_names.update([nm for nm in feature_names if nm.startswith(prefix)]) 

1122 

1123 # --- build panel list: numeric free features + categorical bases (not fixed) --- 

1124 free_numeric_idx = [ 

1125 j for j, nm in enumerate(feature_names) 

1126 if (j not in fixed_scalars) and (nm not in onehot_member_names) 

1127 ] 

1128 free_cat_bases = [b for b in bases if b not in cat_fixed] 

1129 

1130 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases] 

1131 if not panels: 

1132 raise ValueError("All features are fixed (or only single-category categoricals remain); nothing to plot.") 

1133 

1134 # sanity: ensure we didn't accidentally keep any one-hot member columns 

1135 assert all( 

1136 (feature_names[key] not in onehot_member_names) if kind == "num" else True 

1137 for kind, key in panels 

1138 ), "internal: one-hot member leaked into numeric panels" 

1139 

1140 # --- figure scaffold with clean titles --- 

1141 def _panel_title(kind: str, key: object) -> str: 

1142 return feature_names[int(key)] if kind == "num" else str(key) 

1143 

1144 fig = make_subplots( 

1145 rows=len(panels), 

1146 cols=1, 

1147 shared_xaxes=False, 

1148 ) 

1149 

1150 # --- masks/data from filtered rows --- 

1151 tgt_col = str(ds.attrs["target"]) 

1152 success_mask = ~pd.isna(df_raw_f[tgt_col]).to_numpy() 

1153 fail_mask = ~success_mask 

1154 losses_success = df_raw_f.loc[success_mask, tgt_col].to_numpy().astype(float) 

1155 trial_ids_success = df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask] 

1156 trial_ids_fail = df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask] 

1157 band_fill_rgba = _rgb_to_rgba(line_color, band_alpha) 

1158 

1159 tidy_rows: list[dict] = [] 

1160 

1161 row_pos = 0 

1162 for kind, key in panels: 

1163 row_pos += 1 

1164 

1165 if kind == "num": 

1166 j = key 

1167 if feature_names[j] in df_raw.columns: 

1168 series_full = pd.to_numeric(df_raw[feature_names[j]], errors="coerce") 

1169 x_full_raw = series_full.to_numpy(dtype=float) 

1170 else: 

1171 x_full_raw = feature_raw_from_artifact_or_reconstruct( 

1172 ds, j, feature_names[j], transforms[j] 

1173 ).astype(float) 

1174 x_data_all = x_full_raw[row_mask] 

1175 

1176 finite_raw = x_full_raw[np.isfinite(x_full_raw)] 

1177 if transforms[j] == "log10": 

1178 finite_raw = finite_raw[finite_raw > 0] 

1179 

1180 grid = _grid_1d(j, grid_size) 

1181 if (j not in range_windows) and (j not in choice_values) and finite_raw.size: 

1182 finite_std = _orig_to_std(j, finite_raw, transforms, X_mean, X_std) 

1183 grid_min = float(np.nanmin(np.concatenate([grid, finite_std]))) 

1184 grid_max = float(np.nanmax(np.concatenate([grid, finite_std]))) 

1185 if grid_max > grid_min: 

1186 grid = np.linspace(grid_min, grid_max, grid_size) 

1187 else: 

1188 grid = np.array([grid_min], dtype=float) 

1189 

1190 Xn_grid = np.repeat(base_std[None, :], len(grid), axis=0) 

1191 Xn_grid[:, j] = grid 

1192 

1193 # # --- DEBUG: confirm the feature is actually changing in standardized space --- 

1194 # print(f"[{feature_names[j]}] std grid head: {grid[:6]}") 

1195 # print(f"[{feature_names[j]}] std grid ptp (range): {np.ptp(grid)}") 

1196 # print(f"[{feature_names[j]}] Xn_grid[:2, j]: {Xn_grid[:2, j]}") 

1197 # print(f"[{feature_names[j]}] Xn 1–99%: {p01p99[j]}") 

1198 

1199 p_grid = pred_success(Xn_grid) 

1200 mu_grid, sd_grid = pred_loss(Xn_grid, include_observation_noise=True) 

1201 # print(feature_names[j], "mu range:", float(np.ptp(mu_grid))) 

1202 

1203 x_internal = grid * X_std[j] + X_mean[j] 

1204 x_display = _inverse_transform(transforms[j], x_internal) 

1205 

1206 # print(f"[{feature_names[j]}] orig head: {x_display[:6]}") 

1207 # print(f"[{feature_names[j]}] orig ptp (range): {np.ptp(x_display)}") 

1208 

1209 if use_log_scale_for_target_y: 

1210 mu_plot = np.maximum(mu_grid, log_y_epsilon) 

1211 lo_plot = np.maximum(mu_grid - 2.0 * sd_grid, log_y_epsilon) 

1212 hi_plot = np.maximum(mu_grid + 2.0 * sd_grid, log_y_epsilon) 

1213 losses_s_plot = np.maximum(losses_success, log_y_epsilon) if losses_success.size else losses_success 

1214 else: 

1215 mu_plot = mu_grid 

1216 lo_plot = mu_grid - 2.0 * sd_grid 

1217 hi_plot = mu_grid + 2.0 * sd_grid 

1218 losses_s_plot = losses_success 

1219 

1220 y_arrays = [lo_plot, hi_plot] + ([losses_s_plot] if losses_s_plot.size else []) 

1221 y_low = float(np.nanmin([np.nanmin(a) for a in y_arrays])) 

1222 y_high = float(np.nanmax([np.nanmax(a) for a in y_arrays])) 

1223 pad = 0.05 * (y_high - y_low + 1e-12) 

1224 y0_plot = (y_low - pad) if not use_log_scale_for_target_y else max(y_low / 1.5, log_y_epsilon) 

1225 y1_tmp = (y_high + pad) if not use_log_scale_for_target_y else y_high * 1.2 

1226 y_failed_band = y1_tmp + (y_high - y_low + 1e-12) * (0.08 if not use_log_scale_for_target_y else 0.3) 

1227 if use_log_scale_for_target_y and y_failed_band <= log_y_epsilon: 

1228 y_failed_band = max(10.0 * log_y_epsilon, y_high * 2.0) 

1229 y1_plot = y_failed_band + (0.02 if not use_log_scale_for_target_y else 0.05) * (y_high - y_low + 1e-12) 

1230 

1231 _add_low_success_shading_1d(fig, row_pos, x_display, p_grid, y0_plot, y1_plot) 

1232 

1233 show_legend = (row_pos == 1) 

1234 fig.add_trace(go.Scatter(x=x_display, y=lo_plot, mode="lines", 

1235 line=dict(width=0, color=line_color), 

1236 name="±2σ", legendgroup="band", showlegend=False, hoverinfo="skip"), 

1237 row=row_pos, col=1) 

1238 fig.add_trace(go.Scatter(x=x_display, y=hi_plot, mode="lines", fill="tonexty", 

1239 line=dict(width=0, color=line_color), fillcolor=band_fill_rgba, 

1240 name="±2σ", legendgroup="band", showlegend=show_legend, 

1241 hovertemplate="E[target|success]: %{y:.3f}<extra>±2σ</extra>"), 

1242 row=row_pos, col=1) 

1243 fig.add_trace(go.Scatter(x=x_display, y=mu_plot, mode="lines", 

1244 line=dict(width=2, color=line_color), 

1245 name="E[target|success]", legendgroup="mean", showlegend=show_legend, 

1246 hovertemplate=f"{feature_names[j]}: %{{x:.6g}}<br>E[target|success]: %{{y:.3f}}<extra></extra>"), 

1247 row=row_pos, col=1) 

1248 

1249 # experimental points 

1250 x_succ = x_data_all[success_mask] 

1251 if x_succ.size: 

1252 fig.add_trace(go.Scattergl( 

1253 x=x_succ, y=losses_s_plot, mode="markers", 

1254 marker=dict(size=5, color="black", line=dict(width=0)), 

1255 name="data (success)", legendgroup="data_s", showlegend=show_legend, 

1256 hovertemplate=("trial_id: %{customdata}<br>" 

1257 f"{feature_names[j]}: %{{x:.6g}}<br>" 

1258 f"{tgt_col}: %{{y:.4f}}<extra></extra>"), 

1259 customdata=trial_ids_success 

1260 ), row=row_pos, col=1) 

1261 

1262 x_fail = x_data_all[fail_mask] 

1263 if x_fail.size: 

1264 y_fail_plot = np.full_like(x_fail, y_failed_band, dtype=float) 

1265 fig.add_trace(go.Scattergl( 

1266 x=x_fail, y=y_fail_plot, mode="markers", 

1267 marker=dict(size=6, color="red", line=dict(color="black", width=0.8)), 

1268 name="data (failed)", legendgroup="data_f", showlegend=show_legend, 

1269 hovertemplate=("trial_id: %{customdata}<br>" 

1270 f"{feature_names[j]}: %{{x:.6g}}<br>" 

1271 "status: failed (NaN target)<extra></extra>"), 

1272 customdata=trial_ids_fail 

1273 ), row=row_pos, col=1) 

1274 

1275 # overlays 

1276 if optimal_df is not None and feature_names[j] in optimal_df.columns: 

1277 x_opt = optimal_df[feature_names[j]].values 

1278 y_opt = optimal_df["pred_target_mean"].values 

1279 y_sd = optimal_df["pred_target_sd"].values 

1280 fig.add_trace(go.Scattergl( 

1281 x=x_opt, y=y_opt, mode="markers", 

1282 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"), 

1283 name="optimal", legendgroup="optimal", showlegend=show_legend, 

1284 hovertemplate=(f"predicted: %{{y:.3g}} ± {y_sd[0]:.3g}<br>" 

1285 f"{feature_names[j]}: %{{x:.6g}}<extra></extra>") 

1286 ), row=row_pos, col=1) 

1287 

1288 if suggest_df is not None and feature_names[j] in suggest_df.columns: 

1289 x_sug = suggest_df[feature_names[j]].values 

1290 y_sug = suggest_df["pred_target_mean"].values 

1291 y_sd = suggest_df["pred_target_sd"].values 

1292 fig.add_trace(go.Scattergl( 

1293 x=x_sug, y=y_sug, mode="markers", 

1294 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"), 

1295 name="suggested", legendgroup="suggested", showlegend=show_legend, 

1296 hovertemplate=(f"predicted: %{{y:.3g}} ± {{y_sd:.3g}}<br>" 

1297 f"{feature_names[j]}: %{{x:.6g}}<extra></extra>") 

1298 ), row=row_pos, col=1) 

1299 

1300 # axes 

1301 _maybe_log_axis(fig, row_pos, 1, feature_names[j], axis="x", transforms=transforms, j=j) 

1302 fig.update_yaxes(title_text=f"{tgt_col}", row=row_pos, col=1) 

1303 _set_yaxis_range(fig, row=row_pos, col=1, 

1304 y0=y0_plot, y1=y1_plot, 

1305 log=use_log_scale_for_target_y, eps=log_y_epsilon) 

1306 

1307 # restrict x-range if constrained 

1308 is_log_x = (transforms[j] == "log10") 

1309 def _std_to_orig(val_std: float) -> float: 

1310 vi = val_std * X_std[j] + X_mean[j] 

1311 return float(_inverse_transform(transforms[j], np.array([vi]))[0]) 

1312 

1313 x_min_override = x_max_override = None 

1314 if j in range_windows: 

1315 lo_std, hi_std = range_windows[j] 

1316 x_min_override = min(_std_to_orig(lo_std), _std_to_orig(hi_std)) 

1317 x_max_override = max(_std_to_orig(lo_std), _std_to_orig(hi_std)) 

1318 elif j in choice_values: 

1319 ints = choice_values[j] * X_std[j] + X_mean[j] 

1320 origs = _inverse_transform(transforms[j], ints) 

1321 x_min_override = float(np.min(origs)) 

1322 x_max_override = float(np.max(origs)) 

1323 else: 

1324 finite = finite_raw 

1325 if finite.size: 

1326 x_min_override = float(np.min(finite)) 

1327 x_max_override = float(np.max(finite)) 

1328 

1329 if (x_min_override is not None) and (x_max_override is not None): 

1330 if is_log_x: 

1331 x0 = max(x_min_override, 1e-12) 

1332 x1 = max(x_max_override, x0 * (1 + 1e-9)) 

1333 pad = (x1 / x0) ** 0.03 

1334 fig.update_xaxes(type="log", 

1335 range=[np.log10(x0 / pad), np.log10(x1 * pad)], 

1336 row=row_pos, col=1) 

1337 else: 

1338 span = (x_max_override - x_min_override) or 1.0 

1339 pad = 0.02 * span 

1340 fig.update_xaxes(range=[x_min_override - pad, x_max_override + pad], 

1341 row=row_pos, col=1) 

1342 

1343 fig.update_xaxes(title_text=feature_names[j], row=row_pos, col=1) 

1344 

1345 # tidy rows 

1346 for xd, xi, mu_i, sd_i, p_i in zip(x_display, x_internal, mu_grid, sd_grid, p_grid): 

1347 tidy_rows.append({ 

1348 "feature": feature_names[j], 

1349 "x_display": float(xd), 

1350 "x_internal": float(xi), 

1351 "target_conditional_mean": float(mu_i), 

1352 "target_conditional_sd": float(sd_i), 

1353 "success_probability": float(p_i), 

1354 }) 

1355 

1356 else: 

1357 base = key # categorical base 

1358 labels_all = groups[base]["labels"] 

1359 labels = cat_allowed.get(base, labels_all) 

1360 

1361 # Build standardized design for each label at the base point 

1362 Xn_grid = np.repeat(base_std[None, :], len(labels), axis=0) 

1363 for r, lab in enumerate(labels): 

1364 for lab2 in labels_all: 

1365 member_name = groups[base]["name_by_label"][lab2] 

1366 j2 = name_to_idx[member_name] 

1367 raw_val = 1.0 if (lab2 == lab) else 0.0 

1368 # standardized set: 

1369 Xi = (raw_val - X_mean[j2]) / X_std[j2] 

1370 Xn_grid[r, j2] = Xi 

1371 

1372 p_vec = pred_success(Xn_grid) 

1373 mu_vec, sd_vec = pred_loss(Xn_grid, include_observation_noise=True) 

1374 print(feature_names[j], "mu range:", float(np.ptp(mu_grid))) 

1375 

1376 # y transform 

1377 if use_log_scale_for_target_y: 

1378 mu_plot = np.maximum(mu_vec, log_y_epsilon) 

1379 lo_plot = np.maximum(mu_vec - 2.0 * sd_vec, log_y_epsilon) 

1380 hi_plot = np.maximum(mu_vec + 2.0 * sd_vec, log_y_epsilon) 

1381 losses_s_plot = np.maximum(df_raw_f.loc[success_mask, tgt_col].to_numpy().astype(float), log_y_epsilon) if success_mask.any() else np.array([]) 

1382 else: 

1383 mu_plot = mu_vec 

1384 lo_plot = mu_vec - 2.0 * sd_vec 

1385 hi_plot = mu_vec + 2.0 * sd_vec 

1386 losses_s_plot = df_raw_f.loc[success_mask, tgt_col].to_numpy().astype(float) if success_mask.any() else np.array([]) 

1387 

1388 # y-range 

1389 y_arrays = [lo_plot, hi_plot] + ([losses_s_plot] if losses_s_plot.size else []) 

1390 y_low = float(np.nanmin([np.nanmin(a) for a in y_arrays])) if y_arrays else 0.0 

1391 y_high = float(np.nanmax([np.nanmax(a) for a in y_arrays])) if y_arrays else 1.0 

1392 pad = 0.05 * (y_high - y_low + 1e-12) 

1393 y0_plot = (y_low - pad) if not use_log_scale_for_target_y else max(y_low / 1.5, log_y_epsilon) 

1394 y1_tmp = (y_high + pad) if not use_log_scale_for_target_y else y_high * 1.2 

1395 y_failed_band = y1_tmp + (y_high - y_low + 1e-12) * (0.08 if not use_log_scale_for_target_y else 0.3) 

1396 if use_log_scale_for_target_y and y_failed_band <= log_y_epsilon: 

1397 y_failed_band = max(10.0 * log_y_epsilon, y_high * 2.0) 

1398 y1_plot = y_failed_band + (0.02 if not use_log_scale_for_target_y else 0.05) * (y_high - y_low + 1e-12) 

1399 

1400 # x positions are 0..K-1 with tick labels = category names 

1401 x_pos = np.arange(len(labels), dtype=float) 

1402 

1403 # shading per-category threshold regions using shapes 

1404 def _shade_for_thresh(thr: float, alpha: float): 

1405 for k_i, p_i in enumerate(p_vec): 

1406 if p_i < thr: 

1407 fig.add_shape( 

1408 type="rect", 

1409 xref=f"x{'' if row_pos==1 else row_pos}", 

1410 yref=f"y{'' if row_pos==1 else row_pos}", 

1411 x0=k_i - 0.5, x1=k_i + 0.5, 

1412 y0=y0_plot, y1=y1_plot, 

1413 line=dict(width=0), 

1414 fillcolor=f"rgba(128,128,128,{alpha})", 

1415 layer="below", 

1416 row=row_pos, col=1 

1417 ) 

1418 _shade_for_thresh(0.8, 0.40) 

1419 _shade_for_thresh(0.5, 0.25) 

1420 

1421 show_legend = (row_pos == 1) 

1422 

1423 # mean with error bars (±2σ) 

1424 fig.add_trace(go.Scatter( 

1425 x=x_pos, y=mu_plot, mode="lines+markers", 

1426 line=dict(width=2, color=line_color), 

1427 marker=dict(size=7, color=line_color), 

1428 error_y=dict(type="data", array=(hi_plot - mu_plot), arrayminus=(mu_plot - lo_plot), visible=True), 

1429 name="E[target|success]", legendgroup="mean", showlegend=show_legend, 

1430 hovertemplate=(f"{base}: %{{text}}<br>E[target|success]: %{{y:.3f}}" 

1431 "<br>±2σ shown as error bar<extra></extra>"), 

1432 text=labels 

1433 ), row=row_pos, col=1) 

1434 

1435 # experimental points: map each row's label to index 

1436 if base in df_raw_f.columns: 

1437 lab_series = df_raw_f[base].astype("string") 

1438 else: 

1439 # reconstruct from one-hot members 

1440 member_cols = [groups[base]["name_by_label"][lab] for lab in labels_all] 

1441 idx_max = df_raw_f[member_cols].to_numpy().argmax(axis=1) 

1442 lab_series = pd.Series([labels_all[i] for i in idx_max], dtype="string") 

1443 

1444 label_to_idx = {lab: i for i, lab in enumerate(labels)} 

1445 x_idx_all = lab_series.map(lambda s: label_to_idx.get(str(s), np.nan)).to_numpy(dtype=float) 

1446 x_idx_succ = x_idx_all[success_mask] 

1447 x_idx_fail = x_idx_all[fail_mask] 

1448 

1449 # jitter for visibility 

1450 rng = np.random.default_rng(0) 

1451 jitter = lambda n: (rng.random(n) - 0.5) * 0.15 

1452 

1453 if x_idx_succ.size: 

1454 fig.add_trace(go.Scattergl( 

1455 x=x_idx_succ + jitter(x_idx_succ.size), 

1456 y=losses_s_plot, 

1457 mode="markers", 

1458 marker=dict(size=5, color="black", line=dict(width=0)), 

1459 name="data (success)", legendgroup="data_s", showlegend=show_legend, 

1460 hovertemplate=("trial_id: %{customdata}<br>" 

1461 f"{base}: %{{text}}<br>" 

1462 f"{tgt_col}: %{{y:.4f}}<extra></extra>"), 

1463 text=[labels[int(i)] if np.isfinite(i) and int(i) < len(labels) else "?" for i in x_idx_succ], 

1464 customdata=trial_ids_success 

1465 ), row=row_pos, col=1) 

1466 

1467 if x_idx_fail.size: 

1468 y_fail_plot = np.full_like(x_idx_fail, y_failed_band, dtype=float) 

1469 fig.add_trace(go.Scattergl( 

1470 x=x_idx_fail + jitter(x_idx_fail.size), y=y_fail_plot, mode="markers", 

1471 marker=dict(size=6, color="red", line=dict(color="black", width=0.8)), 

1472 name="data (failed)", legendgroup="data_f", showlegend=show_legend, 

1473 hovertemplate=("trial_id: %{customdata}<br>" 

1474 f"{base}: %{{text}}<br>" 

1475 "status: failed (NaN target)<extra></extra>"), 

1476 text=[labels[int(i)] if np.isfinite(i) and int(i) < len(labels) else "?" for i in x_idx_fail], 

1477 customdata=trial_ids_fail 

1478 ), row=row_pos, col=1) 

1479 

1480 # overlays for categorical base: map label to x index 

1481 if optimal_df is not None and (base in optimal_df.columns): 

1482 lab_opt = str(optimal_df[base].values[0]) 

1483 if lab_opt in label_to_idx: 

1484 x_opt = [float(label_to_idx[lab_opt])] 

1485 y_opt = optimal_df["pred_target_mean"].values 

1486 y_sd = optimal_df["pred_target_sd"].values 

1487 fig.add_trace(go.Scattergl( 

1488 x=x_opt, y=y_opt, mode="markers", 

1489 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"), 

1490 name="optimal", legendgroup="optimal", showlegend=show_legend, 

1491 hovertemplate=(f"predicted: %{{y:.3g}} ± {y_sd[0]:.3g}<br>" 

1492 f"{base}: {lab_opt}<extra></extra>") 

1493 ), row=row_pos, col=1) 

1494 

1495 if suggest_df is not None and (base in suggest_df.columns): 

1496 labs_sug = suggest_df[base].astype(str).tolist() 

1497 xs = [label_to_idx[l] for l in labs_sug if l in label_to_idx] 

1498 if xs: 

1499 keep_mask = [l in label_to_idx for l in labs_sug] 

1500 y_sug = suggest_df.loc[keep_mask, "pred_target_mean"].values 

1501 fig.add_trace(go.Scattergl( 

1502 x=np.array(xs, dtype=float), y=y_sug, mode="markers", 

1503 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"), 

1504 name="suggested", legendgroup="suggested", showlegend=show_legend, 

1505 hovertemplate=(f"{base}: %{{text}}<br>" 

1506 "predicted: %{{y:.3g}}<extra>suggested</extra>"), 

1507 text=[labels[int(i)] for i in xs] 

1508 ), row=row_pos, col=1) 

1509 

1510 # axes: categorical ticks 

1511 fig.update_xaxes( 

1512 tickmode="array", 

1513 tickvals=x_pos.tolist(), 

1514 ticktext=labels, 

1515 row=row_pos, col=1 

1516 ) 

1517 fig.update_yaxes(title_text=f"{tgt_col}", row=row_pos, col=1) 

1518 _set_yaxis_range(fig, row=row_pos, col=1, 

1519 y0=y0_plot, y1=y1_plot, 

1520 log=use_log_scale_for_target_y, eps=log_y_epsilon) 

1521 fig.update_xaxes(title_text=base, row=row_pos, col=1) 

1522 

1523 # tidy rows 

1524 for lab, mu_i, sd_i, p_i in zip(labels, mu_vec, sd_vec, p_vec): 

1525 tidy_rows.append({ 

1526 "feature": base, 

1527 "x_display": str(lab), 

1528 "x_internal": float("nan"), 

1529 "target_conditional_mean": float(mu_i), 

1530 "target_conditional_sd": float(sd_i), 

1531 "success_probability": float(p_i), 

1532 }) 

1533 

1534 # title w/ constraints summary 

1535 def _fmt_c(v): 

1536 if isinstance(v, slice): 

1537 a = "" if v.start is None else f"{v.start:g}" 

1538 b = "" if v.stop is None else f"{v.stop:g}" 

1539 return f"[{a},{b}]" 

1540 if isinstance(v, (list, tuple, np.ndarray)): 

1541 try: 

1542 return "[" + ",".join(f"{float(x):g}" for x in np.asarray(v).tolist()) + "]" 

1543 except Exception: 

1544 return "[" + ",".join(map(str, v)) + "]" 

1545 try: 

1546 return f"{float(v):g}" 

1547 except Exception: 

1548 return str(v) 

1549 

1550 parts = [f"1D partial dependence of expected {tgt_col}"] 

1551 if kw_num_raw: 

1552 parts.append(", ".join(f"{k}={_fmt_c(v)}" for k, v in kw_num_raw.items())) 

1553 if cat_fixed: 

1554 parts.append(", ".join(f"{b}={lab}" for b, lab in cat_fixed.items())) 

1555 if cat_allowed: 

1556 parts.append(", ".join(f"{b}∈{{{', '.join(v)}}}" for b, v in cat_allowed.items())) 

1557 title = " — ".join(parts) if len(parts) > 1 else parts[0] 

1558 

1559 width = width if (width and width > 0) else 1200 

1560 height = height if (height and height > 0) else 1200 

1561 

1562 fig.update_layout( 

1563 height=height, 

1564 width=width, 

1565 template="simple_white", 

1566 title=title, 

1567 legend_title_text="" 

1568 ) 

1569 

1570 if output: 

1571 write_image(fig, output) 

1572 if csv_out: 

1573 csv_out = Path(csv_out) 

1574 csv_out.parent.mkdir(parents=True, exist_ok=True) 

1575 pd.DataFrame(tidy_rows).to_csv(str(csv_out), index=False) 

1576 if show: 

1577 fig.show("browser") 

1578 

1579 return fig 

1580 

1581 

1582# ============================================================================= 

1583# Helpers: dataset → predictors & featurization 

1584# ============================================================================= 

1585def _build_predictors(ds: xr.Dataset): 

1586 """Reconstruct fast GP predictors from the artifact using shared helpers.""" 

1587 # Training matrices / targets 

1588 Xn_all = ds["Xn_train"].values.astype(float) # (N, p) 

1589 y_success = ds["y_success"].values.astype(float) # (N,) 

1590 Xn_ok = ds["Xn_success_only"].values.astype(float) # (Ns, p) 

1591 y_loss_centered = ds["y_loss_centered"].values.astype(float) 

1592 

1593 # Compatibility: conditional_loss_mean may be a var or an attr 

1594 cond_mean = ( 

1595 float(ds["conditional_loss_mean"].values) 

1596 if "conditional_loss_mean" in ds 

1597 else float(ds.attrs.get("conditional_loss_mean", 0.0)) 

1598 ) 

1599 

1600 # Success head MAP params 

1601 ell_s = ds["map_success_ell"].values.astype(float) # (p,) 

1602 eta_s = float(ds["map_success_eta"].values) 

1603 sigma_s = float(ds["map_success_sigma"].values) 

1604 beta0_s = float(ds["map_success_beta0"].values) 

1605 

1606 # Loss head MAP params 

1607 ell_l = ds["map_loss_ell"].values.astype(float) # (p,) 

1608 eta_l = float(ds["map_loss_eta"].values) 

1609 sigma_l = float(ds["map_loss_sigma"].values) 

1610 mean_c = float(ds["map_loss_mean_const"].values) 

1611 

1612 # --- Cholesky precomputations (success) --- 

1613 K_s = kernel_m52_ard(Xn_all, Xn_all, ell_s, eta_s) + (sigma_s**2) * np.eye(Xn_all.shape[0]) 

1614 L_s = np.linalg.cholesky(add_jitter(K_s)) 

1615 alpha_s = solve_chol(L_s, (y_success - beta0_s)) 

1616 

1617 # --- Cholesky precomputations (loss | success) --- 

1618 K_l = kernel_m52_ard(Xn_ok, Xn_ok, ell_l, eta_l) + (sigma_l**2) * np.eye(Xn_ok.shape[0]) 

1619 L_l = np.linalg.cholesky(add_jitter(K_l)) 

1620 alpha_l = solve_chol(L_l, (y_loss_centered - mean_c)) 

1621 

1622 def predict_success_probability(Xn: np.ndarray) -> np.ndarray: 

1623 Ks = kernel_m52_ard(Xn, Xn_all, ell_s, eta_s) 

1624 mu = beta0_s + Ks @ alpha_s 

1625 return np.clip(mu, 0.0, 1.0) 

1626 

1627 def predict_conditional_target( 

1628 Xn: np.ndarray, 

1629 include_observation_noise: bool = True 

1630 ): 

1631 Kl = kernel_m52_ard(Xn, Xn_ok, ell_l, eta_l) 

1632 mu_centered = mean_c + Kl @ alpha_l 

1633 mu = mu_centered + cond_mean 

1634 

1635 # diag predictive variance 

1636 v = solve_lower(L_l, Kl.T) # (Ns, Nt) 

1637 var = kernel_diag_m52(Xn, ell_l, eta_l) - np.sum(v * v, axis=0) 

1638 var = np.maximum(var, 1e-12) 

1639 if include_observation_noise: 

1640 var = var + sigma_l**2 

1641 sd = np.sqrt(var) 

1642 return mu, sd 

1643 

1644 return predict_success_probability, predict_conditional_target 

1645 

1646 

1647def _raw_dataframe_from_dataset(ds: xr.Dataset) -> pd.DataFrame: 

1648 """Collect raw columns from the artifact into a DataFrame for plotting.""" 

1649 cols = {} 

1650 for name in ds.data_vars: 

1651 # include only row-aligned arrays 

1652 da = ds[name] 

1653 if "row" in da.dims and len(da.dims) == 1 and da.sizes["row"] == ds.sizes["row"]: 

1654 cols[name] = da.values 

1655 # Ensure trial_id exists for hover 

1656 if "trial_id" not in cols: 

1657 cols["trial_id"] = np.arange(ds.sizes["row"], dtype=int) 

1658 return pd.DataFrame(cols) 

1659 

1660 

1661def _apply_fixed_to_base( 

1662 base_std: np.ndarray, 

1663 fixed: dict[str, float], 

1664 feature_names: list[str], 

1665 transforms: list[str], 

1666 X_mean: np.ndarray, 

1667 X_std: np.ndarray, 

1668) -> np.ndarray: 

1669 """Override base point in standardized space with fixed ORIGINAL values.""" 

1670 out = base_std.copy() 

1671 name_to_idx = {n: i for i, n in enumerate(feature_names)} 

1672 for k, v in fixed.items(): 

1673 if k not in name_to_idx: 

1674 raise KeyError(f"Fixed variable '{k}' is not a model feature.") 

1675 j = name_to_idx[k] 

1676 x_raw = _forward_transform(transforms[j], float(v)) 

1677 out[j] = (x_raw - X_mean[j]) / X_std[j] 

1678 return out 

1679 

1680 

1681def _denormalize_then_inverse_transform(j: int, x_std: np.ndarray, transforms, X_mean, X_std) -> np.ndarray: 

1682 x_raw = x_std * X_std[j] + X_mean[j] 

1683 return _inverse_transform(transforms[j], x_raw) 

1684 

1685 

1686def _forward_transform(tr: str, x: float | np.ndarray) -> np.ndarray: 

1687 if tr == "log10": 

1688 x = np.asarray(x, dtype=float) 

1689 return np.log10(np.maximum(x, 1e-12)) 

1690 return np.asarray(x, dtype=float) 

1691 

1692 

1693def _inverse_transform(tr: str, x: np.ndarray) -> np.ndarray: 

1694 if tr == "log10": 

1695 return 10.0 ** x 

1696 return x 

1697 

1698 

1699def _maybe_log_axis(fig: go.Figure, row: int, col: int, name: str, axis: str = "x", transforms: list[str] | None = None, j: int | None = None): 

1700 """Use log axis for features that were log10-transformed.""" 

1701 use_log = False 

1702 if transforms is not None and j is not None: 

1703 use_log = (transforms[j] == "log10") 

1704 else: 

1705 use_log = ("learning_rate" in name.lower() or name.lower() == "lr") 

1706 if use_log: 

1707 if axis == "x": 

1708 fig.update_xaxes(type="log", row=row, col=col) 

1709 else: 

1710 fig.update_yaxes(type="log", row=row, col=col) 

1711 

1712 

1713def _rgb_string_to_tuple(s: str) -> tuple[int, int, int]: 

1714 vals = s[s.find("(") + 1 : s.find(")")].split(",") 

1715 r, g, b = [int(float(v)) for v in vals[:3]] 

1716 return r, g, b 

1717 

1718 

1719def _rgb_to_rgba(rgb: str, alpha: float) -> str: 

1720 # expects "rgb(r,g,b)" or "rgba(r,g,b,a)" 

1721 try: 

1722 r, g, b = _rgb_string_to_tuple(rgb) 

1723 except Exception: 

1724 r, g, b = (31, 119, 180) 

1725 return f"rgba({r},{g},{b},{alpha:.3f})" 

1726 

1727 

1728def _add_low_success_shading_1d(fig: go.Figure, row_idx: int, x_vals: np.ndarray, p: np.ndarray, y0: float, y1: float): 

1729 xref = "x" if row_idx == 1 else f"x{row_idx}" 

1730 yref = "y" if row_idx == 1 else f"y{row_idx}" 

1731 

1732 def _spans(vals: np.ndarray, mask: np.ndarray): 

1733 m = mask.astype(int) 

1734 diff = np.diff(np.concatenate([[0], m, [0]])) 

1735 starts = np.where(diff == 1)[0] 

1736 ends = np.where(diff == -1)[0] - 1 

1737 return [(vals[s], vals[e]) for s, e in zip(starts, ends)] 

1738 

1739 for x0, x1 in _spans(x_vals, p < 0.5): 

1740 fig.add_shape(type="rect", x0=x0, x1=x1, y0=y0, y1=y1, xref=xref, yref=yref, 

1741 line=dict(width=0), fillcolor="rgba(128,128,128,0.25)", layer="below") 

1742 for x0, x1 in _spans(x_vals, p < 0.8): 

1743 fig.add_shape(type="rect", x0=x0, x1=x1, y0=y0, y1=y1, xref=xref, yref=yref, 

1744 line=dict(width=0), fillcolor="rgba(128,128,128,0.40)", layer="below") 

1745 

1746 

1747def _set_yaxis_range(fig, *, row: int, col: int, y0: float, y1: float, log: bool, eps: float = 1e-12): 

1748 """Update a subplot's Y axis to [y0, y1]. For log axes, the range is given in log10 units.""" 

1749 if log: 

1750 y0 = max(y0, eps) 

1751 y1 = max(y1, y0 * (1.0 + 1e-6)) 

1752 fig.update_yaxes(type="log", range=[np.log10(y0), np.log10(y1)], row=row, col=col) 

1753 else: 

1754 fig.update_yaxes(type="-", range=[y0, y1], row=row, col=col) 

1755 

1756 

1757def optimum_plot1d( 

1758 model: xr.Dataset | Path | str, 

1759 output: Path | None = None, 

1760 csv_out: Path | None = None, 

1761 grid_size: int = 300, 

1762 line_color: str = "rgb(31,119,180)", 

1763 band_alpha: float = 0.25, 

1764 show: bool = False, 

1765 use_log_scale_for_target_y: bool = True, 

1766 log_y_epsilon: float = 1e-9, 

1767 optimal: bool = True, 

1768 suggest: int = 0, # optional overlay 

1769 width: int | None = None, 

1770 height: int | None = None, 

1771 seed: int | None = 42, 

1772 **kwargs, # constraints in ORIGINAL units (as in your plot1d) 

1773) -> go.Figure: 

1774 """ 

1775 1D partial-dependence panels anchored at the *optimal* hyperparameter setting: 

1776 - Compute x* = argmin/argmax mean posterior from opt.optimal(...) 

1777 - For each feature, sweep that feature; keep all *other* features fixed at x*. 

1778 Supports numeric constraints (scalars/slices/choices) and categorical bases. 

1779 """ 

1780 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model) 

1781 pred_success, pred_loss = _build_predictors(ds) 

1782 

1783 # --- metadata --- 

1784 feature_names = [str(n) for n in ds["feature"].values.tolist()] 

1785 transforms = [str(t) for t in ds["feature_transform"].values.tolist()] 

1786 X_mean = ds["feature_mean"].values.astype(float) 

1787 X_std = ds["feature_std"].values.astype(float) 

1788 

1789 df_raw = _raw_dataframe_from_dataset(ds) 

1790 Xn_train = ds["Xn_train"].values.astype(float) 

1791 n_rows, p = Xn_train.shape 

1792 

1793 # --- one-hot categorical groups --- 

1794 groups = opt._onehot_groups(feature_names) 

1795 bases = set(groups.keys()) 

1796 name_to_idx = {name: j for j, name in enumerate(feature_names)} 

1797 

1798 # --- canonicalize kwargs: numeric vs categorical (base) --- 

1799 idx_map = _canon_key_set(ds) # your helper: maps normalized names -> exact feature column 

1800 kw_num_raw: dict[str, object] = {} 

1801 kw_cat_raw: dict[str, object] = {} 

1802 for k, v in kwargs.items(): 

1803 if k in bases: 

1804 kw_cat_raw[k] = v 

1805 elif k in idx_map: 

1806 kw_num_raw[idx_map[k]] = v 

1807 else: 

1808 import re as _re 

1809 nk = _re.sub(r"[^a-z0-9]+", "", str(k).lower()) 

1810 if nk in idx_map: 

1811 kw_num_raw[idx_map[nk]] = v 

1812 

1813 # --- resolve categorical constraints: fixed vs allowed subset --- 

1814 cat_fixed: dict[str, str] = {} 

1815 cat_allowed: dict[str, list[str]] = {} 

1816 for base, val in kw_cat_raw.items(): 

1817 labels = groups[base]["labels"] 

1818 if isinstance(val, str): 

1819 if val not in labels: 

1820 raise ValueError(f"Unknown category for {base!r}: {val!r}. Choices: {labels}") 

1821 cat_fixed[base] = val 

1822 elif isinstance(val, (list, tuple, set)): 

1823 chosen = [x for x in val if isinstance(x, str) and x in labels] 

1824 if not chosen: 

1825 raise ValueError(f"No valid categories for {base!r} in {val!r}. Choices: {labels}") 

1826 cat_allowed[base] = list(dict.fromkeys(chosen)) 

1827 else: 

1828 raise ValueError(f"Categorical constraint for {base!r} must be a string or list/tuple of strings.") 

1829 

1830 # --- row mask for experimental points (matches constraints) --- 

1831 row_mask = np.ones(n_rows, dtype=bool) 

1832 for base, label in cat_fixed.items(): 

1833 if base in df_raw.columns: 

1834 series = df_raw[base].astype("string") 

1835 row_mask &= series.eq(label).fillna(False).to_numpy() 

1836 else: 

1837 member_name = groups[base]["name_by_label"][label] 

1838 j = name_to_idx[member_name] 

1839 raw_j = feature_raw_from_artifact_or_reconstruct(ds, j, member_name, transforms[j]).astype(float) 

1840 row_mask &= (raw_j >= 0.5) 

1841 

1842 for base, allowed in cat_allowed.items(): 

1843 if base in cat_fixed: 

1844 continue # already fixed 

1845 allowed_labels = [str(x) for x in allowed] 

1846 if base in df_raw.columns: 

1847 series = df_raw[base].astype("string").fillna("<NA>") 

1848 allowed_mask = series.isin(set(allowed_labels)).fillna(False).to_numpy() 

1849 row_mask &= allowed_mask 

1850 else: 

1851 allowed_masks: list[np.ndarray] = [] 

1852 for label in allowed_labels: 

1853 member_name = groups[base]["name_by_label"].get(label) 

1854 if member_name is None: 

1855 continue 

1856 j = name_to_idx[member_name] 

1857 raw_j = feature_raw_from_artifact_or_reconstruct(ds, j, member_name, transforms[j]).astype(float) 

1858 allowed_masks.append(raw_j >= 0.5) 

1859 if allowed_masks: 

1860 row_mask &= np.logical_or.reduce(allowed_masks) 

1861 else: 

1862 row_mask &= False 

1863 

1864 for name, val in kw_num_raw.items(): 

1865 if name not in name_to_idx: 

1866 continue 

1867 j = name_to_idx[name] 

1868 if name in df_raw.columns: 

1869 raw_vals = pd.to_numeric(df_raw[name], errors="coerce").to_numpy(dtype=float) 

1870 else: 

1871 raw_vals = feature_raw_from_artifact_or_reconstruct(ds, j, feature_names[j], transforms[j]).astype(float) 

1872 

1873 mask = np.isfinite(raw_vals) 

1874 if isinstance(val, slice): 

1875 lo_raw = -np.inf if val.start is None else float(val.start) 

1876 hi_raw = np.inf if val.stop is None else float(val.stop) 

1877 if hi_raw < lo_raw: 

1878 lo_raw, hi_raw = hi_raw, lo_raw 

1879 mask &= (raw_vals >= lo_raw) & (raw_vals <= hi_raw) 

1880 elif isinstance(val, (list, tuple, set, np.ndarray)): 

1881 arr = np.asarray(list(val) if not isinstance(val, np.ndarray) else val, dtype=float) 

1882 arr = arr[np.isfinite(arr)] 

1883 if arr.size == 0: 

1884 mask &= False 

1885 else: 

1886 mask &= np.any(np.isclose(raw_vals[:, None], arr[None, :], rtol=1e-6, atol=1e-9), axis=1) 

1887 else: 

1888 target = float(val) 

1889 mask &= np.isclose(raw_vals, target, rtol=1e-6, atol=1e-9) 

1890 

1891 row_mask &= mask 

1892 

1893 if not np.any(row_mask): 

1894 raise ValueError("No experiments match the provided constraints; cannot plot data points.") 

1895 

1896 df_raw_f = df_raw.loc[row_mask].reset_index(drop=True) 

1897 

1898 # ---------- 1) Find the *optimal* base point (original units) ---------- 

1899 opt_df = opt.optimal(model, count=1, seed=seed, **kwargs) # uses your gradient-based optimal() 

1900 # We’ll use this row both for overlays and as the anchor point. 

1901 # Expect numeric feature columns and categorical base columns present. 

1902 x_opt_std = np.zeros(p, dtype=float) 

1903 

1904 # Fill numerics from optimal row (orig -> internal -> std) 

1905 def _to_std_single(j: int, x_orig: float) -> float: 

1906 xi = x_orig 

1907 if transforms[j] == "log10": 

1908 xi = np.log10(np.maximum(x_orig, 1e-300)) 

1909 return float((xi - X_mean[j]) / X_std[j]) 

1910 

1911 # Mark one-hot member names 

1912 onehot_members: set[str] = set() 

1913 for base, g in groups.items(): 

1914 onehot_members.update(g["members"]) 

1915 

1916 # numeric features (skip one-hot members) 

1917 for j, nm in enumerate(feature_names): 

1918 if nm in onehot_members: 

1919 continue 

1920 if nm in opt_df.columns: 

1921 x_opt_std[j] = _to_std_single(j, float(opt_df.iloc[0][nm])) 

1922 else: 

1923 # Fall back to dataset median if not present (rare) 

1924 x_opt_std[j] = float(np.median(Xn_train[:, j])) 

1925 

1926 # Categorical bases: set one-hot block to the optimal label (or fixed) 

1927 for base, g in groups.items(): 

1928 # priority: fixed in kwargs → else from optimal row → else keep current (median/std) 

1929 if base in cat_fixed: 

1930 label = cat_fixed[base] 

1931 elif base in opt_df.columns: 

1932 label = str(opt_df.iloc[0][base]) 

1933 else: 

1934 # fallback: most frequent label in data 

1935 if base in df_raw.columns: 

1936 label = str(df_raw[base].astype("string").mode(dropna=True).iloc[0]) 

1937 else: 

1938 label = g["labels"][0] 

1939 

1940 for lab in g["labels"]: 

1941 member_name = g["name_by_label"][lab] 

1942 j2 = name_to_idx[member_name] 

1943 raw = 1.0 if lab == label else 0.0 

1944 # raw (0/1) → standardized using the artifact stats 

1945 x_opt_std[j2] = (raw - X_mean[j2]) / X_std[j2] 

1946 

1947 # ---------- 2) Numeric constraints in STANDARDIZED space ---------- 

1948 def _orig_to_std(j: int, x, transforms, mu, sd): 

1949 x = np.asarray(x, dtype=float) 

1950 if transforms[j] == "log10": 

1951 x = np.where(x <= 0, np.nan, x) 

1952 x = np.log10(x) 

1953 return (x - mu[j]) / sd[j] 

1954 

1955 fixed_scalars: dict[int, float] = {} 

1956 range_windows: dict[int, tuple[float, float]] = {} 

1957 choice_values: dict[int, np.ndarray] = {} 

1958 

1959 for name, val in kw_num_raw.items(): 

1960 if name not in name_to_idx: 

1961 continue 

1962 j = name_to_idx[name] 

1963 if isinstance(val, slice): 

1964 lo = _orig_to_std(j, float(val.start), transforms, X_mean, X_std) 

1965 hi = _orig_to_std(j, float(val.stop), transforms, X_mean, X_std) 

1966 lo, hi = float(min(lo, hi)), float(max(lo, hi)) 

1967 range_windows[j] = (lo, hi) 

1968 elif isinstance(val, (list, tuple, np.ndarray)): 

1969 arr = _orig_to_std(j, np.asarray(val, dtype=float), transforms, X_mean, X_std) 

1970 choice_values[j] = np.asarray(arr, dtype=float) 

1971 else: 

1972 fixed_scalars[j] = float(_orig_to_std(j, float(val), transforms, X_mean, X_std)) 

1973 

1974 # apply numeric fixed overrides on the base point 

1975 for j, vstd in fixed_scalars.items(): 

1976 x_opt_std[j] = vstd 

1977 

1978 # ---------- 3) Panels: sweep ONE var at a time around x* ---------- 

1979 # numeric free = not one-hot member and not fixed via kwargs 

1980 free_numeric_idx = [ 

1981 j for j, nm in enumerate(feature_names) 

1982 if (nm not in onehot_members) and (j not in fixed_scalars) 

1983 ] 

1984 # categorical bases: sweep if not fixed; otherwise not shown 

1985 free_cat_bases = [b for b in bases if b not in cat_fixed] 

1986 

1987 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases] 

1988 if not panels: 

1989 raise ValueError("All features are fixed at the optimum (or categoricals fixed); nothing to plot.") 

1990 

1991 # empirical 1–99% per feature (for default sweep range) 

1992 Xn_p01 = np.percentile(Xn_train, 1, axis=0) 

1993 Xn_p99 = np.percentile(Xn_train, 99, axis=0) 

1994 

1995 def _grid_1d(j: int, n: int) -> np.ndarray: 

1996 # default range in std space 

1997 lo, hi = float(Xn_p01[j]), float(Xn_p99[j]) 

1998 if j in range_windows: 

1999 lo = max(lo, range_windows[j][0]) 

2000 hi = min(hi, range_windows[j][1]) 

2001 if j in choice_values: 

2002 vals = np.asarray(choice_values[j], dtype=float) 

2003 vals = vals[(vals >= lo) & (vals <= hi)] 

2004 return np.unique(np.sort(vals)) if vals.size else np.array([x_opt_std[j]], dtype=float) 

2005 if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo: 

2006 lo, hi = x_opt_std[j] - 1.0, x_opt_std[j] + 1.0 

2007 return np.linspace(lo, hi, n) 

2008 

2009 # figure scaffold 

2010 subplot_titles = [feature_names[int(k)] if t == "num" else str(k) for t, k in panels] 

2011 fig = make_subplots(rows=len(panels), cols=1, shared_xaxes=False, subplot_titles=subplot_titles) 

2012 

2013 tgt_col = str(ds.attrs["target"]) 

2014 success_mask = ~pd.isna(df_raw_f[tgt_col]).to_numpy() 

2015 fail_mask = ~success_mask 

2016 losses_success = df_raw_f.loc[success_mask, tgt_col].to_numpy().astype(float) 

2017 trial_ids_success = df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask] 

2018 trial_ids_fail = df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask] 

2019 band_fill_rgba = _rgb_to_rgba(line_color, band_alpha) 

2020 

2021 # optional overlay 

2022 suggest_df = opt.suggest(model, count=suggest, seed=seed, **kwargs) if (suggest and suggest > 0) else None 

2023 

2024 tidy_rows: list[dict] = [] 

2025 row_pos = 0 

2026 for kind, key in panels: 

2027 row_pos += 1 

2028 

2029 if kind == "num": 

2030 j = int(key) 

2031 if feature_names[j] in df_raw.columns: 

2032 series_full = pd.to_numeric(df_raw[feature_names[j]], errors="coerce") 

2033 x_full_raw = series_full.to_numpy(dtype=float) 

2034 else: 

2035 x_full_raw = feature_raw_from_artifact_or_reconstruct( 

2036 ds, j, feature_names[j], transforms[j] 

2037 ).astype(float) 

2038 x_data_all = x_full_raw[row_mask] 

2039 finite_raw = x_full_raw[np.isfinite(x_full_raw)] 

2040 if transforms[j] == "log10": 

2041 finite_raw = finite_raw[finite_raw > 0] 

2042 

2043 grid = _grid_1d(j, grid_size) 

2044 if (j not in range_windows) and (j not in choice_values) and finite_raw.size: 

2045 finite_std = _orig_to_std(j, finite_raw, transforms, X_mean, X_std) 

2046 grid_min = float(np.nanmin(np.concatenate([grid, finite_std]))) 

2047 grid_max = float(np.nanmax(np.concatenate([grid, finite_std]))) 

2048 if grid_max > grid_min: 

2049 grid = np.linspace(grid_min, grid_max, grid_size) 

2050 else: 

2051 grid = np.array([grid_min], dtype=float) 

2052 Xn_grid = np.repeat(x_opt_std[None, :], len(grid), axis=0) 

2053 Xn_grid[:, j] = grid 

2054 

2055 p_grid = pred_success(Xn_grid) 

2056 mu_grid, sd_grid = pred_loss(Xn_grid, include_observation_noise=True) 

2057 

2058 x_internal = grid * X_std[j] + X_mean[j] 

2059 x_display = _inverse_transform(transforms[j], x_internal) 

2060 

2061 # y transform 

2062 if use_log_scale_for_target_y: 

2063 mu_plot = np.maximum(mu_grid, log_y_epsilon) 

2064 lo_plot = np.maximum(mu_grid - 2.0 * sd_grid, log_y_epsilon) 

2065 hi_plot = np.maximum(mu_grid + 2.0 * sd_grid, log_y_epsilon) 

2066 losses_s_plot = np.maximum(losses_success, log_y_epsilon) if losses_success.size else losses_success 

2067 else: 

2068 mu_plot = mu_grid 

2069 lo_plot = mu_grid - 2.0 * sd_grid 

2070 hi_plot = mu_grid + 2.0 * sd_grid 

2071 losses_s_plot = losses_success 

2072 

2073 y_arrays = [lo_plot, hi_plot] + ([losses_s_plot] if losses_s_plot.size else []) 

2074 y_low = float(np.nanmin([np.nanmin(a) for a in y_arrays])) 

2075 y_high = float(np.nanmax([np.nanmax(a) for a in y_arrays])) 

2076 pad = 0.05 * (y_high - y_low + 1e-12) 

2077 y0_plot = (y_low - pad) if not use_log_scale_for_target_y else max(y_low / 1.5, log_y_epsilon) 

2078 y1_tmp = (y_high + pad) if not use_log_scale_for_target_y else y_high * 1.2 

2079 y_failed_band = y1_tmp + (y_high - y_low + 1e-12) * (0.08 if not use_log_scale_for_target_y else 0.3) 

2080 if use_log_scale_for_target_y and y_failed_band <= log_y_epsilon: 

2081 y_failed_band = max(10.0 * log_y_epsilon, y_high * 2.0) 

2082 y1_plot = y_failed_band + (0.02 if not use_log_scale_for_target_y else 0.05) * (y_high - y_low + 1e-12) 

2083 

2084 _add_low_success_shading_1d(fig, row_pos, x_display, p_grid, y0_plot, y1_plot) 

2085 

2086 show_legend = (row_pos == 1) 

2087 # ±2σ band 

2088 fig.add_trace(go.Scatter(x=x_display, y=lo_plot, mode="lines", 

2089 line=dict(width=0, color=line_color), 

2090 name="±2σ", legendgroup="band", showlegend=False, hoverinfo="skip"), 

2091 row=row_pos, col=1) 

2092 fig.add_trace(go.Scatter(x=x_display, y=hi_plot, mode="lines", fill="tonexty", 

2093 line=dict(width=0, color=line_color), fillcolor=band_fill_rgba, 

2094 name="±2σ", legendgroup="band", showlegend=show_legend, 

2095 hovertemplate="E[target|success]: %{y:.3f}<extra>±2σ</extra>"), 

2096 row=row_pos, col=1) 

2097 # mean 

2098 fig.add_trace(go.Scatter(x=x_display, y=mu_plot, mode="lines", 

2099 line=dict(width=2, color=line_color), 

2100 name="E[target|success]", legendgroup="mean", showlegend=show_legend, 

2101 hovertemplate=f"{feature_names[j]}: %{{x:.6g}}<br>E[target|success]: %{{y:.3f}}<extra></extra>"), 

2102 row=row_pos, col=1) 

2103 

2104 # experimental points at y (filtered to constraint-satisfied rows) 

2105 x_succ = x_data_all[success_mask] 

2106 if x_succ.size: 

2107 fig.add_trace(go.Scattergl( 

2108 x=x_succ, y=losses_s_plot, mode="markers", 

2109 marker=dict(size=5, color="black", line=dict(width=0)), 

2110 name="data (success)", legendgroup="data_s", showlegend=show_legend, 

2111 hovertemplate=("trial_id: %{customdata}<br>" 

2112 f"{feature_names[j]}: %{{x:.6g}}<br>" 

2113 f"{tgt_col}: %{{y:.4f}}<extra></extra>"), 

2114 customdata=trial_ids_success 

2115 ), row=row_pos, col=1) 

2116 

2117 x_fail = x_data_all[fail_mask] 

2118 if x_fail.size: 

2119 y_fail_plot = np.full_like(x_fail, y_failed_band, dtype=float) 

2120 fig.add_trace(go.Scattergl( 

2121 x=x_fail, y=y_fail_plot, mode="markers", 

2122 marker=dict(size=6, color="red", line=dict(color="black", width=0.8)), 

2123 name="data (failed)", legendgroup="data_f", showlegend=show_legend, 

2124 hovertemplate=("trial_id: %{customdata}<br>" 

2125 f"{feature_names[j]}: %{{x:.6g}}<br>" 

2126 "status: failed (NaN target)<extra></extra>"), 

2127 customdata=trial_ids_fail 

2128 ), row=row_pos, col=1) 

2129 

2130 # overlays: optimal (single point) and suggested (optional many) 

2131 if optimal and feature_names[j] in opt_df.columns: 

2132 x_opt_disp = float(opt_df.iloc[0][feature_names[j]]) 

2133 y_opt = float(opt_df.iloc[0]["pred_target_mean"]) 

2134 y_opt_sd = float(opt_df.iloc[0].get("pred_target_sd", np.nan)) 

2135 fig.add_trace(go.Scattergl( 

2136 x=[x_opt_disp], y=[y_opt], mode="markers", 

2137 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"), 

2138 name="optimal", legendgroup="optimal", showlegend=show_legend, 

2139 hovertemplate=(f"predicted: %{{y:.3g}}" 

2140 + ("" if np.isnan(y_opt_sd) else f" ± {y_opt_sd:.3g}") 

2141 + f"<br>{feature_names[j]}: %{{x:.6g}}<extra></extra>") 

2142 ), row=row_pos, col=1) 

2143 

2144 if suggest and (suggest_df is not None) and (feature_names[j] in suggest_df.columns): 

2145 xs = suggest_df[feature_names[j]].values.astype(float) 

2146 ys = suggest_df["pred_target_mean"].values.astype(float) 

2147 ysd = suggest_df.get("pred_target_sd", pd.Series([np.nan]*len(suggest_df))).values 

2148 fig.add_trace(go.Scattergl( 

2149 x=xs, y=ys, mode="markers", 

2150 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"), 

2151 name="suggested", legendgroup="suggested", showlegend=show_legend, 

2152 hovertemplate=("predicted: %{y:.3g}" 

2153 + (" ± %{customdata:.3g}" if not np.isnan(ysd).all() else "") 

2154 + f"<br>{feature_names[j]}: %{{x:.6g}}<extra>suggested</extra>"), 

2155 customdata=ysd 

2156 ), row=row_pos, col=1) 

2157 

2158 # axes + ranges 

2159 _maybe_log_axis(fig, row_pos, 1, feature_names[j], axis="x", transforms=transforms, j=j) 

2160 fig.update_yaxes(title_text=f"{tgt_col}", row=row_pos, col=1) 

2161 _set_yaxis_range(fig, row=row_pos, col=1, 

2162 y0=y0_plot, y1=y1_plot, 

2163 log=use_log_scale_for_target_y, eps=log_y_epsilon) 

2164 fig.update_xaxes(title_text=feature_names[j], row=row_pos, col=1) 

2165 is_log_x = (transforms[j] == "log10") 

2166 

2167 # If a constraint limited the sweep, respect it on the displayed axis 

2168 def _std_to_orig(val_std: float) -> float: 

2169 vi = val_std * X_std[j] + X_mean[j] 

2170 return float(_inverse_transform(transforms[j], np.array([vi]))[0]) 

2171 

2172 if j in range_windows: 

2173 lo_std, hi_std = range_windows[j] 

2174 x_min_override = min(_std_to_orig(lo_std), _std_to_orig(hi_std)) 

2175 x_max_override = max(_std_to_orig(lo_std), _std_to_orig(hi_std)) 

2176 span = (x_max_override - x_min_override) or 1.0 

2177 pad = 0.02 * span 

2178 fig.update_xaxes(range=[x_min_override - pad, x_max_override + pad], row=row_pos, col=1) 

2179 elif j in choice_values and choice_values[j].size: 

2180 ints = choice_values[j] * X_std[j] + X_mean[j] 

2181 origs = _inverse_transform(transforms[j], ints) 

2182 span = float(np.max(origs) - np.min(origs)) or 1.0 

2183 pad = 0.05 * span 

2184 fig.update_xaxes(range=[float(np.min(origs) - pad), float(np.max(origs) + pad)], row=row_pos, col=1) 

2185 else: 

2186 if finite_raw.size: 

2187 if is_log_x: 

2188 x0 = max(float(np.min(finite_raw)), 1e-12) 

2189 x1 = max(float(np.max(finite_raw)), x0 * (1 + 1e-9)) 

2190 pad = (x1 / x0) ** 0.03 

2191 fig.update_xaxes( 

2192 range=[np.log10(x0 / pad), np.log10(x1 * pad)], 

2193 row=row_pos, 

2194 col=1, 

2195 ) 

2196 else: 

2197 x0 = float(np.min(finite_raw)) 

2198 x1 = float(np.max(finite_raw)) 

2199 span = (x1 - x0) or 1.0 

2200 pad = 0.02 * span 

2201 fig.update_xaxes(range=[x0 - pad, x1 + pad], row=row_pos, col=1) 

2202 

2203 # tidy rows 

2204 for xd, xi, mu_i, sd_i, p_i in zip(x_display, x_internal, mu_grid, sd_grid, p_grid): 

2205 tidy_rows.append({ 

2206 "feature": feature_names[j], 

2207 "x_display": float(xd), 

2208 "x_internal": float(xi), 

2209 "target_conditional_mean": float(mu_i), 

2210 "target_conditional_sd": float(sd_i), 

2211 "success_probability": float(p_i), 

2212 }) 

2213 

2214 else: 

2215 base = str(key) 

2216 labels_all = groups[base]["labels"] 

2217 labels = cat_allowed.get(base, labels_all) 

2218 

2219 # Evaluate each label with numerics and other bases fixed at x_opt_std 

2220 Xn_grid = np.repeat(x_opt_std[None, :], len(labels), axis=0) 

2221 for r, lab in enumerate(labels): 

2222 for lab2 in labels_all: 

2223 member_name = groups[base]["name_by_label"][lab2] 

2224 j2 = name_to_idx[member_name] 

2225 raw_val = 1.0 if (lab2 == lab) else 0.0 

2226 Xn_grid[r, j2] = (raw_val - X_mean[j2]) / X_std[j2] 

2227 

2228 p_vec = pred_success(Xn_grid) 

2229 mu_vec, sd_vec = pred_loss(Xn_grid, include_observation_noise=True) 

2230 

2231 # y transform 

2232 if use_log_scale_for_target_y: 

2233 mu_plot = np.maximum(mu_vec, log_y_epsilon) 

2234 lo_plot = np.maximum(mu_vec - 2.0 * sd_vec, log_y_epsilon) 

2235 hi_plot = np.maximum(mu_vec + 2.0 * sd_vec, log_y_epsilon) 

2236 losses_s_plot = np.maximum(df_raw_f.loc[success_mask, tgt_col].to_numpy().astype(float), log_y_epsilon) if success_mask.any() else np.array([]) 

2237 else: 

2238 mu_plot = mu_vec 

2239 lo_plot = mu_vec - 2.0 * sd_vec 

2240 hi_plot = mu_vec + 2.0 * sd_vec 

2241 losses_s_plot = df_raw_f.loc[success_mask, tgt_col].to_numpy().astype(float) if success_mask.any() else np.array([]) 

2242 

2243 y_arrays = [lo_plot, hi_plot] + ([losses_s_plot] if losses_s_plot.size else []) 

2244 y_low = float(np.nanmin([np.nanmin(a) for a in y_arrays])) if y_arrays else 0.0 

2245 y_high = float(np.nanmax([np.nanmax(a) for a in y_arrays])) if y_arrays else 1.0 

2246 pad = 0.05 * (y_high - y_low + 1e-12) 

2247 y0_plot = (y_low - pad) if not use_log_scale_for_target_y else max(y_low / 1.5, log_y_epsilon) 

2248 y1_tmp = (y_high + pad) if not use_log_scale_for_target_y else y_high * 1.2 

2249 y_failed_band = y1_tmp + (y_high - y_low + 1e-12) * (0.08 if not use_log_scale_for_target_y else 0.3) 

2250 if use_log_scale_for_target_y and y_failed_band <= log_y_epsilon: 

2251 y_failed_band = max(10.0 * log_y_epsilon, y_high * 2.0) 

2252 y1_plot = y_failed_band + (0.02 if not use_log_scale_for_target_y else 0.05) * (y_high - y_low + 1e-12) 

2253 

2254 # x = 0..K-1 with tick labels 

2255 x_pos = np.arange(len(labels), dtype=float) 

2256 

2257 # grey out infeasible (p<thr) 

2258 def _shade_for_thresh(thr: float, alpha: float): 

2259 for k_i, p_i in enumerate(p_vec): 

2260 if p_i < thr: 

2261 fig.add_shape( 

2262 type="rect", 

2263 xref=f"x{'' if row_pos==1 else row_pos}", 

2264 yref=f"y{'' if row_pos==1 else row_pos}", 

2265 x0=k_i - 0.5, x1=k_i + 0.5, 

2266 y0=y0_plot, y1=y1_plot, 

2267 line=dict(width=0), 

2268 fillcolor=f"rgba(128,128,128,{alpha})", 

2269 layer="below", 

2270 row=row_pos, col=1 

2271 ) 

2272 _shade_for_thresh(0.8, 0.40) 

2273 _shade_for_thresh(0.5, 0.25) 

2274 

2275 show_legend = (row_pos == 1) 

2276 fig.add_trace(go.Scatter( 

2277 x=x_pos, y=mu_plot, mode="lines+markers", 

2278 line=dict(width=2, color=line_color), 

2279 marker=dict(size=7, color=line_color), 

2280 error_y=dict(type="data", array=(hi_plot - mu_plot), arrayminus=(mu_plot - lo_plot), visible=True), 

2281 name="E[target|success]", legendgroup="mean", showlegend=show_legend, 

2282 hovertemplate=(f"{base}: %{{text}}<br>E[target|success]: %{{y:.3f}}<extra></extra>"), 

2283 text=labels 

2284 ), row=row_pos, col=1) 

2285 

2286 # overlay optimal point for this base (single label at x*=opt) 

2287 if optimal and base in opt_df.columns: 

2288 lab_opt = str(opt_df.iloc[0][base]) 

2289 if lab_opt in labels: 

2290 xi = float(labels.index(lab_opt)) 

2291 y_opt = float(opt_df.iloc[0]["pred_target_mean"]) 

2292 y_opt_sd = float(opt_df.iloc[0].get("pred_target_sd", np.nan)) 

2293 fig.add_trace(go.Scattergl( 

2294 x=[xi], y=[y_opt], mode="markers", 

2295 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"), 

2296 name="optimal", legendgroup="optimal", showlegend=show_legend, 

2297 hovertemplate=(f"predicted: %{{y:.3g}}" 

2298 + ("" if np.isnan(y_opt_sd) else f" ± {y_opt_sd:.3g}") 

2299 + f"<br>{base}: {lab_opt}<extra></extra>") 

2300 ), row=row_pos, col=1) 

2301 

2302 # overlay suggestions (optional) 

2303 if suggest and (suggest_df is not None) and (base in suggest_df.columns): 

2304 labs_sug = suggest_df[base].astype(str).tolist() 

2305 xs = [labels.index(l) for l in labs_sug if l in labels] 

2306 if xs: 

2307 keep_mask = [l in labels for l in labs_sug] 

2308 y_sug = suggest_df.loc[keep_mask, "pred_target_mean"].values 

2309 fig.add_trace(go.Scattergl( 

2310 x=np.array(xs, dtype=float), y=y_sug, mode="markers", 

2311 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"), 

2312 name="suggested", legendgroup="suggested", showlegend=show_legend, 

2313 hovertemplate=(f"{base}: %{{text}}<br>" 

2314 "predicted: %{{y:.3g}}<extra>suggested</extra>"), 

2315 text=[labels[int(i)] for i in xs] 

2316 ), row=row_pos, col=1) 

2317 

2318 fig.update_xaxes( 

2319 tickmode="array", 

2320 tickvals=x_pos.tolist(), 

2321 ticktext=labels, 

2322 title_text=base, 

2323 row=row_pos, col=1 

2324 ) 

2325 fig.update_yaxes(title_text=f"{tgt_col}", row=row_pos, col=1) 

2326 _set_yaxis_range(fig, row=row_pos, col=1, 

2327 y0=y0_plot, y1=y1_plot, 

2328 log=use_log_scale_for_target_y, eps=log_y_epsilon) 

2329 

2330 # tidy rows 

2331 for lab, mu_i, sd_i, p_i in zip(labels, mu_vec, sd_vec, p_vec): 

2332 tidy_rows.append({ 

2333 "feature": base, 

2334 "x_display": str(lab), 

2335 "x_internal": float("nan"), 

2336 "target_conditional_mean": float(mu_i), 

2337 "target_conditional_sd": float(sd_i), 

2338 "success_probability": float(p_i), 

2339 }) 

2340 

2341 # ---- layout & IO ---- 

2342 parts = [f"1D PD at optimal setting of all other hyperparameters ({ds.attrs.get('target', 'target')})"] 

2343 if kw_num_raw: 

2344 def _fmt_c(v): 

2345 if isinstance(v, slice): 

2346 a = "" if v.start is None else f"{v.start:g}" 

2347 b = "" if v.stop is None else f"{v.stop:g}" 

2348 return f"[{a},{b}]" 

2349 if isinstance(v, (list, tuple, np.ndarray)): 

2350 try: 

2351 return "[" + ",".join(f"{float(x):g}" for x in np.asarray(v).tolist()) + "]" 

2352 except Exception: 

2353 return "[" + ",".join(map(str, v)) + "]" 

2354 try: 

2355 return f"{float(v):g}" 

2356 except Exception: 

2357 return str(v) 

2358 parts.append(", ".join(f"{k}={_fmt_c(v)}" for k, v in kw_num_raw.items())) 

2359 if cat_fixed: 

2360 parts.append(", ".join(f"{b}={lab}" for b, lab in cat_fixed.items())) 

2361 title = " — ".join(parts) 

2362 

2363 width = width if (width and width > 0) else 1200 

2364 height = height if (height and height > 0) else 1200 

2365 fig.update_layout(height=height, width=width, template="simple_white", title=title, legend_title_text="") 

2366 

2367 if output: 

2368 write_image(fig, output) 

2369 if csv_out: 

2370 csv_out = Path(csv_out); csv_out.parent.mkdir(parents=True, exist_ok=True) 

2371 pd.DataFrame(tidy_rows).to_csv(str(csv_out), index=False) 

2372 if show: 

2373 fig.show("browser") 

2374 return fig 

2375 

2376 

2377def optimum_plot2d( 

2378 model: xr.Dataset | Path | str, 

2379 output: Path | None = None, 

2380 grid_size: int = 70, 

2381 use_log_scale_for_target: bool = False, 

2382 log_shift_epsilon: float = 1e-9, 

2383 colorscale: str = "RdBu", 

2384 show: bool = False, 

2385 n_contours: int = 12, 

2386 optimal: bool = True, 

2387 suggest: int = 0, 

2388 width: int | None = None, 

2389 height: int | None = None, 

2390 seed: int | None = 42, 

2391 **kwargs, 

2392) -> go.Figure: 

2393 """2D PD panels anchored at the optimal hyperparameter setting.""" 

2394 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model) 

2395 pred_success, pred_loss = _build_predictors(ds) 

2396 

2397 feature_names = [str(n) for n in ds["feature"].values.tolist()] 

2398 transforms = [str(t) for t in ds["feature_transform"].values.tolist()] 

2399 X_mean = ds["feature_mean"].values.astype(float) 

2400 X_std = ds["feature_std"].values.astype(float) 

2401 name_to_idx = {name: j for j, name in enumerate(feature_names)} 

2402 

2403 df_raw = _raw_dataframe_from_dataset(ds) 

2404 Xn_train = ds["Xn_train"].values.astype(float) 

2405 n_rows = Xn_train.shape[0] 

2406 

2407 groups = opt._onehot_groups(feature_names) 

2408 bases = set(groups.keys()) 

2409 

2410 idx_map = _canon_key_set(ds) 

2411 kw_num_raw: dict[str, object] = {} 

2412 kw_cat_raw: dict[str, object] = {} 

2413 for k, v in kwargs.items(): 

2414 if k in bases: 

2415 kw_cat_raw[k] = v 

2416 continue 

2417 if k in idx_map: 

2418 kw_num_raw[idx_map[k]] = v 

2419 continue 

2420 import re as _re 

2421 nk = _re.sub(r"[^a-z0-9]+", "", str(k).lower()) 

2422 if nk in idx_map: 

2423 kw_num_raw[idx_map[nk]] = v 

2424 

2425 cat_fixed: dict[str, str] = {} 

2426 for base, val in kw_cat_raw.items(): 

2427 labels = groups[base]["labels"] 

2428 if isinstance(val, str): 

2429 if val not in labels: 

2430 raise ValueError(f"Unknown category for {base!r}: {val!r}. Choices: {labels}") 

2431 cat_fixed[base] = val 

2432 else: 

2433 chosen = [x for x in (list(val) if isinstance(val, (list, tuple, set)) else [val]) 

2434 if isinstance(x, str) and x in labels] 

2435 if not chosen: 

2436 raise ValueError(f"No valid categories for {base!r} in {val!r}. Choices: {labels}") 

2437 if len(chosen) == 1: 

2438 cat_fixed[base] = chosen[0] 

2439 else: 

2440 raise ValueError("optimum_plot2d currently requires categorical bases to be fixed.") 

2441 

2442 row_mask = np.ones(n_rows, dtype=bool) 

2443 for base, label in cat_fixed.items(): 

2444 if base in df_raw.columns: 

2445 series = df_raw[base].astype("string") 

2446 row_mask &= series.eq(label).fillna(False).to_numpy() 

2447 else: 

2448 member_name = groups[base]["name_by_label"][label] 

2449 j = name_to_idx[member_name] 

2450 raw_j = feature_raw_from_artifact_or_reconstruct(ds, j, member_name, transforms[j]).astype(float) 

2451 row_mask &= (raw_j >= 0.5) 

2452 

2453 for name, val in kw_num_raw.items(): 

2454 if name not in name_to_idx: 

2455 continue 

2456 j = name_to_idx[name] 

2457 if name in df_raw.columns: 

2458 raw_vals = pd.to_numeric(df_raw[name], errors="coerce").to_numpy(dtype=float) 

2459 else: 

2460 raw_vals = feature_raw_from_artifact_or_reconstruct(ds, j, feature_names[j], transforms[j]).astype(float) 

2461 mask = np.isfinite(raw_vals) 

2462 if isinstance(val, slice): 

2463 lo_raw = -np.inf if val.start is None else float(val.start) 

2464 hi_raw = np.inf if val.stop is None else float(val.stop) 

2465 if hi_raw < lo_raw: 

2466 lo_raw, hi_raw = hi_raw, lo_raw 

2467 mask &= (raw_vals >= lo_raw) & (raw_vals <= hi_raw) 

2468 elif isinstance(val, (list, tuple, set, np.ndarray)): 

2469 arr = np.asarray(list(val) if not isinstance(val, np.ndarray) else val, dtype=float) 

2470 arr = arr[np.isfinite(arr)] 

2471 if arr.size == 0: 

2472 mask &= False 

2473 else: 

2474 mask &= np.any(np.isclose(raw_vals[:, None], arr[None, :], rtol=1e-6, atol=1e-9), axis=1) 

2475 else: 

2476 target = float(val) 

2477 mask &= np.isclose(raw_vals, target, rtol=1e-6, atol=1e-9) 

2478 row_mask &= mask 

2479 

2480 if not np.any(row_mask): 

2481 raise ValueError("No experiments match the provided constraints; nothing to plot.") 

2482 

2483 df_raw_f = df_raw.loc[row_mask].reset_index(drop=True) 

2484 

2485 opt_df = opt.optimal(model, count=1, seed=seed, **kwargs) 

2486 x_opt_std = np.zeros(len(feature_names), dtype=float) 

2487 

2488 def _to_std_single(j: int, x_orig: float) -> float: 

2489 xi = x_orig 

2490 if transforms[j] == "log10": 

2491 xi = np.log10(np.maximum(x_orig, 1e-300)) 

2492 return float((xi - X_mean[j]) / X_std[j]) 

2493 

2494 onehot_members: set[str] = set() 

2495 for base, g in groups.items(): 

2496 onehot_members.update(g["members"]) 

2497 

2498 for j, name in enumerate(feature_names): 

2499 if name in onehot_members: 

2500 continue 

2501 if name in opt_df.columns: 

2502 x_opt_std[j] = _to_std_single(j, float(opt_df.iloc[0][name])) 

2503 else: 

2504 x_opt_std[j] = float(np.median(Xn_train[:, j])) 

2505 

2506 for base, g in groups.items(): 

2507 if base in cat_fixed: 

2508 label = cat_fixed[base] 

2509 elif base in opt_df.columns: 

2510 label = str(opt_df.iloc[0][base]) 

2511 else: 

2512 if base in df_raw.columns: 

2513 label = str(df_raw[base].astype("string").mode(dropna=True).iloc[0]) 

2514 else: 

2515 label = g["labels"][0] 

2516 for lab in g["labels"]: 

2517 member_name = g["name_by_label"][lab] 

2518 j = name_to_idx[member_name] 

2519 raw = 1.0 if lab == label else 0.0 

2520 x_opt_std[j] = (raw - X_mean[j]) / X_std[j] 

2521 

2522 def _orig_to_std(j: int, x, transforms, mu, sd): 

2523 x = np.asarray(x, dtype=float) 

2524 if transforms[j] == "log10": 

2525 x = np.where(x <= 0, np.nan, x) 

2526 x = np.log10(x) 

2527 return (x - mu[j]) / sd[j] 

2528 

2529 fixed_scalars_std: dict[int, float] = {} 

2530 range_windows_std: dict[int, tuple[float, float]] = {} 

2531 choice_values_std: dict[int, np.ndarray] = {} 

2532 

2533 for name, val in kw_num_raw.items(): 

2534 if name not in name_to_idx: 

2535 continue 

2536 j = name_to_idx[name] 

2537 if isinstance(val, slice): 

2538 lo = _orig_to_std(j, float(val.start), transforms, X_mean, X_std) 

2539 hi = _orig_to_std(j, float(val.stop), transforms, X_mean, X_std) 

2540 lo, hi = float(min(lo, hi)), float(max(lo, hi)) 

2541 range_windows_std[j] = (lo, hi) 

2542 elif isinstance(val, (list, tuple, np.ndarray)): 

2543 arr = _orig_to_std(j, np.asarray(val, dtype=float), transforms, X_mean, X_std) 

2544 choice_values_std[j] = np.asarray(arr, dtype=float) 

2545 else: 

2546 fixed_scalars_std[j] = float(_orig_to_std(j, float(val), transforms, X_mean, X_std)) 

2547 

2548 for j, v in fixed_scalars_std.items(): 

2549 x_opt_std[j] = v 

2550 

2551 free_numeric_idx = [ 

2552 j for j, name in enumerate(feature_names) 

2553 if (j not in fixed_scalars_std) and (name not in onehot_members) 

2554 ] 

2555 if len(free_numeric_idx) == 0: 

2556 raise ValueError("All numeric features are fixed at the optimum; nothing to plot.") 

2557 

2558 grids_std_num: dict[int, np.ndarray] = {} 

2559 raw_full_cache: dict[int, np.ndarray] = {} 

2560 Xn_p01 = np.percentile(Xn_train, 1, axis=0) 

2561 Xn_p99 = np.percentile(Xn_train, 99, axis=0) 

2562 

2563 def _grid_std_num(j: int) -> np.ndarray: 

2564 lo, hi = float(Xn_p01[j]), float(Xn_p99[j]) 

2565 if j in range_windows_std: 

2566 lo = max(lo, range_windows_std[j][0]) 

2567 hi = min(hi, range_windows_std[j][1]) 

2568 if j in choice_values_std: 

2569 vals = np.asarray(choice_values_std[j], dtype=float) 

2570 vals = vals[(vals >= lo) & (vals <= hi)] 

2571 return np.unique(np.sort(vals)) if vals.size else np.array([x_opt_std[j]], dtype=float) 

2572 if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo: 

2573 lo, hi = x_opt_std[j] - 1.0, x_opt_std[j] + 1.0 

2574 return np.linspace(lo, hi, grid_size) 

2575 

2576 for j in free_numeric_idx: 

2577 if feature_names[j] in df_raw.columns: 

2578 raw_vals = pd.to_numeric(df_raw[feature_names[j]], errors="coerce").to_numpy(dtype=float) 

2579 else: 

2580 raw_vals = feature_raw_from_artifact_or_reconstruct(ds, j, feature_names[j], transforms[j]).astype(float) 

2581 raw_full_cache[j] = raw_vals 

2582 grid = _grid_std_num(j) 

2583 if (j not in range_windows_std) and (j not in choice_values_std): 

2584 finite_raw = raw_vals[np.isfinite(raw_vals)] 

2585 if transforms[j] == "log10": 

2586 finite_raw = finite_raw[finite_raw > 0] 

2587 if finite_raw.size: 

2588 finite_std = _orig_to_std(j, finite_raw, transforms, X_mean, X_std) 

2589 grid_min = float(np.nanmin(np.concatenate([grid, finite_std]))) 

2590 grid_max = float(np.nanmax(np.concatenate([grid, finite_std]))) 

2591 if grid_max > grid_min: 

2592 grid = np.linspace(grid_min, grid_max, grid_size) 

2593 grids_std_num[j] = grid 

2594 

2595 subplot_titles = [feature_names[j] for j in free_numeric_idx] 

2596 k = len(free_numeric_idx) 

2597 fig = make_subplots( 

2598 rows=k, 

2599 cols=k, 

2600 shared_xaxes=True, 

2601 shared_yaxes=True, 

2602 horizontal_spacing=0.01, 

2603 vertical_spacing=0.01, 

2604 subplot_titles=subplot_titles, 

2605 ) 

2606 

2607 optimal_df = opt_df.copy() if optimal else None 

2608 suggest_df = opt.suggest(ds, count=suggest, seed=seed, **kwargs) if (suggest and suggest > 0) else None 

2609 

2610 tgt_col = str(ds.attrs.get("target", "target")) 

2611 success_mask = ~pd.isna(df_raw_f[tgt_col]).to_numpy() 

2612 fail_mask = ~success_mask 

2613 

2614 all_blocks: list[np.ndarray] = [] 

2615 cell_payload: dict[tuple[int, int], dict] = {} 

2616 base_std = x_opt_std.copy() 

2617 

2618 def _denorm_inv_opt(j: int, std_vals: np.ndarray) -> np.ndarray: 

2619 internal = std_vals * X_std[j] + X_mean[j] 

2620 return _inverse_transform(transforms[j], internal) 

2621 

2622 for row_idx, i in enumerate(free_numeric_idx): 

2623 for col_idx, j in enumerate(free_numeric_idx): 

2624 xg = grids_std_num[j] 

2625 yg = grids_std_num[i] 

2626 if i == j: 

2627 grid = grids_std_num[j] 

2628 Xn_1d = np.repeat(base_std[None, :], len(grid), axis=0) 

2629 Xn_1d[:, j] = grid 

2630 mu_1d, _ = pred_loss(Xn_1d, include_observation_noise=True) 

2631 p_1d = pred_success(Xn_1d) 

2632 Zmu = 0.5 * (mu_1d[:, None] + mu_1d[None, :]) 

2633 Zp = np.minimum(p_1d[:, None], p_1d[None, :]) 

2634 x_orig = _denorm_inv_opt(j, grid) 

2635 y_orig = x_orig 

2636 else: 

2637 XX, YY = np.meshgrid(xg, yg) 

2638 Xn_grid = np.repeat(base_std[None, :], XX.size, axis=0) 

2639 Xn_grid[:, j] = XX.ravel() 

2640 Xn_grid[:, i] = YY.ravel() 

2641 mu_flat, _ = pred_loss(Xn_grid, include_observation_noise=True) 

2642 p_flat = pred_success(Xn_grid) 

2643 Zmu = mu_flat.reshape(YY.shape) 

2644 Zp = p_flat.reshape(YY.shape) 

2645 x_orig = _denorm_inv_opt(j, xg) 

2646 y_orig = _denorm_inv_opt(i, yg) 

2647 cell_payload[(row_idx, col_idx)] = dict(i=i, j=j, x=x_orig, y=y_orig, Zmu=Zmu, Zp=Zp) 

2648 all_blocks.append(Zmu.ravel()) 

2649 

2650 def _color_xform(z_raw: np.ndarray) -> tuple[np.ndarray, float]: 

2651 if not use_log_scale_for_target: 

2652 return z_raw, 0.0 

2653 zmin = float(np.nanmin(z_raw)) 

2654 shift = 0.0 if zmin > 0 else -zmin + float(log_shift_epsilon) 

2655 return np.log10(np.maximum(z_raw + shift, log_shift_epsilon)), shift 

2656 

2657 z_all = np.concatenate(all_blocks) if all_blocks else np.array([0.0, 1.0]) 

2658 z_all_t, global_shift = _color_xform(z_all) 

2659 cmin_t = float(np.nanmin(z_all_t)) 

2660 cmax_t = float(np.nanmax(z_all_t)) 

2661 cs = get_colorscale(colorscale) 

2662 

2663 def _contour_line_color(level_raw: float) -> str: 

2664 zt = np.log10(max(level_raw + global_shift, log_shift_epsilon)) if use_log_scale_for_target else level_raw 

2665 t = 0.5 if cmax_t == cmin_t else (zt - cmin_t) / (cmax_t - cmin_t) 

2666 rgb = sample_colorscale(cs, [float(np.clip(t, 0.0, 1.0))])[0] 

2667 r, g, b = _rgb_string_to_tuple(rgb) 

2668 lum = (0.2126*r + 0.7152*g + 0.0722*b)/255.0 

2669 grey = int(round((1.0 - lum) * 255)) 

2670 return f"rgba({grey},{grey},{grey},0.9)" 

2671 

2672 def _data_vals_for_feature(j_full: int) -> np.ndarray: 

2673 name = feature_names[j_full] 

2674 if name in df_raw_f.columns: 

2675 return df_raw_f[name].to_numpy(dtype=float) 

2676 vals = feature_raw_from_artifact_or_reconstruct(ds, j_full, name, transforms[j_full]).astype(float) 

2677 return vals[row_mask] 

2678 

2679 for (r, c), payload in cell_payload.items(): 

2680 Zmu_raw = payload["Zmu"] 

2681 Zp = payload["Zp"] 

2682 Z_t, _ = _color_xform(Zmu_raw) 

2683 x_vals = payload["x"] 

2684 y_vals = payload["y"] 

2685 if payload["i"] == payload["j"]: 

2686 diag_vals = np.asarray(x_vals, dtype=float) 

2687 x_vals = diag_vals 

2688 y_vals = diag_vals 

2689 fig.add_trace(go.Heatmap( 

2690 x=x_vals, y=y_vals, z=Z_t, 

2691 coloraxis="coloraxis", zsmooth=False, showscale=False, 

2692 hovertemplate=(f"{feature_names[payload['j']]}: %{{x:.6g}}<br>" 

2693 f"{feature_names[payload['i']]}: %{{y:.6g}}" 

2694 "<br>E[target|success]: %{customdata:.3f}<extra></extra>"), 

2695 customdata=Zmu_raw 

2696 ), row=r+1, col=c+1) 

2697 

2698 for thr, alpha in ((0.5, 0.25), (0.8, 0.40)): 

2699 mask = np.where(Zp < thr, 1.0, np.nan) 

2700 fig.add_trace(go.Heatmap( 

2701 x=x_vals, y=y_vals, z=mask, zmin=0, zmax=1, 

2702 colorscale=[[0, "rgba(0,0,0,0)"], [1, f"rgba(128,128,128,{alpha})"]], 

2703 showscale=False, hoverinfo="skip" 

2704 ), row=r+1, col=c+1) 

2705 

2706 zmin_r, zmax_r = float(np.nanmin(Zmu_raw)), float(np.nanmax(Zmu_raw)) 

2707 levels = np.linspace(zmin_r, zmax_r, max(n_contours, 2)) 

2708 for lev in levels: 

2709 color = _contour_line_color(lev) 

2710 fig.add_trace(go.Contour( 

2711 x=x_vals, y=y_vals, z=Zmu_raw, 

2712 autocontour=False, 

2713 contours=dict(coloring="lines", showlabels=False, start=lev, end=lev, size=1e-9), 

2714 line=dict(width=1), 

2715 colorscale=[[0, color], [1, color]], 

2716 showscale=False, hoverinfo="skip" 

2717 ), row=r+1, col=c+1) 

2718 

2719 xd = _data_vals_for_feature(payload["j"]) 

2720 yd = _data_vals_for_feature(payload["i"]) 

2721 show_leg = (r == 0 and c == 0) 

2722 fig.add_trace(go.Scattergl( 

2723 x=xd[success_mask], y=yd[success_mask], mode="markers", 

2724 marker=dict(size=4, color="black", line=dict(width=0)), 

2725 name="data (success)", legendgroup="data_succ", showlegend=show_leg, 

2726 hovertemplate=("trial_id: %{customdata[0]}<br>" 

2727 f"{feature_names[payload['j']]}: %{{x:.6g}}<br>" 

2728 f"{feature_names[payload['i']]}: %{{y:.6g}}<br>" 

2729 f"{tgt_col}: %{{customdata[1]:.4f}}<extra></extra>"), 

2730 customdata=np.column_stack([ 

2731 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask], 

2732 df_raw_f[tgt_col].to_numpy()[success_mask], 

2733 ]) 

2734 ), row=r+1, col=c+1) 

2735 fig.add_trace(go.Scattergl( 

2736 x=xd[fail_mask], y=yd[fail_mask], mode="markers", 

2737 marker=dict(size=5, color="red", line=dict(color="black", width=0.8)), 

2738 name="data (failed)", legendgroup="data_fail", showlegend=show_leg, 

2739 hovertemplate=("trial_id: %{customdata}<br>" 

2740 f"{feature_names[payload['j']]}: %{{x:.6g}}<br>" 

2741 f"{feature_names[payload['i']]}: %{{y:.6g}}<br>" 

2742 "status: failed (NaN target)<extra></extra>"), 

2743 customdata=df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask] 

2744 ), row=r+1, col=c+1) 

2745 

2746 if ( 

2747 optimal 

2748 and optimal_df is not None 

2749 and feature_names[payload["j"]] in optimal_df.columns 

2750 and feature_names[payload["i"]] in optimal_df.columns 

2751 ): 

2752 ox = np.asarray(optimal_df[feature_names[payload["j"]]].values, dtype=float) 

2753 oy = np.asarray(optimal_df[feature_names[payload["i"]]].values, dtype=float) 

2754 pmu = float(optimal_df["pred_target_mean"].values[0]) 

2755 psd = float(optimal_df.get("pred_target_sd", pd.Series([np.nan])).values[0]) 

2756 fig.add_trace(go.Scattergl( 

2757 x=ox, y=oy, mode="markers", 

2758 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"), 

2759 name="optimal", legendgroup="optimal", showlegend=show_leg, 

2760 hovertemplate=(f"predicted: {pmu:.3g}" 

2761 + ("" if np.isnan(psd) else f" ± {psd:.3g}") 

2762 + f"<br>{feature_names[payload['j']]}: %{{x:.6g}}" 

2763 f"<br>{feature_names[payload['i']]}: %{{y:.6g}}<extra></extra>") 

2764 ), row=r+1, col=c+1) 

2765 

2766 if ( 

2767 suggest 

2768 and suggest_df is not None 

2769 and feature_names[payload["j"]] in suggest_df.columns 

2770 and feature_names[payload["i"]] in suggest_df.columns 

2771 ): 

2772 xs = suggest_df[feature_names[payload["j"]]].values.astype(float) 

2773 ys = suggest_df[feature_names[payload["i"]]].values.astype(float) 

2774 ymu = suggest_df["pred_target_mean"].values.astype(float) 

2775 ysd = suggest_df.get("pred_target_sd", pd.Series([np.nan]*len(suggest_df))).values 

2776 fig.add_trace(go.Scattergl( 

2777 x=xs, y=ys, mode="markers", 

2778 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"), 

2779 name="suggested", legendgroup="suggested", showlegend=show_leg, 

2780 hovertemplate=("predicted: %{customdata[0]:.3f}" 

2781 + (" ± %{customdata[1]:.3f}" if not np.isnan(ysd).all() else "") 

2782 + f"<br>{feature_names[payload['j']]}: %{{x:.6g}}" 

2783 f"<br>{feature_names[payload['i']]}: %{{y:.6g}}<extra>suggested</extra>"), 

2784 customdata=np.column_stack([ymu, ysd]) 

2785 ), row=r+1, col=c+1) 

2786 

2787 _maybe_log_axis(fig, row=r+1, col=c+1, name=feature_names[payload["j"]], axis="x", transforms=transforms, j=payload["j"]) 

2788 _maybe_log_axis(fig, row=r+1, col=c+1, name=feature_names[payload["i"]], axis="y", transforms=transforms, j=payload["i"]) 

2789 if r == k - 1: 

2790 fig.update_xaxes(title_text=feature_names[payload["j"]], row=r+1, col=c+1) 

2791 else: 

2792 fig.update_xaxes(tickmode=None, row=r+1, col=c+1) 

2793 if c == 0: 

2794 fig.update_yaxes(title_text=feature_names[payload["i"]], row=r+1, col=c+1) 

2795 else: 

2796 fig.update_yaxes(tickmode=None, row=r+1, col=c+1) 

2797 

2798 if payload["j"] in raw_full_cache and payload["j"] not in range_windows_std and payload["j"] not in choice_values_std: 

2799 finite_raw = raw_full_cache[payload["j"]][np.isfinite(raw_full_cache[payload["j"]])] 

2800 if transforms[payload["j"]] == "log10": 

2801 finite_raw = finite_raw[finite_raw > 0] 

2802 if finite_raw.size: 

2803 x0 = float(np.min(finite_raw)); x1 = float(np.max(finite_raw)) 

2804 if transforms[payload["j"]] == "log10": 

2805 x0 = max(x0, 1e-12); x1 = max(x1, x0 * (1 + 1e-9)) 

2806 pad = (x1 / x0) ** 0.03 

2807 fig.update_xaxes(range=[np.log10(x0 / pad), np.log10(x1 * pad)], row=r+1, col=c+1) 

2808 else: 

2809 span = (x1 - x0) or 1.0 

2810 pad = 0.02 * span 

2811 fig.update_xaxes(range=[x0 - pad, x1 + pad], row=r+1, col=c+1) 

2812 

2813 if payload["i"] in raw_full_cache and payload["i"] not in range_windows_std and payload["i"] not in choice_values_std: 

2814 finite_raw = raw_full_cache[payload["i"]][np.isfinite(raw_full_cache[payload["i"]])] 

2815 if transforms[payload["i"]] == "log10": 

2816 finite_raw = finite_raw[finite_raw > 0] 

2817 if finite_raw.size: 

2818 y0 = float(np.min(finite_raw)); y1 = float(np.max(finite_raw)) 

2819 if transforms[payload["i"]] == "log10": 

2820 y0 = max(y0, 1e-12); y1 = max(y1, y0 * (1 + 1e-9)) 

2821 pad = (y1 / y0) ** 0.03 

2822 fig.update_yaxes(range=[np.log10(y0 / pad), np.log10(y1 * pad)], row=r+1, col=c+1) 

2823 else: 

2824 span = (y1 - y0) or 1.0 

2825 pad = 0.02 * span 

2826 fig.update_yaxes(range=[y0 - pad, y1 + pad], row=r+1, col=c+1) 

2827 

2828 z_title = "E[target|success]" + (" (log10)" if use_log_scale_for_target else "") 

2829 if use_log_scale_for_target and global_shift > 0: 

2830 z_title += f" (shift Δ={global_shift:.3g})" 

2831 

2832 width = width if (width and width > 0) else 1100 

2833 height = height if (height and height > 0) else 1100 

2834 fig.update_layout( 

2835 height=height, 

2836 width=width, 

2837 template="simple_white", 

2838 coloraxis=dict( 

2839 colorscale=colorscale, 

2840 cmin=cmin_t, cmax=cmax_t, 

2841 colorbar=dict( 

2842 title=z_title, 

2843 thickness=10, # thinner bar 

2844 len=0.55, # shorter bar (fraction of plot height) 

2845 lenmode="fraction", 

2846 x=1.02, y=0.5, # just right of plot, vertically centered 

2847 xanchor="left", yanchor="middle", 

2848 ), 

2849 ), 

2850 legend=dict( 

2851 orientation="v", 

2852 x=1.02, xanchor="left", # to the right of the colorbar 

2853 y=1.0, yanchor="top", 

2854 bgcolor="rgba(255,255,255,0.85)" 

2855 ), 

2856 title=f"2D PD at optimal setting of all other hyperparameters ({tgt_col})", 

2857 legend_title_text="", 

2858 ) 

2859 

2860 if output: 

2861 write_image(fig, output) 

2862 if show: 

2863 fig.show("browser") 

2864 return fig 

2865 

2866 

2867def write_image(fig, output:Path|str): 

2868 """Write a Plotly figure to an image file (PNG, JPEG, etc). Requires kaleido.""" 

2869 output = Path(output) 

2870 output.parent.mkdir(parents=True, exist_ok=True) 

2871 if output.suffix.lower() == ".html": 

2872 fig.write_html(str(output)) 

2873 else: 

2874 fig.write_image(str(output))