Coverage for psyop/viz.py: 38.34%

1158 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-10 06:02 +0000

1# -*- coding: utf-8 -*- 

2from pathlib import Path 

3import numpy as np 

4import pandas as pd 

5import xarray as xr 

6 

7import plotly.graph_objects as go 

8from plotly.subplots import make_subplots 

9from plotly.colors import get_colorscale, sample_colorscale 

10 

11from .model import ( 

12 kernel_diag_m52, 

13 kernel_m52_ard, 

14 add_jitter, 

15 solve_chol, 

16 solve_lower, 

17 feature_raw_from_artifact_or_reconstruct, 

18) 

19from . import opt 

20 

21 

22def _canon_key_set(ds) -> dict[str, str]: 

23 feats = [str(x) for x in ds["feature"].values.tolist()] 

24 def _norm(s: str) -> str: 

25 import re 

26 return re.sub(r"[^a-z0-9]+", "", s.lower()) 

27 return {**{f: f for f in feats}, **{_norm(f): f for f in feats}} 

28 

29 

30def _edges_from_centers(vals: np.ndarray, is_log: bool) -> tuple[float, float]: 

31 """Return (min_edge, max_edge) that tightly bound a heatmap with given center coords.""" 

32 v = np.asarray(vals, float) 

33 v = v[np.isfinite(v)] 

34 if v.size == 0: 

35 return (0.0, 1.0) 

36 if v.size == 1: 

37 # tiny pad 

38 if is_log: 

39 lo = max(v[0] / 1.5, 1e-12) 

40 hi = v[0] * 1.5 

41 else: 

42 span = max(abs(v[0]) * 0.5, 1e-9) 

43 lo, hi = v[0] - span, v[0] + span 

44 return float(lo), float(hi) 

45 

46 if is_log: 

47 lv = np.log10(v) 

48 l0 = lv[0] - 0.5 * (lv[1] - lv[0]) 

49 lN = lv[-1] + 0.5 * (lv[-1] - lv[-2]) 

50 lo = 10.0 ** l0 

51 hi = 10.0 ** lN 

52 lo = max(lo, 1e-12) 

53 hi = max(hi, lo * 1.0000001) 

54 return float(lo), float(hi) 

55 else: 

56 d0 = v[1] - v[0] 

57 dN = v[-1] - v[-2] 

58 lo = v[0] - 0.5 * d0 

59 hi = v[-1] + 0.5 * dN 

60 if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo: 

61 lo, hi = float(v.min()), float(v.max()) 

62 return float(lo), float(hi) 

63 

64def _update_axis_type_and_range( 

65 fig, *, row: int, col: int, axis: str, centers: np.ndarray, is_log: bool 

66): 

67 """Set axis type and range to heatmap edges so tiles meet the axes exactly.""" 

68 lo, hi = _edges_from_centers(centers, is_log) 

69 if axis == "x": 

70 if is_log: 

71 fig.update_xaxes(type="log", range=[np.log10(lo), np.log10(hi)], row=row, col=col) 

72 else: 

73 fig.update_xaxes(range=[lo, hi], row=row, col=col) 

74 else: 

75 if is_log: 

76 fig.update_yaxes(type="log", range=[np.log10(lo), np.log10(hi)], row=row, col=col) 

77 else: 

78 fig.update_yaxes(range=[lo, hi], row=row, col=col) 

79 

80 

81def plot2d( 

82 model: xr.Dataset | Path | str, 

83 output: Path | None = None, 

84 grid_size: int = 70, 

85 use_log_scale_for_target: bool = False, # log10 colors for heatmaps 

86 log_shift_epsilon: float = 1e-9, 

87 colorscale: str = "RdBu", 

88 show: bool = False, 

89 n_contours: int = 12, 

90 optimal: bool = True, 

91 suggest: int = 0, 

92 seed: int|None = 42, 

93 width:int|None = None, 

94 height:int|None = None, 

95 **kwargs, 

96) -> go.Figure: 

97 """ 

98 2D Partial Dependence of E[target|success] (pairwise features), including 

99 categorical variables as single axes (one row/column per base). 

100 

101 Conditioning via kwargs (original units): 

102 - numeric: fixed scalar, slice(lo,hi), list/tuple choices 

103 - categorical base (e.g. language="Linear A" or language=("Linear A","Linear B")): 

104 * single string: fixed to that label (axis removed and clamped) 

105 * list/tuple of labels: restrict the categorical axis to those labels 

106 

107 Notes: 

108 * one-hot member features (e.g. language=Linear A) never appear as axes. 

109 * when a categorical axis is present, we render a heatmap over category index 

110 with tick labels set to the category names; data overlays use jitter. 

111 """ 

112 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model) 

113 pred_success, pred_loss = _build_predictors(ds) 

114 

115 # --- features & transforms 

116 feature_names = [str(x) for x in ds["feature"].values.tolist()] 

117 transforms = [str(t) for t in ds["feature_transform"].values.tolist()] 

118 X_mean = ds["feature_mean"].values.astype(float) 

119 X_std = ds["feature_std"].values.astype(float) 

120 name_to_idx = {nm: i for i, nm in enumerate(feature_names)} 

121 

122 # one-hot groups 

123 groups = opt._onehot_groups(feature_names) # { base: {"labels":[...], "name_by_label":{label->member}, "members":[...] } } 

124 bases = set(groups.keys()) 

125 onehot_member_names = {m for g in groups.values() for m in g["members"]} 

126 

127 # raw df + train design 

128 df_raw = _raw_dataframe_from_dataset(ds) 

129 Xn_train = ds["Xn_train"].values.astype(float) 

130 n_rows = Xn_train.shape[0] 

131 

132 # --- split kwargs into numeric vs categorical (keys are canonical already when coming from CLI) 

133 kw_num: dict[str, object] = {} 

134 kw_cat: dict[str, object] = {} 

135 for k, v in (kwargs or {}).items(): 

136 if k in bases: 

137 kw_cat[k] = v 

138 elif k in name_to_idx: 

139 kw_num[k] = v 

140 else: 

141 # unknown/ignored 

142 pass 

143 

144 # --- resolve categorical constraints: 

145 # - cat_fixed[base] -> fixed single label (axis removed and clamped) 

146 # - cat_allowed[base] -> labels that are allowed on that axis (if not fixed) 

147 cat_fixed: dict[str, str] = {} 

148 cat_allowed: dict[str, list[str]] = {} 

149 for base in bases: 

150 labels = list(groups[base]["labels"]) 

151 if base not in kw_cat: 

152 cat_allowed[base] = labels # unrestricted axis (if not fixed numerically later) 

153 continue 

154 val = kw_cat[base] 

155 if isinstance(val, str): 

156 if val not in labels: 

157 raise ValueError(f"Unknown category for {base!r}: {val!r}. Choices: {labels}") 

158 cat_fixed[base] = val 

159 elif isinstance(val, (list, tuple, set)): 

160 chosen = [x for x in val if isinstance(x, str) and x in labels] 

161 if not chosen: 

162 raise ValueError(f"No valid categories for {base!r} in {val!r}. Choices: {labels}") 

163 # if only one remains, treat as fixed; else allowed list 

164 if len(chosen) == 1: 

165 cat_fixed[base] = chosen[0] 

166 else: 

167 cat_allowed[base] = chosen 

168 else: 

169 raise ValueError(f"Categorical constraint for {base!r} must be a string or list/tuple of strings.") 

170 

171 # --- filter rows to categorical *fixed* selections for medians/percentiles & overlays 

172 row_mask = np.ones(n_rows, dtype=bool) 

173 for base, label in cat_fixed.items(): 

174 if base in df_raw.columns: 

175 row_mask &= (df_raw[base].astype("string") == pd.Series([label]*len(df_raw), dtype="string")).to_numpy() 

176 else: 

177 member = groups[base]["name_by_label"][label] 

178 j = name_to_idx[member] 

179 raw_j = feature_raw_from_artifact_or_reconstruct(ds, j, member, transforms[j]).astype(float) 

180 row_mask &= (raw_j >= 0.5) 

181 

182 df_raw_f = df_raw.loc[row_mask].reset_index(drop=True) if cat_fixed else df_raw 

183 Xn_train_f = Xn_train[row_mask, :] if cat_fixed else Xn_train 

184 

185 # --- numeric constraints (standardized) 

186 def _orig_to_std(j: int, x, transforms, mu, sd): 

187 x = np.asarray(x, dtype=float) 

188 if transforms[j] == "log10": 

189 x = np.where(x <= 0, np.nan, x) 

190 x = np.log10(x) 

191 return (x - mu[j]) / sd[j] 

192 

193 fixed_scalars_std: dict[int, float] = {} 

194 range_windows_std: dict[int, tuple[float, float]] = {} 

195 choice_values_std: dict[int, np.ndarray] = {} 

196 

197 for name, val in kw_num.items(): 

198 j = name_to_idx[name] 

199 if isinstance(val, slice): 

200 lo = _orig_to_std(j, float(val.start), transforms, X_mean, X_std) 

201 hi = _orig_to_std(j, float(val.stop), transforms, X_mean, X_std) 

202 lo, hi = float(min(lo, hi)), float(max(lo, hi)) 

203 range_windows_std[j] = (lo, hi) 

204 elif isinstance(val, (list, tuple, np.ndarray)): 

205 arr = _orig_to_std(j, np.asarray(val, dtype=float), transforms, X_mean, X_std) 

206 choice_values_std[j] = np.asarray(arr, dtype=float) 

207 else: 

208 fixed_scalars_std[j] = float(_orig_to_std(j, float(val), transforms, X_mean, X_std)) 

209 

210 # --- apply categorical *fixed* selections as standardized 0/1 on their member features 

211 for base, label in cat_fixed.items(): 

212 labels = groups[base]["labels"] 

213 for lab in labels: 

214 member = groups[base]["name_by_label"][lab] 

215 j = name_to_idx[member] 

216 raw_val = 1.0 if (lab == label) else 0.0 

217 fixed_scalars_std[j] = float(_orig_to_std(j, raw_val, transforms, X_mean, X_std)) 

218 

219 # --- free axes = numeric features not scalar-fixed & not one-hot members, plus categorical bases not fixed 

220 free_numeric_idx = [ 

221 j for j, nm in enumerate(feature_names) 

222 if (j not in fixed_scalars_std) and (nm not in onehot_member_names) 

223 ] 

224 free_cat_bases = [b for b in bases if b not in cat_fixed] # we already filtered by allowed above 

225 

226 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases] 

227 if not panels: 

228 raise ValueError("All features are fixed (or only single-category categoricals remain); nothing to plot.") 

229 

230 # --- base point (median in standardized space of filtered rows), then apply scalar fixes 

231 base_std = np.median(Xn_train_f, axis=0) 

232 for j, vstd in fixed_scalars_std.items(): 

233 base_std[j] = vstd 

234 

235 # --- per-feature grids (numeric) over filtered 1–99% + respecting ranges/choices 

236 p01p99 = [np.percentile(Xn_train_f[:, j], [1, 99]) for j in range(len(feature_names))] 

237 def _grid_std_num(j: int) -> np.ndarray: 

238 p01, p99 = p01p99[j] 

239 if j in choice_values_std: 

240 vals = np.asarray(choice_values_std[j], dtype=float) 

241 vals = vals[(vals >= p01) & (vals <= p99)] 

242 return np.unique(np.sort(vals)) if vals.size else np.array([np.median(Xn_train_f[:, j])]) 

243 lo, hi = p01, p99 

244 if j in range_windows_std: 

245 rlo, rhi = range_windows_std[j] 

246 lo, hi = max(lo, rlo), min(hi, rhi) 

247 if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo: 

248 hi = lo + 1e-9 

249 return np.linspace(lo, hi, grid_size) 

250 

251 grids_std_num = {j: _grid_std_num(j) for j in free_numeric_idx} 

252 

253 # --- helpers for categorical evaluation --------------------------------- 

254 def _std_for_member(member_name: str, raw01: float) -> float: 

255 j = name_to_idx[member_name] 

256 return float(_orig_to_std(j, raw01, transforms, X_mean, X_std)) 

257 

258 def _apply_onehot_for_base(Xn_block: np.ndarray, base: str, label: str) -> None: 

259 # set the whole block's rows to the 0/1 standardized values for this label 

260 for lab in groups[base]["labels"]: 

261 member = groups[base]["name_by_label"][lab] 

262 j = name_to_idx[member] 

263 Xn_block[:, j] = _std_for_member(member, 1.0 if lab == label else 0.0) 

264 

265 def _denorm_inv(j: int, std_vals: np.ndarray) -> np.ndarray: 

266 internal = std_vals * X_std[j] + X_mean[j] 

267 return _inverse_transform(transforms[j], internal) 

268 

269 # 1) Robustly detect one-hot member columns. 

270 # Use both the detector output AND a fallback "base=" prefix scan, 

271 # so any columns like "language=Linear A" are guaranteed to be excluded. 

272 onehot_member_names: set[str] = set() 

273 for base, g in groups.items(): 

274 # detector-known members 

275 onehot_member_names.update(g["members"]) 

276 # prefix fallback 

277 prefix = f"{base}=" 

278 onehot_member_names.update([nm for nm in feature_names if nm.startswith(prefix)]) 

279 

280 # 2) Build panel list: keep numeric features that are not scalar-fixed AND 

281 # are not one-hot members; plus categorical bases that are not fixed. 

282 free_numeric_idx = [ 

283 j for j, nm in enumerate(feature_names) 

284 if (j not in fixed_scalars_std) and (nm not in onehot_member_names) 

285 ] 

286 free_cat_bases = [b for b in bases if b not in cat_fixed] 

287 

288 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases] 

289 if not panels: 

290 raise ValueError("All features are fixed (or only single-category categoricals remain); nothing to plot.") 

291 

292 # 3) Sanity check: no one-hot member should survive as a numeric panel. 

293 assert all( 

294 (feature_names[key] not in onehot_member_names) if kind == "num" else True 

295 for kind, key in panels 

296 ), "internal: one-hot member leaked into numeric panels" 

297 

298 # 4) Subplot scaffold (matrix layout k x k) with clear titles. 

299 def _panel_title(kind: str, key: object) -> str: 

300 return feature_names[int(key)] if kind == "num" else str(key) 

301 

302 k = len(panels) 

303 fig = make_subplots( 

304 rows=k, 

305 cols=k, 

306 shared_xaxes=False, 

307 shared_yaxes=False, 

308 horizontal_spacing=0.03, 

309 vertical_spacing=0.03, 

310 subplot_titles=[_panel_title(kind, key) for kind, key in panels], 

311 ) 

312 

313 # (Keep the rest of your cell-evaluation and rendering logic unchanged. 

314 # Because we filtered `onehot_member_names`, rows/columns like 

315 # "language=Linear A" / "language=Linear B" will no longer appear. 

316 # Categorical bases (e.g., "language") will show as a single axis.) 

317 

318 

319 # overlays prepared under the SAME constraints (pass original kwargs straight through) 

320 optimal_df = opt.optimal(ds, count=1, seed=seed, **kwargs) if optimal else None 

321 suggest_df = opt.suggest(ds, count=suggest, seed=seed, **kwargs) if (suggest and suggest > 0) else None 

322 

323 # masks for data overlays (already filtered if cat_fixed) 

324 tgt_col = str(ds.attrs["target"]) 

325 success_mask = ~pd.isna(df_raw_f[tgt_col]).to_numpy() 

326 fail_mask = ~success_mask 

327 

328 # collect Z blocks for global color bounds 

329 all_blocks: list[np.ndarray] = [] 

330 cell_payload: dict[tuple[int,int], dict] = {} 

331 

332 # --- build each cell payload (numeric/num, cat/num, num/cat, cat/cat) 

333 for r, (kind_r, key_r) in enumerate(panels): 

334 for c, (kind_c, key_c) in enumerate(panels): 

335 # X axis = column; Y axis = row 

336 if kind_r == "num" and kind_c == "num": 

337 i = int(key_r); j = int(key_c) 

338 xg = grids_std_num[j]; yg = grids_std_num[i] 

339 if i == j: 

340 grid = grids_std_num[j] 

341 Xn_1d = np.repeat(base_std[None, :], len(grid), axis=0) 

342 Xn_1d[:, j] = grid 

343 mu_1d, _ = pred_loss(Xn_1d, include_observation_noise=True) 

344 p_1d = pred_success(Xn_1d) 

345 Zmu = 0.5 * (mu_1d[:, None] + mu_1d[None, :]) 

346 Zp = np.minimum(p_1d[:, None], p_1d[None, :]) 

347 x_orig = _denorm_inv(j, grid) 

348 y_orig = x_orig 

349 else: 

350 XX, YY = np.meshgrid(xg, yg) 

351 Xn_grid = np.repeat(base_std[None, :], XX.size, axis=0) 

352 Xn_grid[:, j] = XX.ravel() 

353 Xn_grid[:, i] = YY.ravel() 

354 mu_flat, _ = pred_loss(Xn_grid, include_observation_noise=True) 

355 p_flat = pred_success(Xn_grid) 

356 Zmu = mu_flat.reshape(YY.shape) 

357 Zp = p_flat.reshape(YY.shape) 

358 x_orig = _denorm_inv(j, xg) 

359 y_orig = _denorm_inv(i, yg) 

360 cell_payload[(r, c)] = dict(kind=("num","num"), i=i, j=j, x=x_orig, y=y_orig, Zmu=Zmu, Zp=Zp) 

361 

362 elif kind_r == "cat" and kind_c == "num": 

363 base = str(key_r); j = int(key_c) 

364 labels = list(cat_allowed.get(base, groups[base]["labels"])) 

365 xg = grids_std_num[j] 

366 # build rows per label 

367 Zmu_rows = []; Zp_rows = [] 

368 for lab in labels: 

369 Xn_grid = np.repeat(base_std[None, :], len(xg), axis=0) 

370 Xn_grid[:, j] = xg 

371 _apply_onehot_for_base(Xn_grid, base, lab) 

372 mu_row, _ = pred_loss(Xn_grid, include_observation_noise=True) 

373 p_row = pred_success(Xn_grid) 

374 Zmu_rows.append(mu_row[None, :]) 

375 Zp_rows.append(p_row[None, :]) 

376 Zmu = np.concatenate(Zmu_rows, axis=0) # (n_labels, n_x) 

377 Zp = np.concatenate(Zp_rows, axis=0) 

378 x_orig = _denorm_inv(j, xg) 

379 y_cats = labels # categorical ticks 

380 cell_payload[(r,c)] = dict(kind=("cat","num"), base=base, j=j, x=x_orig, y=y_cats, Zmu=Zmu, Zp=Zp) 

381 

382 elif kind_r == "num" and kind_c == "cat": 

383 i = int(key_r); base = str(key_c) 

384 labels = list(cat_allowed.get(base, groups[base]["labels"])) 

385 yg = grids_std_num[i] 

386 # columns per label 

387 Zmu_cols = []; Zp_cols = [] 

388 for lab in labels: 

389 Xn_grid = np.repeat(base_std[None, :], len(yg), axis=0) 

390 Xn_grid[:, i] = yg 

391 _apply_onehot_for_base(Xn_grid, base, lab) 

392 mu_col, _ = pred_loss(Xn_grid, include_observation_noise=True) 

393 p_col = pred_success(Xn_grid) 

394 Zmu_cols.append(mu_col[:, None]) 

395 Zp_cols.append(p_col[:, None]) 

396 Zmu = np.concatenate(Zmu_cols, axis=1) # (n_y, n_labels) 

397 Zp = np.concatenate(Zp_cols, axis=1) 

398 x_cats = labels 

399 y_orig = _denorm_inv(i, yg) 

400 cell_payload[(r,c)] = dict(kind=("num","cat"), i=i, base=base, x=x_cats, y=y_orig, Zmu=Zmu, Zp=Zp) 

401 

402 else: # kind_r == "cat" and kind_c == "cat" 

403 base_r = str(key_r); base_c = str(key_c) 

404 labels_r = list(cat_allowed.get(base_r, groups[base_r]["labels"])) 

405 labels_c = list(cat_allowed.get(base_c, groups[base_c]["labels"])) 

406 Z = np.zeros((len(labels_r), len(labels_c)), dtype=float) 

407 P = np.zeros_like(Z) 

408 # evaluate each pair 

409 for rr, lab_r in enumerate(labels_r): 

410 for cc, lab_c in enumerate(labels_c): 

411 Xn_grid = base_std[None, :].copy() 

412 _apply_onehot_for_base(Xn_grid, base_r, lab_r) 

413 _apply_onehot_for_base(Xn_grid, base_c, lab_c) 

414 mu_val, _ = pred_loss(Xn_grid, include_observation_noise=True) 

415 p_val = pred_success(Xn_grid) 

416 Z[rr, cc] = float(mu_val[0]) 

417 P[rr, cc] = float(p_val[0]) 

418 cell_payload[(r,c)] = dict(kind=("cat","cat"), x=labels_c, y=labels_r, Zmu=Z, Zp=P) 

419 

420 all_blocks.append(cell_payload[(r,c)]["Zmu"].ravel()) 

421 

422 # --- color transform bounds 

423 def _color_xform(z_raw: np.ndarray) -> tuple[np.ndarray, float]: 

424 if not use_log_scale_for_target: 

425 return z_raw, 0.0 

426 zmin = float(np.nanmin(z_raw)) 

427 shift = 0.0 if zmin > 0 else -zmin + float(log_shift_epsilon) 

428 return np.log10(np.maximum(z_raw + shift, log_shift_epsilon)), shift 

429 

430 z_all = np.concatenate(all_blocks) if all_blocks else np.array([0.0, 1.0]) 

431 z_all_t, global_shift = _color_xform(z_all) 

432 cmin_t = float(np.nanmin(z_all_t)) 

433 cmax_t = float(np.nanmax(z_all_t)) 

434 cs = get_colorscale(colorscale) 

435 

436 def _contour_line_color(level_raw: float) -> str: 

437 zt = np.log10(max(level_raw + global_shift, log_shift_epsilon)) if use_log_scale_for_target else level_raw 

438 t = 0.5 if cmax_t == cmin_t else (zt - cmin_t) / (cmax_t - cmin_t) 

439 rgb = sample_colorscale(cs, [float(np.clip(t, 0.0, 1.0))])[0] 

440 r, g, b = _rgb_string_to_tuple(rgb) 

441 lum = (0.2126*r + 0.7152*g + 0.0722*b)/255.0 

442 grey = int(round((1.0 - lum) * 255)) 

443 return f"rgba({grey},{grey},{grey},0.9)" 

444 

445 # --- render cells 

446 def _is_log_feature(j: int) -> bool: return (transforms[j] == "log10") 

447 

448 for (r, c), PAY in cell_payload.items(): 

449 kind = PAY["kind"]; Zmu_raw = PAY["Zmu"]; Zp = PAY["Zp"] 

450 Z_t, _ = _color_xform(Zmu_raw) 

451 

452 # axes values (numeric arrays or category indices) 

453 if kind == ("num","num"): 

454 x_vals = PAY["x"]; y_vals = PAY["y"] 

455 fig.add_trace(go.Heatmap( 

456 x=x_vals, y=y_vals, z=Z_t, 

457 coloraxis="coloraxis", zsmooth=False, showscale=False, 

458 hovertemplate=(f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>" 

459 f"{feature_names[PAY['i']]}: %{{y:.6g}}" 

460 "<br>E[target|success]: %{customdata:.3f}<extra></extra>"), 

461 customdata=Zmu_raw 

462 ), row=r+1, col=c+1) 

463 

464 # p(success) shading + contours 

465 for thr, alpha in ((0.5, 0.25), (0.8, 0.40)): 

466 mask = np.where(Zp < thr, 1.0, np.nan) 

467 fig.add_trace(go.Heatmap( 

468 x=x_vals, y=y_vals, z=mask, zmin=0, zmax=1, 

469 colorscale=[[0, "rgba(0,0,0,0)"], [1, f"rgba(128,128,128,{alpha})"]], 

470 showscale=False, hoverinfo="skip" 

471 ), row=r+1, col=c+1) 

472 

473 # contour lines 

474 zmin_r, zmax_r = float(np.nanmin(Zmu_raw)), float(np.nanmax(Zmu_raw)) 

475 levels = np.linspace(zmin_r, zmax_r, max(n_contours, 2)) 

476 for lev in levels: 

477 color = _contour_line_color(lev) 

478 fig.add_trace(go.Contour( 

479 x=x_vals, y=y_vals, z=Zmu_raw, 

480 autocontour=False, 

481 contours=dict(coloring="lines", showlabels=False, start=lev, end=lev, size=1e-9), 

482 line=dict(width=1), 

483 colorscale=[[0, color], [1, color]], 

484 showscale=False, hoverinfo="skip" 

485 ), row=r+1, col=c+1) 

486 

487 # data overlays (success/fail) 

488 def _data_vals_for_feature(j_full: int) -> np.ndarray: 

489 nm = feature_names[j_full] 

490 if nm in df_raw_f.columns: 

491 return df_raw_f[nm].to_numpy().astype(float) 

492 return feature_raw_from_artifact_or_reconstruct(ds, j_full, nm, transforms[j_full]).astype(float)[row_mask] \ 

493 if cat_fixed else \ 

494 feature_raw_from_artifact_or_reconstruct(ds, j_full, nm, transforms[j_full]).astype(float) 

495 

496 xd = _data_vals_for_feature(PAY["j"]) 

497 yd = _data_vals_for_feature(PAY["i"]) 

498 show_leg = (r == 0 and c == 0) 

499 fig.add_trace(go.Scattergl( 

500 x=xd[success_mask], y=yd[success_mask], mode="markers", 

501 marker=dict(size=4, color="black", line=dict(width=0)), 

502 name="data (success)", legendgroup="data_succ", showlegend=show_leg, 

503 hovertemplate=("trial_id: %{customdata[0]}<br>" 

504 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>" 

505 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>" 

506 f"{tgt_col}: %{{customdata[1]:.4f}}<extra></extra>"), 

507 customdata=np.column_stack([ 

508 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask], 

509 df_raw_f[tgt_col].to_numpy()[success_mask], 

510 ]) 

511 ), row=r+1, col=c+1) 

512 fig.add_trace(go.Scattergl( 

513 x=xd[fail_mask], y=yd[fail_mask], mode="markers", 

514 marker=dict(size=5, color="red", line=dict(color="black", width=0.8)), 

515 name="data (failed)", legendgroup="data_fail", showlegend=show_leg, 

516 hovertemplate=("trial_id: %{customdata}<br>" 

517 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>" 

518 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>" 

519 "status: failed (NaN target)<extra></extra>"), 

520 customdata=df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask] 

521 ), row=r+1, col=c+1) 

522 

523 # overlays (optimal/suggest) on numeric axes only 

524 if optimal and (optimal_df is not None): 

525 if feature_names[PAY["j"]] in optimal_df.columns and feature_names[PAY["i"]] in optimal_df.columns: 

526 ox = np.asarray(optimal_df[feature_names[PAY["j"]]].values, dtype=float) 

527 oy = np.asarray(optimal_df[feature_names[PAY["i"]]].values, dtype=float) 

528 if np.isfinite(ox).all() and np.isfinite(oy).all(): 

529 pmu = float(optimal_df["pred_target_mean"].values[0]) 

530 psd = float(optimal_df["pred_target_sd"].values[0]) 

531 fig.add_trace(go.Scattergl( 

532 x=ox, y=oy, mode="markers", 

533 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"), 

534 name="optimal", legendgroup="optimal", showlegend=(r == 0 and c == 0), 

535 hovertemplate=(f"predicted: {pmu:.2g} ± {psd:.2g}<br>" 

536 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>" 

537 f"{feature_names[PAY['i']]}: %{{y:.6g}}<extra></extra>") 

538 ), row=r+1, col=c+1) 

539 if suggest and (suggest_df is not None): 

540 have = (feature_names[PAY["j"]] in suggest_df.columns) and (feature_names[PAY["i"]] in suggest_df.columns) 

541 if have: 

542 sx = np.asarray(suggest_df[feature_names[PAY["j"]]].values, dtype=float) 

543 sy = np.asarray(suggest_df[feature_names[PAY["i"]]].values, dtype=float) 

544 keep_s = np.isfinite(sx) & np.isfinite(sy) 

545 if keep_s.any(): 

546 sx, sy = sx[keep_s], sy[keep_s] 

547 mu_s = suggest_df.loc[keep_s, "pred_target_mean"].values if "pred_target_mean" in suggest_df else None 

548 sd_s = suggest_df.loc[keep_s, "pred_target_sd"].values if "pred_target_sd" in suggest_df else None 

549 ps_s = suggest_df.loc[keep_s, "pred_p_success"].values if "pred_p_success" in suggest_df else None 

550 if (mu_s is not None) and (sd_s is not None) and (ps_s is not None): 

551 custom_s = np.column_stack([mu_s, sd_s, ps_s]) 

552 hover_s = ( 

553 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>" 

554 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>" 

555 "pred: %{customdata[0]:.3g} ± %{customdata[1]:.3g}<br>" 

556 "p(success): %{customdata[2]:.2f}<extra>suggested</extra>" 

557 ) 

558 else: 

559 custom_s = None 

560 hover_s = ( 

561 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>" 

562 f"{feature_names[PAY['i']]}: %{{y:.6g}}<extra>suggested</extra>" 

563 ) 

564 fig.add_trace(go.Scattergl( 

565 x=sx, y=sy, mode="markers", 

566 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"), 

567 name="suggested", legendgroup="suggested", 

568 showlegend=(r == 0 and c == 0), 

569 customdata=custom_s, hovertemplate=hover_s 

570 ), row=r+1, col=c+1) 

571 

572 # axis types/ranges 

573 _update_axis_type_and_range(fig, row=r+1, col=c+1, axis="x", centers=x_vals, is_log=_is_log_feature(PAY["j"])) 

574 _update_axis_type_and_range(fig, row=r+1, col=c+1, axis="y", centers=y_vals, is_log=_is_log_feature(PAY["i"])) 

575 

576 elif kind == ("cat","num"): 

577 base = PAY["base"]; x_vals = PAY["x"]; labels = PAY["y"] 

578 nlab = len(labels) 

579 # heatmap (categories on Y) 

580 fig.add_trace(go.Heatmap( 

581 x=x_vals, y=np.arange(nlab), z=Z_t, 

582 coloraxis="coloraxis", zsmooth=False, showscale=False, 

583 hovertemplate=(f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>" 

584 f"{base}: %{{text}}" 

585 "<br>E[target|success]: %{customdata:.3f}<extra></extra>"), 

586 text=np.array(labels)[:, None].repeat(len(x_vals), axis=1), 

587 customdata=Zmu_raw 

588 ), row=r+1, col=c+1) 

589 # p(success) shading 

590 for thr, alpha in ((0.5, 0.25), (0.8, 0.40)): 

591 mask = np.where(Zp < thr, 1.0, np.nan) 

592 fig.add_trace(go.Heatmap( 

593 x=x_vals, y=np.arange(nlab), z=mask, zmin=0, zmax=1, 

594 colorscale=[[0, "rgba(0,0,0,0)"], [1, f"rgba(128,128,128,{alpha})"]], 

595 showscale=False, hoverinfo="skip" 

596 ), row=r+1, col=c+1) 

597 # categorical ticks 

598 fig.update_yaxes(tickmode="array", tickvals=list(range(nlab)), ticktext=labels, row=r+1, col=c+1) 

599 # data overlays: numeric vs categorical with jitter on Y 

600 if base in df_raw_f.columns and feature_names[PAY["j"]] in df_raw_f.columns: 

601 cat_series = df_raw_f[base].astype("string") 

602 cat_to_idx = {lab: i for i, lab in enumerate(labels)} 

603 y_map = cat_series.map(cat_to_idx) 

604 ok = y_map.notna().to_numpy() 

605 y_idx = y_map.to_numpy(dtype=float) 

606 jitter = 0.10 * (np.random.default_rng(0).standard_normal(size=len(y_idx))) 

607 yj = y_idx + jitter 

608 xd = df_raw_f[feature_names[PAY["j"]]].to_numpy(dtype=float) 

609 show_leg = (r == 0 and c == 0) 

610 fig.add_trace(go.Scattergl( 

611 x=xd[success_mask & ok], y=yj[success_mask & ok], mode="markers", 

612 marker=dict(size=4, color="black", line=dict(width=0)), 

613 name="data (success)", legendgroup="data_succ", showlegend=show_leg, 

614 hovertemplate=("trial_id: %{customdata[0]}<br>" 

615 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>" 

616 f"{base}: %{{customdata[1]}}<br>" 

617 f"{tgt_col}: %{{customdata[2]:.4f}}<extra></extra>"), 

618 customdata=np.column_stack([ 

619 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask & ok], 

620 cat_series.to_numpy()[success_mask & ok], 

621 df_raw_f[tgt_col].to_numpy()[success_mask & ok], 

622 ]) 

623 ), row=r+1, col=c+1) 

624 fig.add_trace(go.Scattergl( 

625 x=xd[fail_mask & ok], y=yj[fail_mask & ok], mode="markers", 

626 marker=dict(size=5, color="red", line=dict(color="black", width=0.8)), 

627 name="data (failed)", legendgroup="data_fail", showlegend=show_leg, 

628 hovertemplate=("trial_id: %{customdata[0]}<br>" 

629 f"{feature_names[PAY['j']]}: %{{x:.6g}}<br>" 

630 f"{base}: %{{customdata[1]}}<br>" 

631 "status: failed (NaN target)<extra></extra>"), 

632 customdata=np.column_stack([ 

633 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask & ok], 

634 cat_series.to_numpy()[fail_mask & ok], 

635 ]) 

636 ), row=r+1, col=c+1) 

637 # axes: x numeric; y categorical range 

638 _update_axis_type_and_range(fig, row=r+1, col=c+1, axis="x", centers=x_vals, is_log=_is_log_feature(PAY["j"])) 

639 fig.update_yaxes(range=[-0.5, nlab - 0.5], row=r+1, col=c+1) 

640 

641 elif kind == ("num","cat"): 

642 base = PAY["base"]; y_vals = PAY["y"]; labels = PAY["x"] 

643 nlab = len(labels) 

644 # heatmap (categories on X) 

645 fig.add_trace(go.Heatmap( 

646 x=np.arange(nlab), y=y_vals, z=Z_t, 

647 coloraxis="coloraxis", zsmooth=False, showscale=False, 

648 hovertemplate=(f"{base}: %{{text}}<br>" 

649 f"{feature_names[PAY['i']]}: %{{y:.6g}}" 

650 "<br>E[target|success]: %{customdata:.3f}<extra></extra>"), 

651 text=np.array(labels)[None, :].repeat(len(y_vals), axis=0), 

652 customdata=Zmu_raw 

653 ), row=r+1, col=c+1) 

654 for thr, alpha in ((0.5, 0.25), (0.8, 0.40)): 

655 mask = np.where(Zp < thr, 1.0, np.nan) 

656 fig.add_trace(go.Heatmap( 

657 x=np.arange(nlab), y=y_vals, z=mask, zmin=0, zmax=1, 

658 colorscale=[[0, "rgba(0,0,0,0)"], [1, f"rgba(128,128,128,{alpha})"]], 

659 showscale=False, hoverinfo="skip" 

660 ), row=r+1, col=c+1) 

661 fig.update_xaxes(tickmode="array", tickvals=list(range(nlab)), ticktext=labels, row=r+1, col=c+1) 

662 # data overlays with jitter on X 

663 if base in df_raw_f.columns and feature_names[PAY["i"]] in df_raw_f.columns: 

664 cat_series = df_raw_f[base].astype("string") 

665 cat_to_idx = {lab: i for i, lab in enumerate(labels)} 

666 x_map = cat_series.map(cat_to_idx) 

667 ok = x_map.notna().to_numpy() 

668 x_idx = x_map.to_numpy(dtype=float) 

669 jitter = 0.10 * (np.random.default_rng(0).standard_normal(size=len(x_idx))) 

670 xj = x_idx + jitter 

671 yd = df_raw_f[feature_names[PAY["i"]]].to_numpy(dtype=float) 

672 show_leg = (r == 0 and c == 0) 

673 fig.add_trace(go.Scattergl( 

674 x=xj[success_mask & ok], y=yd[success_mask & ok], mode="markers", 

675 marker=dict(size=4, color="black", line=dict(width=0)), 

676 name="data (success)", legendgroup="data_succ", showlegend=show_leg, 

677 hovertemplate=("trial_id: %{customdata[0]}<br>" 

678 f"{base}: %{{customdata[1]}}<br>" 

679 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>" 

680 f"{tgt_col}: %{{customdata[2]:.4f}}<extra></extra>"), 

681 customdata=np.column_stack([ 

682 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask & ok], 

683 cat_series.to_numpy()[success_mask & ok], 

684 df_raw_f[tgt_col].to_numpy()[success_mask & ok], 

685 ]) 

686 ), row=r+1, col=c+1) 

687 fig.add_trace(go.Scattergl( 

688 x=xj[fail_mask & ok], y=yd[fail_mask & ok], mode="markers", 

689 marker=dict(size=5, color="red", line=dict(color="black", width=0.8)), 

690 name="data (failed)", legendgroup="data_fail", showlegend=show_leg, 

691 hovertemplate=("trial_id: %{customdata[0]}<br>" 

692 f"{base}: %{{customdata[1]}}<br>" 

693 f"{feature_names[PAY['i']]}: %{{y:.6g}}<br>" 

694 "status: failed (NaN target)<extra></extra>"), 

695 customdata=np.column_stack([ 

696 df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask & ok], 

697 cat_series.to_numpy()[fail_mask & ok], 

698 ]) 

699 ), row=r+1, col=c+1) 

700 # axes: x categorical; y numeric 

701 fig.update_xaxes(range=[-0.5, nlab - 0.5], row=r+1, col=c+1) 

702 _update_axis_type_and_range(fig, row=r+1, col=c+1, axis="y", centers=y_vals, is_log=_is_log_feature(PAY["i"])) 

703 

704 elif kind == ("cat","cat"): 

705 labels_y = PAY["y"] 

706 labels_x = PAY["x"] 

707 ny, nx = len(labels_y), len(labels_x) 

708 

709 # Build customdata carrying (row_label, col_label) for hovertemplate. 

710 custom = np.dstack(( 

711 np.array(labels_y, dtype=object)[:, None].repeat(nx, axis=1), 

712 np.array(labels_x, dtype=object)[None, :].repeat(ny, axis=0), 

713 )) 

714 

715 # Heatmap over categorical indices 

716 fig.add_trace(go.Heatmap( 

717 x=np.arange(nx), 

718 y=np.arange(ny), 

719 z=Z_t, 

720 coloraxis="coloraxis", 

721 zsmooth=False, 

722 showscale=False, 

723 hovertemplate=( 

724 "row: %{customdata[0]}<br>" 

725 "col: %{customdata[1]}<br>" 

726 "E[target|success]: %{z:.3f}<extra></extra>" 

727 ), 

728 customdata=custom, 

729 ), row=r+1, col=c+1) 

730 

731 # p(success) shading overlays 

732 for thr, alpha in ((0.5, 0.25), (0.8, 0.40)): 

733 mask = np.where(Zp < thr, 1.0, np.nan) 

734 fig.add_trace(go.Heatmap( 

735 x=np.arange(nx), 

736 y=np.arange(ny), 

737 z=mask, 

738 zmin=0, 

739 zmax=1, 

740 colorscale=[[0, "rgba(0,0,0,0)"], [1, f"rgba(128,128,128,{alpha})"]], 

741 showscale=False, 

742 hoverinfo="skip", 

743 ), row=r+1, col=c+1) 

744 

745 # Categorical tick labels on both axes 

746 fig.update_xaxes( 

747 tickmode="array", 

748 tickvals=list(range(nx)), 

749 ticktext=labels_x, 

750 range=[-0.5, nx - 0.5], 

751 row=r+1, 

752 col=c+1, 

753 ) 

754 fig.update_yaxes( 

755 tickmode="array", 

756 tickvals=list(range(ny)), 

757 ticktext=labels_y, 

758 range=[-0.5, ny - 0.5], 

759 row=r+1, 

760 col=c+1, 

761 ) 

762 

763 # --- outer axis labels 

764 def _panel_title(kind: str, key: object) -> str: 

765 return feature_names[int(key)] if kind == "num" else str(key) 

766 

767 for c, (_, key_c) in enumerate(panels): 

768 fig.update_xaxes(title_text=_panel_title(panels[c][0], key_c), row=k, col=c+1) 

769 for r, (kind_r, key_r) in enumerate(panels): 

770 fig.update_yaxes(title_text=_panel_title(kind_r, key_r), row=r+1, col=1) 

771 

772 # --- title 

773 def _fmt_c(v): 

774 if isinstance(v, slice): 

775 a = f"{v.start:g}" if v.start is not None else "" 

776 b = f"{v.stop:g}" if v.stop is not None else "" 

777 return f"[{a},{b}]" 

778 if isinstance(v, (list, tuple, np.ndarray)): 

779 try: 

780 return "[" + ",".join(f"{float(x):g}" for x in np.asarray(v).tolist()) + "]" 

781 except Exception: 

782 return "[" + ",".join(map(str, v)) + "]" 

783 return str(v) 

784 

785 title_parts = [f"2D partial dependence of expected {tgt_col}"] 

786 

787 # numeric constraints shown 

788 for name, val in kw_num.items(): 

789 title_parts.append(f"{name}={_fmt_c(val)}") 

790 # categorical constraints: fixed shown as base=Label; allowed ranges omitted in title 

791 for base, lab in cat_fixed.items(): 

792 title_parts.append(f"{base}={lab}") 

793 title = " — ".join([title_parts[0], ", ".join(title_parts[1:])]) if len(title_parts) > 1 else title_parts[0] 

794 

795 # --- layout 

796 cell = 250 

797 z_title = "E[target|success]" + (" (log10)" if use_log_scale_for_target else "") 

798 if use_log_scale_for_target and global_shift > 0: 

799 z_title += f" (shift Δ={global_shift:.3g})" 

800 

801 width = width if (width and width > 0) else cell * k 

802 width = max(width, 400) 

803 height = height if (height and height > 0) else cell * k 

804 height = max(height, 400) 

805 

806 fig.update_layout( 

807 template="simple_white", 

808 width=width, 

809 height=height, 

810 title=title, 

811 legend_title_text="", 

812 coloraxis=dict( 

813 colorscale=colorscale, 

814 cmin=cmin_t, cmax=cmax_t, 

815 colorbar=dict( 

816 title=z_title, 

817 thickness=10, # thinner bar 

818 len=0.55, # shorter bar (fraction of plot height) 

819 lenmode="fraction", 

820 x=1.02, y=0.5, # just right of plot, vertically centered 

821 xanchor="left", yanchor="middle", 

822 ), 

823 ), 

824 legend=dict( 

825 orientation="v", 

826 x=1.02, xanchor="left", # to the right of the colorbar 

827 y=1.0, yanchor="top", 

828 bgcolor="rgba(255,255,255,0.85)" 

829 ), 

830 margin=dict(t=90, r=100), # room for title + legend + colorbar 

831 ) 

832 

833 if output: 

834 output = Path(output) 

835 output.parent.mkdir(parents=True, exist_ok=True) 

836 fig.write_html(str(output), include_plotlyjs="cdn") 

837 if show: 

838 fig.show("browser") 

839 return fig 

840 

841 

842def plot1d( 

843 model: xr.Dataset | Path | str, 

844 output: Path | None = None, 

845 csv_out: Path | None = None, 

846 grid_size: int = 300, 

847 line_color: str = "rgb(31,119,180)", 

848 band_alpha: float = 0.25, 

849 figure_height_per_row_px: int = 320, 

850 show: bool = False, 

851 use_log_scale_for_target_y: bool = True, # log-y for target 

852 log_y_epsilon: float = 1e-9, 

853 optimal: bool = True, 

854 suggest: int = 0, 

855 width:int|None = None, 

856 height:int|None = None, 

857 seed: int|None = 42, 

858 **kwargs, 

859) -> go.Figure: 

860 """ 

861 Vertical 1D PD panels of E[target|success] vs each *free* feature. 

862 Scalars (fix & hide), slices (restrict sweep & x-range), lists/tuples (discrete grids). 

863 Categorical bases (e.g. language) are plotted as a single categorical subplot 

864 when not fixed; passing --language "Linear A" fixes that base and removes it 

865 from the plotted axes. 

866 """ 

867 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model) 

868 pred_success, pred_loss = _build_predictors(ds) 

869 

870 feature_names = [str(n) for n in ds["feature"].values.tolist()] 

871 transforms = [str(t) for t in ds["feature_transform"].values.tolist()] 

872 X_mean = ds["feature_mean"].values.astype(float) 

873 X_std = ds["feature_std"].values.astype(float) 

874 

875 df_raw = _raw_dataframe_from_dataset(ds) 

876 Xn_train = ds["Xn_train"].values.astype(float) 

877 n_rows, p = Xn_train.shape 

878 

879 # --- one-hot categorical groups --- 

880 groups = opt._onehot_groups(feature_names) # { base: {"labels":[...], "name_by_label":{label:member}, "members":[...]} } 

881 bases = set(groups.keys()) 

882 name_to_idx = {name: j for j, name in enumerate(feature_names)} 

883 

884 # --- canonicalize kwargs: numeric vs categorical (base) --- 

885 idx_map = _canon_key_set(ds) 

886 kw_num_raw: dict[str, object] = {} 

887 kw_cat_raw: dict[str, object] = {} 

888 for k, v in kwargs.items(): 

889 if k in bases: 

890 kw_cat_raw[k] = v 

891 continue 

892 if k in idx_map: 

893 kw_num_raw[idx_map[k]] = v 

894 continue 

895 import re as _re 

896 nk = _re.sub(r"[^a-z0-9]+", "", str(k).lower()) 

897 if nk in idx_map: 

898 kw_num_raw[idx_map[nk]] = v 

899 

900 # --- resolve categorical constraints: fixed (single) vs allowed (multiple) --- 

901 cat_fixed: dict[str, str] = {} 

902 cat_allowed: dict[str, list[str]] = {} 

903 for base, val in kw_cat_raw.items(): 

904 labels = groups[base]["labels"] 

905 if isinstance(val, str): 

906 if val not in labels: 

907 raise ValueError(f"Unknown category for {base!r}: {val!r}. Choices: {labels}") 

908 cat_fixed[base] = val 

909 elif isinstance(val, (list, tuple, set)): 

910 chosen = [x for x in val if isinstance(x, str) and x in labels] 

911 if not chosen: 

912 raise ValueError(f"No valid categories for {base!r} in {val!r}. Choices: {labels}") 

913 # multiple -> treat as allowed subset (NOT fixed) 

914 cat_allowed[base] = list(dict.fromkeys(chosen)) 

915 else: 

916 raise ValueError(f"Categorical constraint for {base!r} must be a string or list/tuple of strings.") 

917 

918 # --- filter rows by fixed categoricals (affects medians/percentiles & overlays) --- 

919 row_mask = np.ones(n_rows, dtype=bool) 

920 for base, label in cat_fixed.items(): 

921 if base in df_raw.columns: 

922 row_mask &= (df_raw[base].astype("string") == pd.Series([label]*len(df_raw), dtype="string")).to_numpy() 

923 else: 

924 member_name = groups[base]["name_by_label"][label] 

925 j = name_to_idx[member_name] 

926 raw_j = feature_raw_from_artifact_or_reconstruct(ds, j, member_name, transforms[j]).astype(float) 

927 row_mask &= (raw_j >= 0.5) 

928 

929 df_raw_f = df_raw.loc[row_mask].reset_index(drop=True) 

930 Xn_train_f = Xn_train[row_mask, :] 

931 

932 # --- helpers to transform original <-> standardized for feature j --- 

933 def _orig_to_std(j: int, x, transforms, mu, sd): 

934 x = np.asarray(x, dtype=float) 

935 if transforms[j] == "log10": 

936 x = np.where(x <= 0, np.nan, x) 

937 x = np.log10(x) 

938 return (x - mu[j]) / sd[j] 

939 

940 # --- numeric constraint split (STANDARDIZED) --- 

941 fixed_scalars: dict[int, float] = {} 

942 range_windows: dict[int, tuple[float, float]] = {} 

943 choice_values: dict[int, np.ndarray] = {} 

944 for name, val in kw_num_raw.items(): 

945 if name not in name_to_idx: 

946 continue 

947 j = name_to_idx[name] 

948 if isinstance(val, slice): 

949 lo = _orig_to_std(j, float(val.start), transforms, X_mean, X_std) 

950 hi = _orig_to_std(j, float(val.stop), transforms, X_mean, X_std) 

951 lo, hi = float(min(lo, hi)), float(max(lo, hi)) 

952 range_windows[j] = (lo, hi) 

953 elif isinstance(val, (list, tuple, np.ndarray)): 

954 arr = _orig_to_std(j, np.asarray(val, dtype=float), transforms, X_mean, X_std) 

955 choice_values[j] = np.asarray(arr, dtype=float) 

956 else: 

957 fixed_scalars[j] = float(_orig_to_std(j, float(val), transforms, X_mean, X_std)) 

958 

959 # --- apply categorical fixed as standardized scalar fixes on each one-hot member --- 

960 for base, label in cat_fixed.items(): 

961 labels = groups[base]["labels"] 

962 for lab in labels: 

963 member_name = groups[base]["name_by_label"][lab] 

964 j = name_to_idx[member_name] 

965 raw_val = 1.0 if lab == label else 0.0 

966 fixed_scalars[j] = float(_orig_to_std(j, raw_val, transforms, X_mean, X_std)) 

967 

968 # --- overlays conditioned on the same kwargs (numeric + categorical) --- 

969 optimal_df = opt.optimal(model, count=1, seed=seed, **kwargs) if optimal else None 

970 suggest_df = opt.suggest(model, count=suggest, seed=seed, **kwargs) if (suggest and suggest > 0) else None 

971 

972 # --- base standardized point (median over filtered rows), then apply scalar fixes --- 

973 base_std = np.median(Xn_train_f, axis=0) 

974 for j, vstd in fixed_scalars.items(): 

975 base_std[j] = vstd 

976 

977 # --- plotted panels: numeric free features + categorical bases not fixed --- 

978 onehot_members = set() 

979 for base, g in groups.items(): 

980 onehot_members.update(g["members"]) 

981 free_numeric_idx = [j for j in range(p) if (j not in fixed_scalars) and (feature_names[j] not in onehot_members)] 

982 free_cat_bases = [b for b in bases if b not in cat_fixed] # optional: filtered by cat_allowed later 

983 

984 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases] 

985 if not panels: 

986 raise ValueError("All features are fixed (or categorical only with single category chosen); nothing to plot.") 

987 

988 # --- empirical 1–99% from filtered rows for numeric bounds --- 

989 p01p99 = [np.percentile(Xn_train_f[:, j], [1, 99]) for j in range(p)] 

990 def _grid_1d(j: int, n: int) -> np.ndarray: 

991 p01, p99 = p01p99[j] 

992 if j in choice_values: 

993 vals = np.asarray(choice_values[j], dtype=float) 

994 vals = vals[(vals >= p01) & (vals <= p99)] 

995 return np.unique(np.sort(vals)) if vals.size else np.array([np.median(Xn_train_f[:, j])], dtype=float) 

996 lo, hi = p01, p99 

997 if j in range_windows: 

998 rlo, rhi = range_windows[j] 

999 lo, hi = max(lo, rlo), min(hi, rhi) 

1000 if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo: 

1001 lo, hi = p01, max(p01 + 1e-9, p99) 

1002 return np.linspace(lo, hi, n) 

1003 

1004 # --- one-hot member names (robust) --- 

1005 onehot_member_names: set[str] = set() 

1006 for base, g in groups.items(): 

1007 # names recorded by the detector 

1008 onehot_member_names.update(g["members"]) 

1009 # fallback pattern match in case detector missed anything 

1010 prefix = f"{base}=" 

1011 onehot_member_names.update([nm for nm in feature_names if nm.startswith(prefix)]) 

1012 

1013 # --- build panel list: numeric free features + categorical bases (not fixed) --- 

1014 free_numeric_idx = [ 

1015 j for j, nm in enumerate(feature_names) 

1016 if (j not in fixed_scalars) and (nm not in onehot_member_names) 

1017 ] 

1018 free_cat_bases = [b for b in bases if b not in cat_fixed] 

1019 

1020 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases] 

1021 if not panels: 

1022 raise ValueError("All features are fixed (or only single-category categoricals remain); nothing to plot.") 

1023 

1024 # sanity: ensure we didn't accidentally keep any one-hot member columns 

1025 assert all( 

1026 (feature_names[key] not in onehot_member_names) if kind == "num" else True 

1027 for kind, key in panels 

1028 ), "internal: one-hot member leaked into numeric panels" 

1029 

1030 # --- figure scaffold with clean titles --- 

1031 def _panel_title(kind: str, key: object) -> str: 

1032 return feature_names[int(key)] if kind == "num" else str(key) 

1033 

1034 subplot_titles = [_panel_title(kind, key) for kind, key in panels] 

1035 fig = make_subplots( 

1036 rows=len(panels), 

1037 cols=1, 

1038 shared_xaxes=False, 

1039 subplot_titles=subplot_titles, 

1040 ) 

1041 

1042 # --- masks/data from filtered rows --- 

1043 tgt_col = str(ds.attrs["target"]) 

1044 success_mask = ~pd.isna(df_raw_f[tgt_col]).to_numpy() 

1045 fail_mask = ~success_mask 

1046 losses_success = df_raw_f.loc[success_mask, tgt_col].to_numpy().astype(float) 

1047 trial_ids_success = df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[success_mask] 

1048 trial_ids_fail = df_raw_f.get("trial_id", pd.Series(np.arange(len(df_raw_f)))).to_numpy()[fail_mask] 

1049 band_fill_rgba = _rgb_to_rgba(line_color, band_alpha) 

1050 

1051 tidy_rows: list[dict] = [] 

1052 

1053 row_pos = 0 

1054 for kind, key in panels: 

1055 row_pos += 1 

1056 

1057 if kind == "num": 

1058 j = key 

1059 grid = _grid_1d(j, grid_size) 

1060 Xn_grid = np.repeat(base_std[None, :], len(grid), axis=0) 

1061 Xn_grid[:, j] = grid 

1062 

1063 # # --- DEBUG: confirm the feature is actually changing in standardized space --- 

1064 # print(f"[{feature_names[j]}] std grid head: {grid[:6]}") 

1065 # print(f"[{feature_names[j]}] std grid ptp (range): {np.ptp(grid)}") 

1066 # print(f"[{feature_names[j]}] Xn_grid[:2, j]: {Xn_grid[:2, j]}") 

1067 # print(f"[{feature_names[j]}] Xn 1–99%: {p01p99[j]}") 

1068 

1069 p_grid = pred_success(Xn_grid) 

1070 mu_grid, sd_grid = pred_loss(Xn_grid, include_observation_noise=True) 

1071 # print(feature_names[j], "mu range:", float(np.ptp(mu_grid))) 

1072 

1073 x_internal = grid * X_std[j] + X_mean[j] 

1074 x_display = _inverse_transform(transforms[j], x_internal) 

1075 

1076 # print(f"[{feature_names[j]}] orig head: {x_display[:6]}") 

1077 # print(f"[{feature_names[j]}] orig ptp (range): {np.ptp(x_display)}") 

1078 

1079 if use_log_scale_for_target_y: 

1080 mu_plot = np.maximum(mu_grid, log_y_epsilon) 

1081 lo_plot = np.maximum(mu_grid - 2.0 * sd_grid, log_y_epsilon) 

1082 hi_plot = np.maximum(mu_grid + 2.0 * sd_grid, log_y_epsilon) 

1083 losses_s_plot = np.maximum(losses_success, log_y_epsilon) if losses_success.size else losses_success 

1084 else: 

1085 mu_plot = mu_grid 

1086 lo_plot = mu_grid - 2.0 * sd_grid 

1087 hi_plot = mu_grid + 2.0 * sd_grid 

1088 losses_s_plot = losses_success 

1089 

1090 y_arrays = [lo_plot, hi_plot] + ([losses_s_plot] if losses_s_plot.size else []) 

1091 y_low = float(np.nanmin([np.nanmin(a) for a in y_arrays])) 

1092 y_high = float(np.nanmax([np.nanmax(a) for a in y_arrays])) 

1093 pad = 0.05 * (y_high - y_low + 1e-12) 

1094 y0_plot = (y_low - pad) if not use_log_scale_for_target_y else max(y_low / 1.5, log_y_epsilon) 

1095 y1_tmp = (y_high + pad) if not use_log_scale_for_target_y else y_high * 1.2 

1096 y_failed_band = y1_tmp + (y_high - y_low + 1e-12) * (0.08 if not use_log_scale_for_target_y else 0.3) 

1097 if use_log_scale_for_target_y and y_failed_band <= log_y_epsilon: 

1098 y_failed_band = max(10.0 * log_y_epsilon, y_high * 2.0) 

1099 y1_plot = y_failed_band + (0.02 if not use_log_scale_for_target_y else 0.05) * (y_high - y_low + 1e-12) 

1100 

1101 _add_low_success_shading_1d(fig, row_pos, x_display, p_grid, y0_plot, y1_plot) 

1102 

1103 show_legend = (row_pos == 1) 

1104 fig.add_trace(go.Scatter(x=x_display, y=lo_plot, mode="lines", 

1105 line=dict(width=0, color=line_color), 

1106 name="±2σ", legendgroup="band", showlegend=False, hoverinfo="skip"), 

1107 row=row_pos, col=1) 

1108 fig.add_trace(go.Scatter(x=x_display, y=hi_plot, mode="lines", fill="tonexty", 

1109 line=dict(width=0, color=line_color), fillcolor=band_fill_rgba, 

1110 name="±2σ", legendgroup="band", showlegend=show_legend, 

1111 hovertemplate="E[target|success]: %{y:.3f}<extra>±2σ</extra>"), 

1112 row=row_pos, col=1) 

1113 fig.add_trace(go.Scatter(x=x_display, y=mu_plot, mode="lines", 

1114 line=dict(width=2, color=line_color), 

1115 name="E[target|success]", legendgroup="mean", showlegend=show_legend, 

1116 hovertemplate=f"{feature_names[j]}: %{{x:.6g}}<br>E[target|success]: %{{y:.3f}}<extra></extra>"), 

1117 row=row_pos, col=1) 

1118 

1119 # experimental points 

1120 if feature_names[j] in df_raw_f.columns: 

1121 x_data_all = df_raw_f[feature_names[j]].to_numpy().astype(float) 

1122 else: 

1123 full_vals = feature_raw_from_artifact_or_reconstruct(ds, j, feature_names[j], transforms[j]).astype(float) 

1124 x_data_all = full_vals[row_mask] 

1125 

1126 x_succ = x_data_all[success_mask] 

1127 if x_succ.size: 

1128 fig.add_trace(go.Scattergl( 

1129 x=x_succ, y=losses_s_plot, mode="markers", 

1130 marker=dict(size=5, color="black", line=dict(width=0)), 

1131 name="data (success)", legendgroup="data_s", showlegend=show_legend, 

1132 hovertemplate=("trial_id: %{customdata}<br>" 

1133 f"{feature_names[j]}: %{{x:.6g}}<br>" 

1134 f"{tgt_col}: %{{y:.4f}}<extra></extra>"), 

1135 customdata=trial_ids_success 

1136 ), row=row_pos, col=1) 

1137 

1138 x_fail = x_data_all[fail_mask] 

1139 if x_fail.size: 

1140 y_fail_plot = np.full_like(x_fail, y_failed_band, dtype=float) 

1141 fig.add_trace(go.Scattergl( 

1142 x=x_fail, y=y_fail_plot, mode="markers", 

1143 marker=dict(size=6, color="red", line=dict(color="black", width=0.8)), 

1144 name="data (failed)", legendgroup="data_f", showlegend=show_legend, 

1145 hovertemplate=("trial_id: %{customdata}<br>" 

1146 f"{feature_names[j]}: %{{x:.6g}}<br>" 

1147 "status: failed (NaN target)<extra></extra>"), 

1148 customdata=trial_ids_fail 

1149 ), row=row_pos, col=1) 

1150 

1151 # overlays 

1152 if optimal_df is not None and feature_names[j] in optimal_df.columns: 

1153 x_opt = optimal_df[feature_names[j]].values 

1154 y_opt = optimal_df["pred_target_mean"].values 

1155 y_sd = optimal_df["pred_target_sd"].values 

1156 fig.add_trace(go.Scattergl( 

1157 x=x_opt, y=y_opt, mode="markers", 

1158 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"), 

1159 name="optimal", legendgroup="optimal", showlegend=show_legend, 

1160 hovertemplate=(f"predicted: %{{y:.3g}} ± {y_sd[0]:.3g}<br>" 

1161 f"{feature_names[j]}: %{{x:.6g}}<extra></extra>") 

1162 ), row=row_pos, col=1) 

1163 

1164 if suggest_df is not None and feature_names[j] in suggest_df.columns: 

1165 x_sug = suggest_df[feature_names[j]].values 

1166 y_sug = suggest_df["pred_target_mean"].values 

1167 y_sd = suggest_df["pred_target_sd"].values 

1168 fig.add_trace(go.Scattergl( 

1169 x=x_sug, y=y_sug, mode="markers", 

1170 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"), 

1171 name="suggested", legendgroup="suggested", showlegend=show_legend, 

1172 hovertemplate=(f"predicted: %{{y:.3g}} ± {{y_sd:.3g}}<br>" 

1173 f"{feature_names[j]}: %{{x:.6g}}<extra></extra>") 

1174 ), row=row_pos, col=1) 

1175 

1176 # axes 

1177 _maybe_log_axis(fig, row_pos, 1, feature_names[j], axis="x", transforms=transforms, j=j) 

1178 fig.update_yaxes(title_text=f"{tgt_col}", row=row_pos, col=1) 

1179 _set_yaxis_range(fig, row=row_pos, col=1, 

1180 y0=y0_plot, y1=y1_plot, 

1181 log=use_log_scale_for_target_y, eps=log_y_epsilon) 

1182 

1183 # restrict x-range if constrained 

1184 is_log_x = (transforms[j] == "log10") 

1185 def _std_to_orig(val_std: float) -> float: 

1186 vi = val_std * X_std[j] + X_mean[j] 

1187 return float(_inverse_transform(transforms[j], np.array([vi]))[0]) 

1188 

1189 x_min_override = x_max_override = None 

1190 if j in range_windows: 

1191 lo_std, hi_std = range_windows[j] 

1192 x_min_override = min(_std_to_orig(lo_std), _std_to_orig(hi_std)) 

1193 x_max_override = max(_std_to_orig(lo_std), _std_to_orig(hi_std)) 

1194 elif j in choice_values: 

1195 ints = choice_values[j] * X_std[j] + X_mean[j] 

1196 origs = _inverse_transform(transforms[j], ints) 

1197 x_min_override = float(np.min(origs)) 

1198 x_max_override = float(np.max(origs)) 

1199 

1200 if (x_min_override is not None) and (x_max_override is not None): 

1201 if is_log_x: 

1202 x0 = max(x_min_override, 1e-12) 

1203 x1 = max(x_max_override, x0 * (1 + 1e-9)) 

1204 pad = (x1 / x0) ** 0.03 

1205 fig.update_xaxes(type="log", 

1206 range=[np.log10(x0 / pad), np.log10(x1 * pad)], 

1207 row=row_pos, col=1) 

1208 else: 

1209 span = (x_max_override - x_min_override) or 1.0 

1210 pad = 0.02 * span 

1211 fig.update_xaxes(range=[x_min_override - pad, x_max_override + pad], 

1212 row=row_pos, col=1) 

1213 

1214 fig.update_xaxes(title_text=feature_names[j], row=row_pos, col=1) 

1215 

1216 # tidy rows 

1217 for xd, xi, mu_i, sd_i, p_i in zip(x_display, x_internal, mu_grid, sd_grid, p_grid): 

1218 tidy_rows.append({ 

1219 "feature": feature_names[j], 

1220 "x_display": float(xd), 

1221 "x_internal": float(xi), 

1222 "target_conditional_mean": float(mu_i), 

1223 "target_conditional_sd": float(sd_i), 

1224 "success_probability": float(p_i), 

1225 }) 

1226 

1227 else: 

1228 base = key # categorical base 

1229 labels_all = groups[base]["labels"] 

1230 labels = cat_allowed.get(base, labels_all) 

1231 

1232 # Build standardized design for each label at the base point 

1233 Xn_grid = np.repeat(base_std[None, :], len(labels), axis=0) 

1234 for r, lab in enumerate(labels): 

1235 for lab2 in labels_all: 

1236 member_name = groups[base]["name_by_label"][lab2] 

1237 j2 = name_to_idx[member_name] 

1238 raw_val = 1.0 if (lab2 == lab) else 0.0 

1239 # standardized set: 

1240 Xi = (raw_val - X_mean[j2]) / X_std[j2] 

1241 Xn_grid[r, j2] = Xi 

1242 

1243 p_vec = pred_success(Xn_grid) 

1244 mu_vec, sd_vec = pred_loss(Xn_grid, include_observation_noise=True) 

1245 print(feature_names[j], "mu range:", float(np.ptp(mu_grid))) 

1246 

1247 # y transform 

1248 if use_log_scale_for_target_y: 

1249 mu_plot = np.maximum(mu_vec, log_y_epsilon) 

1250 lo_plot = np.maximum(mu_vec - 2.0 * sd_vec, log_y_epsilon) 

1251 hi_plot = np.maximum(mu_vec + 2.0 * sd_vec, log_y_epsilon) 

1252 losses_s_plot = np.maximum(df_raw_f.loc[success_mask, tgt_col].to_numpy().astype(float), log_y_epsilon) if success_mask.any() else np.array([]) 

1253 else: 

1254 mu_plot = mu_vec 

1255 lo_plot = mu_vec - 2.0 * sd_vec 

1256 hi_plot = mu_vec + 2.0 * sd_vec 

1257 losses_s_plot = df_raw_f.loc[success_mask, tgt_col].to_numpy().astype(float) if success_mask.any() else np.array([]) 

1258 

1259 # y-range 

1260 y_arrays = [lo_plot, hi_plot] + ([losses_s_plot] if losses_s_plot.size else []) 

1261 y_low = float(np.nanmin([np.nanmin(a) for a in y_arrays])) if y_arrays else 0.0 

1262 y_high = float(np.nanmax([np.nanmax(a) for a in y_arrays])) if y_arrays else 1.0 

1263 pad = 0.05 * (y_high - y_low + 1e-12) 

1264 y0_plot = (y_low - pad) if not use_log_scale_for_target_y else max(y_low / 1.5, log_y_epsilon) 

1265 y1_tmp = (y_high + pad) if not use_log_scale_for_target_y else y_high * 1.2 

1266 y_failed_band = y1_tmp + (y_high - y_low + 1e-12) * (0.08 if not use_log_scale_for_target_y else 0.3) 

1267 if use_log_scale_for_target_y and y_failed_band <= log_y_epsilon: 

1268 y_failed_band = max(10.0 * log_y_epsilon, y_high * 2.0) 

1269 y1_plot = y_failed_band + (0.02 if not use_log_scale_for_target_y else 0.05) * (y_high - y_low + 1e-12) 

1270 

1271 # x positions are 0..K-1 with tick labels = category names 

1272 x_pos = np.arange(len(labels), dtype=float) 

1273 

1274 # shading per-category threshold regions using shapes 

1275 def _shade_for_thresh(thr: float, alpha: float): 

1276 for k_i, p_i in enumerate(p_vec): 

1277 if p_i < thr: 

1278 fig.add_shape( 

1279 type="rect", 

1280 xref=f"x{'' if row_pos==1 else row_pos}", 

1281 yref=f"y{'' if row_pos==1 else row_pos}", 

1282 x0=k_i - 0.5, x1=k_i + 0.5, 

1283 y0=y0_plot, y1=y1_plot, 

1284 line=dict(width=0), 

1285 fillcolor=f"rgba(128,128,128,{alpha})", 

1286 layer="below", 

1287 row=row_pos, col=1 

1288 ) 

1289 _shade_for_thresh(0.8, 0.40) 

1290 _shade_for_thresh(0.5, 0.25) 

1291 

1292 show_legend = (row_pos == 1) 

1293 

1294 # mean with error bars (±2σ) 

1295 fig.add_trace(go.Scatter( 

1296 x=x_pos, y=mu_plot, mode="lines+markers", 

1297 line=dict(width=2, color=line_color), 

1298 marker=dict(size=7, color=line_color), 

1299 error_y=dict(type="data", array=(hi_plot - mu_plot), arrayminus=(mu_plot - lo_plot), visible=True), 

1300 name="E[target|success]", legendgroup="mean", showlegend=show_legend, 

1301 hovertemplate=(f"{base}: %{{text}}<br>E[target|success]: %{{y:.3f}}" 

1302 "<br>±2σ shown as error bar<extra></extra>"), 

1303 text=labels 

1304 ), row=row_pos, col=1) 

1305 

1306 # experimental points: map each row's label to index 

1307 if base in df_raw_f.columns: 

1308 lab_series = df_raw_f[base].astype("string") 

1309 else: 

1310 # reconstruct from one-hot members 

1311 member_cols = [groups[base]["name_by_label"][lab] for lab in labels_all] 

1312 idx_max = df_raw_f[member_cols].to_numpy().argmax(axis=1) 

1313 lab_series = pd.Series([labels_all[i] for i in idx_max], dtype="string") 

1314 

1315 label_to_idx = {lab: i for i, lab in enumerate(labels)} 

1316 x_idx_all = lab_series.map(lambda s: label_to_idx.get(str(s), np.nan)).to_numpy(dtype=float) 

1317 x_idx_succ = x_idx_all[success_mask] 

1318 x_idx_fail = x_idx_all[fail_mask] 

1319 

1320 # jitter for visibility 

1321 rng = np.random.default_rng(0) 

1322 jitter = lambda n: (rng.random(n) - 0.5) * 0.15 

1323 

1324 if x_idx_succ.size: 

1325 fig.add_trace(go.Scattergl( 

1326 x=x_idx_succ + jitter(x_idx_succ.size), 

1327 y=losses_s_plot, 

1328 mode="markers", 

1329 marker=dict(size=5, color="black", line=dict(width=0)), 

1330 name="data (success)", legendgroup="data_s", showlegend=show_legend, 

1331 hovertemplate=("trial_id: %{customdata}<br>" 

1332 f"{base}: %{{text}}<br>" 

1333 f"{tgt_col}: %{{y:.4f}}<extra></extra>"), 

1334 text=[labels[int(i)] if np.isfinite(i) and int(i) < len(labels) else "?" for i in x_idx_succ], 

1335 customdata=trial_ids_success 

1336 ), row=row_pos, col=1) 

1337 

1338 if x_idx_fail.size: 

1339 y_fail_plot = np.full_like(x_idx_fail, y_failed_band, dtype=float) 

1340 fig.add_trace(go.Scattergl( 

1341 x=x_idx_fail + jitter(x_idx_fail.size), y=y_fail_plot, mode="markers", 

1342 marker=dict(size=6, color="red", line=dict(color="black", width=0.8)), 

1343 name="data (failed)", legendgroup="data_f", showlegend=show_legend, 

1344 hovertemplate=("trial_id: %{customdata}<br>" 

1345 f"{base}: %{{text}}<br>" 

1346 "status: failed (NaN target)<extra></extra>"), 

1347 text=[labels[int(i)] if np.isfinite(i) and int(i) < len(labels) else "?" for i in x_idx_fail], 

1348 customdata=trial_ids_fail 

1349 ), row=row_pos, col=1) 

1350 

1351 # overlays for categorical base: map label to x index 

1352 if optimal_df is not None and (base in optimal_df.columns): 

1353 lab_opt = str(optimal_df[base].values[0]) 

1354 if lab_opt in label_to_idx: 

1355 x_opt = [float(label_to_idx[lab_opt])] 

1356 y_opt = optimal_df["pred_target_mean"].values 

1357 y_sd = optimal_df["pred_target_sd"].values 

1358 fig.add_trace(go.Scattergl( 

1359 x=x_opt, y=y_opt, mode="markers", 

1360 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"), 

1361 name="optimal", legendgroup="optimal", showlegend=show_legend, 

1362 hovertemplate=(f"predicted: %{{y:.3g}} ± {y_sd[0]:.3g}<br>" 

1363 f"{base}: {lab_opt}<extra></extra>") 

1364 ), row=row_pos, col=1) 

1365 

1366 if suggest_df is not None and (base in suggest_df.columns): 

1367 labs_sug = suggest_df[base].astype(str).tolist() 

1368 xs = [label_to_idx[l] for l in labs_sug if l in label_to_idx] 

1369 if xs: 

1370 keep_mask = [l in label_to_idx for l in labs_sug] 

1371 y_sug = suggest_df.loc[keep_mask, "pred_target_mean"].values 

1372 fig.add_trace(go.Scattergl( 

1373 x=np.array(xs, dtype=float), y=y_sug, mode="markers", 

1374 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"), 

1375 name="suggested", legendgroup="suggested", showlegend=show_legend, 

1376 hovertemplate=(f"{base}: %{{text}}<br>" 

1377 "predicted: %{{y:.3g}}<extra>suggested</extra>"), 

1378 text=[labels[int(i)] for i in xs] 

1379 ), row=row_pos, col=1) 

1380 

1381 # axes: categorical ticks 

1382 fig.update_xaxes( 

1383 tickmode="array", 

1384 tickvals=x_pos.tolist(), 

1385 ticktext=labels, 

1386 row=row_pos, col=1 

1387 ) 

1388 fig.update_yaxes(title_text=f"{tgt_col}", row=row_pos, col=1) 

1389 _set_yaxis_range(fig, row=row_pos, col=1, 

1390 y0=y0_plot, y1=y1_plot, 

1391 log=use_log_scale_for_target_y, eps=log_y_epsilon) 

1392 fig.update_xaxes(title_text=base, row=row_pos, col=1) 

1393 

1394 # tidy rows 

1395 for lab, mu_i, sd_i, p_i in zip(labels, mu_vec, sd_vec, p_vec): 

1396 tidy_rows.append({ 

1397 "feature": base, 

1398 "x_display": str(lab), 

1399 "x_internal": float("nan"), 

1400 "target_conditional_mean": float(mu_i), 

1401 "target_conditional_sd": float(sd_i), 

1402 "success_probability": float(p_i), 

1403 }) 

1404 

1405 # title w/ constraints summary 

1406 def _fmt_c(v): 

1407 if isinstance(v, slice): 

1408 a = "" if v.start is None else f"{v.start:g}" 

1409 b = "" if v.stop is None else f"{v.stop:g}" 

1410 return f"[{a},{b}]" 

1411 if isinstance(v, (list, tuple, np.ndarray)): 

1412 try: 

1413 return "[" + ",".join(f"{float(x):g}" for x in np.asarray(v).tolist()) + "]" 

1414 except Exception: 

1415 return "[" + ",".join(map(str, v)) + "]" 

1416 try: 

1417 return f"{float(v):g}" 

1418 except Exception: 

1419 return str(v) 

1420 

1421 parts = [f"1D partial dependence of expected {tgt_col}"] 

1422 if kw_num_raw: 

1423 parts.append(", ".join(f"{k}={_fmt_c(v)}" for k, v in kw_num_raw.items())) 

1424 if cat_fixed: 

1425 parts.append(", ".join(f"{b}={lab}" for b, lab in cat_fixed.items())) 

1426 if cat_allowed: 

1427 parts.append(", ".join(f"{b}∈{{{', '.join(v)}}}" for b, v in cat_allowed.items())) 

1428 title = " — ".join(parts) if len(parts) > 1 else parts[0] 

1429 

1430 width = width if (width and width > 0) else 1200 

1431 height = height if (height and height > 0) else figure_height_per_row_px * len(panels) 

1432 

1433 fig.update_layout( 

1434 height=height, 

1435 width=width, 

1436 template="simple_white", 

1437 title=title, 

1438 legend_title_text="" 

1439 ) 

1440 

1441 if output: 

1442 output = Path(output) 

1443 output.parent.mkdir(parents=True, exist_ok=True) 

1444 fig.write_html(str(output), include_plotlyjs="cdn") 

1445 if csv_out: 

1446 csv_out = Path(csv_out) 

1447 csv_out.parent.mkdir(parents=True, exist_ok=True) 

1448 pd.DataFrame(tidy_rows).to_csv(str(csv_out), index=False) 

1449 if show: 

1450 fig.show("browser") 

1451 

1452 return fig 

1453 

1454 

1455# ============================================================================= 

1456# Helpers: dataset → predictors & featurization 

1457# ============================================================================= 

1458def _build_predictors(ds: xr.Dataset): 

1459 """Reconstruct fast GP predictors from the artifact using shared helpers.""" 

1460 # Training matrices / targets 

1461 Xn_all = ds["Xn_train"].values.astype(float) # (N, p) 

1462 y_success = ds["y_success"].values.astype(float) # (N,) 

1463 Xn_ok = ds["Xn_success_only"].values.astype(float) # (Ns, p) 

1464 y_loss_centered = ds["y_loss_centered"].values.astype(float) 

1465 

1466 # Compatibility: conditional_loss_mean may be a var or an attr 

1467 cond_mean = ( 

1468 float(ds["conditional_loss_mean"].values) 

1469 if "conditional_loss_mean" in ds 

1470 else float(ds.attrs.get("conditional_loss_mean", 0.0)) 

1471 ) 

1472 

1473 # Success head MAP params 

1474 ell_s = ds["map_success_ell"].values.astype(float) # (p,) 

1475 eta_s = float(ds["map_success_eta"].values) 

1476 sigma_s = float(ds["map_success_sigma"].values) 

1477 beta0_s = float(ds["map_success_beta0"].values) 

1478 

1479 # Loss head MAP params 

1480 ell_l = ds["map_loss_ell"].values.astype(float) # (p,) 

1481 eta_l = float(ds["map_loss_eta"].values) 

1482 sigma_l = float(ds["map_loss_sigma"].values) 

1483 mean_c = float(ds["map_loss_mean_const"].values) 

1484 

1485 # --- Cholesky precomputations (success) --- 

1486 K_s = kernel_m52_ard(Xn_all, Xn_all, ell_s, eta_s) + (sigma_s**2) * np.eye(Xn_all.shape[0]) 

1487 L_s = np.linalg.cholesky(add_jitter(K_s)) 

1488 alpha_s = solve_chol(L_s, (y_success - beta0_s)) 

1489 

1490 # --- Cholesky precomputations (loss | success) --- 

1491 K_l = kernel_m52_ard(Xn_ok, Xn_ok, ell_l, eta_l) + (sigma_l**2) * np.eye(Xn_ok.shape[0]) 

1492 L_l = np.linalg.cholesky(add_jitter(K_l)) 

1493 alpha_l = solve_chol(L_l, (y_loss_centered - mean_c)) 

1494 

1495 def predict_success_probability(Xn: np.ndarray) -> np.ndarray: 

1496 Ks = kernel_m52_ard(Xn, Xn_all, ell_s, eta_s) 

1497 mu = beta0_s + Ks @ alpha_s 

1498 return np.clip(mu, 0.0, 1.0) 

1499 

1500 def predict_conditional_target( 

1501 Xn: np.ndarray, 

1502 include_observation_noise: bool = True 

1503 ): 

1504 Kl = kernel_m52_ard(Xn, Xn_ok, ell_l, eta_l) 

1505 mu_centered = mean_c + Kl @ alpha_l 

1506 mu = mu_centered + cond_mean 

1507 

1508 # diag predictive variance 

1509 v = solve_lower(L_l, Kl.T) # (Ns, Nt) 

1510 var = kernel_diag_m52(Xn, ell_l, eta_l) - np.sum(v * v, axis=0) 

1511 var = np.maximum(var, 1e-12) 

1512 if include_observation_noise: 

1513 var = var + sigma_l**2 

1514 sd = np.sqrt(var) 

1515 return mu, sd 

1516 

1517 return predict_success_probability, predict_conditional_target 

1518 

1519 

1520def _raw_dataframe_from_dataset(ds: xr.Dataset) -> pd.DataFrame: 

1521 """Collect raw columns from the artifact into a DataFrame for plotting.""" 

1522 cols = {} 

1523 for name in ds.data_vars: 

1524 # include only row-aligned arrays 

1525 da = ds[name] 

1526 if "row" in da.dims and len(da.dims) == 1 and da.sizes["row"] == ds.sizes["row"]: 

1527 cols[name] = da.values 

1528 # Ensure trial_id exists for hover 

1529 if "trial_id" not in cols: 

1530 cols["trial_id"] = np.arange(ds.sizes["row"], dtype=int) 

1531 return pd.DataFrame(cols) 

1532 

1533 

1534def _apply_fixed_to_base( 

1535 base_std: np.ndarray, 

1536 fixed: dict[str, float], 

1537 feature_names: list[str], 

1538 transforms: list[str], 

1539 X_mean: np.ndarray, 

1540 X_std: np.ndarray, 

1541) -> np.ndarray: 

1542 """Override base point in standardized space with fixed ORIGINAL values.""" 

1543 out = base_std.copy() 

1544 name_to_idx = {n: i for i, n in enumerate(feature_names)} 

1545 for k, v in fixed.items(): 

1546 if k not in name_to_idx: 

1547 raise KeyError(f"Fixed variable '{k}' is not a model feature.") 

1548 j = name_to_idx[k] 

1549 x_raw = _forward_transform(transforms[j], float(v)) 

1550 out[j] = (x_raw - X_mean[j]) / X_std[j] 

1551 return out 

1552 

1553 

1554def _denormalize_then_inverse_transform(j: int, x_std: np.ndarray, transforms, X_mean, X_std) -> np.ndarray: 

1555 x_raw = x_std * X_std[j] + X_mean[j] 

1556 return _inverse_transform(transforms[j], x_raw) 

1557 

1558 

1559def _forward_transform(tr: str, x: float | np.ndarray) -> np.ndarray: 

1560 if tr == "log10": 

1561 x = np.asarray(x, dtype=float) 

1562 return np.log10(np.maximum(x, 1e-12)) 

1563 return np.asarray(x, dtype=float) 

1564 

1565 

1566def _inverse_transform(tr: str, x: np.ndarray) -> np.ndarray: 

1567 if tr == "log10": 

1568 return 10.0 ** x 

1569 return x 

1570 

1571 

1572def _maybe_log_axis(fig: go.Figure, row: int, col: int, name: str, axis: str = "x", transforms: list[str] | None = None, j: int | None = None): 

1573 """Use log axis for features that were log10-transformed.""" 

1574 use_log = False 

1575 if transforms is not None and j is not None: 

1576 use_log = (transforms[j] == "log10") 

1577 else: 

1578 use_log = ("learning_rate" in name.lower() or name.lower() == "lr") 

1579 if use_log: 

1580 if axis == "x": 

1581 fig.update_xaxes(type="log", row=row, col=col) 

1582 else: 

1583 fig.update_yaxes(type="log", row=row, col=col) 

1584 

1585 

1586def _rgb_string_to_tuple(s: str) -> tuple[int, int, int]: 

1587 vals = s[s.find("(") + 1 : s.find(")")].split(",") 

1588 r, g, b = [int(float(v)) for v in vals[:3]] 

1589 return r, g, b 

1590 

1591 

1592def _rgb_to_rgba(rgb: str, alpha: float) -> str: 

1593 # expects "rgb(r,g,b)" or "rgba(r,g,b,a)" 

1594 try: 

1595 r, g, b = _rgb_string_to_tuple(rgb) 

1596 except Exception: 

1597 r, g, b = (31, 119, 180) 

1598 return f"rgba({r},{g},{b},{alpha:.3f})" 

1599 

1600 

1601def _add_low_success_shading_1d(fig: go.Figure, row_idx: int, x_vals: np.ndarray, p: np.ndarray, y0: float, y1: float): 

1602 xref = "x" if row_idx == 1 else f"x{row_idx}" 

1603 yref = "y" if row_idx == 1 else f"y{row_idx}" 

1604 

1605 def _spans(vals: np.ndarray, mask: np.ndarray): 

1606 m = mask.astype(int) 

1607 diff = np.diff(np.concatenate([[0], m, [0]])) 

1608 starts = np.where(diff == 1)[0] 

1609 ends = np.where(diff == -1)[0] - 1 

1610 return [(vals[s], vals[e]) for s, e in zip(starts, ends)] 

1611 

1612 for x0, x1 in _spans(x_vals, p < 0.5): 

1613 fig.add_shape(type="rect", x0=x0, x1=x1, y0=y0, y1=y1, xref=xref, yref=yref, 

1614 line=dict(width=0), fillcolor="rgba(128,128,128,0.25)", layer="below") 

1615 for x0, x1 in _spans(x_vals, p < 0.8): 

1616 fig.add_shape(type="rect", x0=x0, x1=x1, y0=y0, y1=y1, xref=xref, yref=yref, 

1617 line=dict(width=0), fillcolor="rgba(128,128,128,0.40)", layer="below") 

1618 

1619 

1620def _set_yaxis_range(fig, *, row: int, col: int, y0: float, y1: float, log: bool, eps: float = 1e-12): 

1621 """Update a subplot's Y axis to [y0, y1]. For log axes, the range is given in log10 units.""" 

1622 if log: 

1623 y0 = max(y0, eps) 

1624 y1 = max(y1, y0 * (1.0 + 1e-6)) 

1625 fig.update_yaxes(type="log", range=[np.log10(y0), np.log10(y1)], row=row, col=col) 

1626 else: 

1627 fig.update_yaxes(type="-", range=[y0, y1], row=row, col=col) 

1628 

1629 

1630def plot1d_at_optimum( 

1631 model: xr.Dataset | Path | str, 

1632 output: Path | None = None, 

1633 csv_out: Path | None = None, 

1634 grid_size: int = 300, 

1635 line_color: str = "rgb(31,119,180)", 

1636 band_alpha: float = 0.25, 

1637 figure_height_per_row_px: int = 320, 

1638 show: bool = False, 

1639 use_log_scale_for_target_y: bool = True, 

1640 log_y_epsilon: float = 1e-9, 

1641 suggest: int = 0, # optional overlay 

1642 width: int | None = None, 

1643 height: int | None = None, 

1644 seed: int | None = 42, 

1645 **kwargs, # constraints in ORIGINAL units (as in your plot1d) 

1646) -> go.Figure: 

1647 """ 

1648 1D partial-dependence panels anchored at the *optimal* hyperparameter setting: 

1649 - Compute x* = argmin/argmax mean posterior from opt.optimal(...) 

1650 - For each feature, sweep that feature; keep all *other* features fixed at x*. 

1651 Supports numeric constraints (scalars/slices/choices) and categorical bases. 

1652 """ 

1653 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model) 

1654 pred_success, pred_loss = _build_predictors(ds) 

1655 

1656 # --- metadata --- 

1657 feature_names = [str(n) for n in ds["feature"].values.tolist()] 

1658 transforms = [str(t) for t in ds["feature_transform"].values.tolist()] 

1659 X_mean = ds["feature_mean"].values.astype(float) 

1660 X_std = ds["feature_std"].values.astype(float) 

1661 

1662 df_raw = _raw_dataframe_from_dataset(ds) 

1663 Xn_train = ds["Xn_train"].values.astype(float) 

1664 n_rows, p = Xn_train.shape 

1665 

1666 # --- one-hot categorical groups --- 

1667 groups = opt._onehot_groups(feature_names) 

1668 bases = set(groups.keys()) 

1669 name_to_idx = {name: j for j, name in enumerate(feature_names)} 

1670 

1671 # --- canonicalize kwargs: numeric vs categorical (base) --- 

1672 idx_map = _canon_key_set(ds) # your helper: maps normalized names -> exact feature column 

1673 kw_num_raw: dict[str, object] = {} 

1674 kw_cat_raw: dict[str, object] = {} 

1675 for k, v in kwargs.items(): 

1676 if k in bases: 

1677 kw_cat_raw[k] = v 

1678 elif k in idx_map: 

1679 kw_num_raw[idx_map[k]] = v 

1680 else: 

1681 import re as _re 

1682 nk = _re.sub(r"[^a-z0-9]+", "", str(k).lower()) 

1683 if nk in idx_map: 

1684 kw_num_raw[idx_map[nk]] = v 

1685 

1686 # --- resolve categorical constraints: fixed vs allowed subset --- 

1687 cat_fixed: dict[str, str] = {} 

1688 cat_allowed: dict[str, list[str]] = {} 

1689 for base, val in kw_cat_raw.items(): 

1690 labels = groups[base]["labels"] 

1691 if isinstance(val, str): 

1692 if val not in labels: 

1693 raise ValueError(f"Unknown category for {base!r}: {val!r}. Choices: {labels}") 

1694 cat_fixed[base] = val 

1695 elif isinstance(val, (list, tuple, set)): 

1696 chosen = [x for x in val if isinstance(x, str) and x in labels] 

1697 if not chosen: 

1698 raise ValueError(f"No valid categories for {base!r} in {val!r}. Choices: {labels}") 

1699 cat_allowed[base] = list(dict.fromkeys(chosen)) 

1700 else: 

1701 raise ValueError(f"Categorical constraint for {base!r} must be a string or list/tuple of strings.") 

1702 

1703 # ---------- 1) Find the *optimal* base point (original units) ---------- 

1704 opt_df = opt.optimal(model, count=1, seed=seed, **kwargs) # uses your gradient-based optimal() 

1705 # We’ll use this row both for overlays and as the anchor point. 

1706 # Expect numeric feature columns and categorical base columns present. 

1707 x_opt_std = np.zeros(p, dtype=float) 

1708 

1709 # Fill numerics from optimal row (orig -> internal -> std) 

1710 def _to_std_single(j: int, x_orig: float) -> float: 

1711 xi = x_orig 

1712 if transforms[j] == "log10": 

1713 xi = np.log10(np.maximum(x_orig, 1e-300)) 

1714 return float((xi - X_mean[j]) / X_std[j]) 

1715 

1716 # Mark one-hot member names 

1717 onehot_members: set[str] = set() 

1718 for base, g in groups.items(): 

1719 onehot_members.update(g["members"]) 

1720 

1721 # numeric features (skip one-hot members) 

1722 for j, nm in enumerate(feature_names): 

1723 if nm in onehot_members: 

1724 continue 

1725 if nm in opt_df.columns: 

1726 x_opt_std[j] = _to_std_single(j, float(opt_df.iloc[0][nm])) 

1727 else: 

1728 # Fall back to dataset median if not present (rare) 

1729 x_opt_std[j] = float(np.median(Xn_train[:, j])) 

1730 

1731 # Categorical bases: set one-hot block to the optimal label (or fixed) 

1732 for base, g in groups.items(): 

1733 # priority: fixed in kwargs → else from optimal row → else keep current (median/std) 

1734 if base in cat_fixed: 

1735 label = cat_fixed[base] 

1736 elif base in opt_df.columns: 

1737 label = str(opt_df.iloc[0][base]) 

1738 else: 

1739 # fallback: most frequent label in data 

1740 if base in df_raw.columns: 

1741 label = str(df_raw[base].astype("string").mode(dropna=True).iloc[0]) 

1742 else: 

1743 label = g["labels"][0] 

1744 

1745 for lab in g["labels"]: 

1746 member_name = g["name_by_label"][lab] 

1747 j2 = name_to_idx[member_name] 

1748 raw = 1.0 if lab == label else 0.0 

1749 # raw (0/1) → standardized using the artifact stats 

1750 x_opt_std[j2] = (raw - X_mean[j2]) / X_std[j2] 

1751 

1752 # ---------- 2) Numeric constraints in STANDARDIZED space ---------- 

1753 def _orig_to_std(j: int, x, transforms, mu, sd): 

1754 x = np.asarray(x, dtype=float) 

1755 if transforms[j] == "log10": 

1756 x = np.where(x <= 0, np.nan, x) 

1757 x = np.log10(x) 

1758 return (x - mu[j]) / sd[j] 

1759 

1760 fixed_scalars: dict[int, float] = {} 

1761 range_windows: dict[int, tuple[float, float]] = {} 

1762 choice_values: dict[int, np.ndarray] = {} 

1763 

1764 for name, val in kw_num_raw.items(): 

1765 if name not in name_to_idx: 

1766 continue 

1767 j = name_to_idx[name] 

1768 if isinstance(val, slice): 

1769 lo = _orig_to_std(j, float(val.start), transforms, X_mean, X_std) 

1770 hi = _orig_to_std(j, float(val.stop), transforms, X_mean, X_std) 

1771 lo, hi = float(min(lo, hi)), float(max(lo, hi)) 

1772 range_windows[j] = (lo, hi) 

1773 elif isinstance(val, (list, tuple, np.ndarray)): 

1774 arr = _orig_to_std(j, np.asarray(val, dtype=float), transforms, X_mean, X_std) 

1775 choice_values[j] = np.asarray(arr, dtype=float) 

1776 else: 

1777 fixed_scalars[j] = float(_orig_to_std(j, float(val), transforms, X_mean, X_std)) 

1778 

1779 # apply numeric fixed overrides on the base point 

1780 for j, vstd in fixed_scalars.items(): 

1781 x_opt_std[j] = vstd 

1782 

1783 # ---------- 3) Panels: sweep ONE var at a time around x* ---------- 

1784 # numeric free = not one-hot member and not fixed via kwargs 

1785 free_numeric_idx = [ 

1786 j for j, nm in enumerate(feature_names) 

1787 if (nm not in onehot_members) and (j not in fixed_scalars) 

1788 ] 

1789 # categorical bases: sweep if not fixed; otherwise not shown 

1790 free_cat_bases = [b for b in bases if b not in cat_fixed] 

1791 

1792 panels: list[tuple[str, object]] = [("num", j) for j in free_numeric_idx] + [("cat", b) for b in free_cat_bases] 

1793 if not panels: 

1794 raise ValueError("All features are fixed at the optimum (or categoricals fixed); nothing to plot.") 

1795 

1796 # empirical 1–99% per feature (for default sweep range) 

1797 Xn_p01 = np.percentile(Xn_train, 1, axis=0) 

1798 Xn_p99 = np.percentile(Xn_train, 99, axis=0) 

1799 

1800 def _grid_1d(j: int, n: int) -> np.ndarray: 

1801 # default range in std space 

1802 lo, hi = float(Xn_p01[j]), float(Xn_p99[j]) 

1803 if j in range_windows: 

1804 lo = max(lo, range_windows[j][0]) 

1805 hi = min(hi, range_windows[j][1]) 

1806 if j in choice_values: 

1807 vals = np.asarray(choice_values[j], dtype=float) 

1808 vals = vals[(vals >= lo) & (vals <= hi)] 

1809 return np.unique(np.sort(vals)) if vals.size else np.array([x_opt_std[j]], dtype=float) 

1810 if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo: 

1811 lo, hi = x_opt_std[j] - 1.0, x_opt_std[j] + 1.0 

1812 return np.linspace(lo, hi, n) 

1813 

1814 # figure scaffold 

1815 subplot_titles = [feature_names[int(k)] if t == "num" else str(k) for t, k in panels] 

1816 fig = make_subplots(rows=len(panels), cols=1, shared_xaxes=False, subplot_titles=subplot_titles) 

1817 

1818 tgt_col = str(ds.attrs["target"]) 

1819 success_mask = ~pd.isna(df_raw[tgt_col]).to_numpy() 

1820 fail_mask = ~success_mask 

1821 losses_success = df_raw.loc[success_mask, tgt_col].to_numpy().astype(float) 

1822 trial_ids_success = df_raw.get("trial_id", pd.Series(np.arange(len(df_raw)))).to_numpy()[success_mask] 

1823 trial_ids_fail = df_raw.get("trial_id", pd.Series(np.arange(len(df_raw)))).to_numpy()[fail_mask] 

1824 band_fill_rgba = _rgb_to_rgba(line_color, band_alpha) 

1825 

1826 # optional overlay 

1827 suggest_df = opt.suggest(model, count=suggest, seed=seed, **kwargs) if (suggest and suggest > 0) else None 

1828 

1829 tidy_rows: list[dict] = [] 

1830 row_pos = 0 

1831 for kind, key in panels: 

1832 row_pos += 1 

1833 

1834 if kind == "num": 

1835 j = int(key) 

1836 grid = _grid_1d(j, grid_size) 

1837 Xn_grid = np.repeat(x_opt_std[None, :], len(grid), axis=0) 

1838 Xn_grid[:, j] = grid 

1839 

1840 p_grid = pred_success(Xn_grid) 

1841 mu_grid, sd_grid = pred_loss(Xn_grid, include_observation_noise=True) 

1842 

1843 x_internal = grid * X_std[j] + X_mean[j] 

1844 x_display = _inverse_transform(transforms[j], x_internal) 

1845 

1846 # y transform 

1847 if use_log_scale_for_target_y: 

1848 mu_plot = np.maximum(mu_grid, log_y_epsilon) 

1849 lo_plot = np.maximum(mu_grid - 2.0 * sd_grid, log_y_epsilon) 

1850 hi_plot = np.maximum(mu_grid + 2.0 * sd_grid, log_y_epsilon) 

1851 losses_s_plot = np.maximum(losses_success, log_y_epsilon) if losses_success.size else losses_success 

1852 else: 

1853 mu_plot = mu_grid 

1854 lo_plot = mu_grid - 2.0 * sd_grid 

1855 hi_plot = mu_grid + 2.0 * sd_grid 

1856 losses_s_plot = losses_success 

1857 

1858 y_arrays = [lo_plot, hi_plot] + ([losses_s_plot] if losses_s_plot.size else []) 

1859 y_low = float(np.nanmin([np.nanmin(a) for a in y_arrays])) 

1860 y_high = float(np.nanmax([np.nanmax(a) for a in y_arrays])) 

1861 pad = 0.05 * (y_high - y_low + 1e-12) 

1862 y0_plot = (y_low - pad) if not use_log_scale_for_target_y else max(y_low / 1.5, log_y_epsilon) 

1863 y1_tmp = (y_high + pad) if not use_log_scale_for_target_y else y_high * 1.2 

1864 y_failed_band = y1_tmp + (y_high - y_low + 1e-12) * (0.08 if not use_log_scale_for_target_y else 0.3) 

1865 if use_log_scale_for_target_y and y_failed_band <= log_y_epsilon: 

1866 y_failed_band = max(10.0 * log_y_epsilon, y_high * 2.0) 

1867 y1_plot = y_failed_band + (0.02 if not use_log_scale_for_target_y else 0.05) * (y_high - y_low + 1e-12) 

1868 

1869 _add_low_success_shading_1d(fig, row_pos, x_display, p_grid, y0_plot, y1_plot) 

1870 

1871 show_legend = (row_pos == 1) 

1872 # ±2σ band 

1873 fig.add_trace(go.Scatter(x=x_display, y=lo_plot, mode="lines", 

1874 line=dict(width=0, color=line_color), 

1875 name="±2σ", legendgroup="band", showlegend=False, hoverinfo="skip"), 

1876 row=row_pos, col=1) 

1877 fig.add_trace(go.Scatter(x=x_display, y=hi_plot, mode="lines", fill="tonexty", 

1878 line=dict(width=0, color=line_color), fillcolor=band_fill_rgba, 

1879 name="±2σ", legendgroup="band", showlegend=show_legend, 

1880 hovertemplate="E[target|success]: %{y:.3f}<extra>±2σ</extra>"), 

1881 row=row_pos, col=1) 

1882 # mean 

1883 fig.add_trace(go.Scatter(x=x_display, y=mu_plot, mode="lines", 

1884 line=dict(width=2, color=line_color), 

1885 name="E[target|success]", legendgroup="mean", showlegend=show_legend, 

1886 hovertemplate=f"{feature_names[j]}: %{{x:.6g}}<br>E[target|success]: %{{y:.3f}}<extra></extra>"), 

1887 row=row_pos, col=1) 

1888 

1889 # experimental points at y 

1890 if feature_names[j] in df_raw.columns: 

1891 x_data_all = df_raw[feature_names[j]].to_numpy().astype(float) 

1892 else: 

1893 full_vals = feature_raw_from_artifact_or_reconstruct(ds, j, feature_names[j], transforms[j]).astype(float) 

1894 x_data_all = full_vals 

1895 

1896 x_succ = x_data_all[success_mask] 

1897 if x_succ.size: 

1898 fig.add_trace(go.Scattergl( 

1899 x=x_succ, y=losses_s_plot, mode="markers", 

1900 marker=dict(size=5, color="black", line=dict(width=0)), 

1901 name="data (success)", legendgroup="data_s", showlegend=show_legend, 

1902 hovertemplate=("trial_id: %{customdata}<br>" 

1903 f"{feature_names[j]}: %{{x:.6g}}<br>" 

1904 f"{tgt_col}: %{{y:.4f}}<extra></extra>"), 

1905 customdata=trial_ids_success 

1906 ), row=row_pos, col=1) 

1907 

1908 x_fail = x_data_all[fail_mask] 

1909 if x_fail.size: 

1910 y_fail_plot = np.full_like(x_fail, y_failed_band, dtype=float) 

1911 fig.add_trace(go.Scattergl( 

1912 x=x_fail, y=y_fail_plot, mode="markers", 

1913 marker=dict(size=6, color="red", line=dict(color="black", width=0.8)), 

1914 name="data (failed)", legendgroup="data_f", showlegend=show_legend, 

1915 hovertemplate=("trial_id: %{customdata}<br>" 

1916 f"{feature_names[j]}: %{{x:.6g}}<br>" 

1917 "status: failed (NaN target)<extra></extra>"), 

1918 customdata=trial_ids_fail 

1919 ), row=row_pos, col=1) 

1920 

1921 # overlays: optimal (single point) and suggested (optional many) 

1922 x_opt_disp = None 

1923 if feature_names[j] in opt_df.columns: 

1924 x_opt_disp = float(opt_df.iloc[0][feature_names[j]]) 

1925 y_opt = float(opt_df.iloc[0]["pred_target_mean"]) 

1926 y_opt_sd = float(opt_df.iloc[0].get("pred_target_sd", np.nan)) 

1927 fig.add_trace(go.Scattergl( 

1928 x=[x_opt_disp], y=[y_opt], mode="markers", 

1929 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"), 

1930 name="optimal", legendgroup="optimal", showlegend=show_legend, 

1931 hovertemplate=(f"predicted: %{{y:.3g}}" 

1932 + ("" if np.isnan(y_opt_sd) else f" ± {y_opt_sd:.3g}") 

1933 + f"<br>{feature_names[j]}: %{{x:.6g}}<extra></extra>") 

1934 ), row=row_pos, col=1) 

1935 

1936 if suggest and (suggest_df is not None) and (feature_names[j] in suggest_df.columns): 

1937 xs = suggest_df[feature_names[j]].values.astype(float) 

1938 ys = suggest_df["pred_target_mean"].values.astype(float) 

1939 ysd = suggest_df.get("pred_target_sd", pd.Series([np.nan]*len(suggest_df))).values 

1940 fig.add_trace(go.Scattergl( 

1941 x=xs, y=ys, mode="markers", 

1942 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"), 

1943 name="suggested", legendgroup="suggested", showlegend=show_legend, 

1944 hovertemplate=("predicted: %{y:.3g}" 

1945 + (" ± %{customdata:.3g}" if not np.isnan(ysd).all() else "") 

1946 + f"<br>{feature_names[j]}: %{{x:.6g}}<extra>suggested</extra>"), 

1947 customdata=ysd 

1948 ), row=row_pos, col=1) 

1949 

1950 # axes + ranges 

1951 _maybe_log_axis(fig, row_pos, 1, feature_names[j], axis="x", transforms=transforms, j=j) 

1952 fig.update_yaxes(title_text=f"{tgt_col}", row=row_pos, col=1) 

1953 _set_yaxis_range(fig, row=row_pos, col=1, 

1954 y0=y0_plot, y1=y1_plot, 

1955 log=use_log_scale_for_target_y, eps=log_y_epsilon) 

1956 fig.update_xaxes(title_text=feature_names[j], row=row_pos, col=1) 

1957 

1958 # If a constraint limited the sweep, respect it on the displayed axis 

1959 def _std_to_orig(val_std: float) -> float: 

1960 vi = val_std * X_std[j] + X_mean[j] 

1961 return float(_inverse_transform(transforms[j], np.array([vi]))[0]) 

1962 

1963 if j in range_windows: 

1964 lo_std, hi_std = range_windows[j] 

1965 x_min_override = min(_std_to_orig(lo_std), _std_to_orig(hi_std)) 

1966 x_max_override = max(_std_to_orig(lo_std), _std_to_orig(hi_std)) 

1967 span = (x_max_override - x_min_override) or 1.0 

1968 pad = 0.02 * span 

1969 fig.update_xaxes(range=[x_min_override - pad, x_max_override + pad], row=row_pos, col=1) 

1970 elif j in choice_values and choice_values[j].size: 

1971 ints = choice_values[j] * X_std[j] + X_mean[j] 

1972 origs = _inverse_transform(transforms[j], ints) 

1973 span = float(np.max(origs) - np.min(origs)) or 1.0 

1974 pad = 0.05 * span 

1975 fig.update_xaxes(range=[float(np.min(origs) - pad), float(np.max(origs) + pad)], row=row_pos, col=1) 

1976 

1977 # tidy rows 

1978 for xd, xi, mu_i, sd_i, p_i in zip(x_display, x_internal, mu_grid, sd_grid, p_grid): 

1979 tidy_rows.append({ 

1980 "feature": feature_names[j], 

1981 "x_display": float(xd), 

1982 "x_internal": float(xi), 

1983 "target_conditional_mean": float(mu_i), 

1984 "target_conditional_sd": float(sd_i), 

1985 "success_probability": float(p_i), 

1986 }) 

1987 

1988 else: 

1989 base = str(key) 

1990 labels_all = groups[base]["labels"] 

1991 labels = cat_allowed.get(base, labels_all) 

1992 

1993 # Evaluate each label with numerics and other bases fixed at x_opt_std 

1994 Xn_grid = np.repeat(x_opt_std[None, :], len(labels), axis=0) 

1995 for r, lab in enumerate(labels): 

1996 for lab2 in labels_all: 

1997 member_name = groups[base]["name_by_label"][lab2] 

1998 j2 = name_to_idx[member_name] 

1999 raw_val = 1.0 if (lab2 == lab) else 0.0 

2000 Xn_grid[r, j2] = (raw_val - X_mean[j2]) / X_std[j2] 

2001 

2002 p_vec = pred_success(Xn_grid) 

2003 mu_vec, sd_vec = pred_loss(Xn_grid, include_observation_noise=True) 

2004 

2005 # y transform 

2006 if use_log_scale_for_target_y: 

2007 mu_plot = np.maximum(mu_vec, log_y_epsilon) 

2008 lo_plot = np.maximum(mu_vec - 2.0 * sd_vec, log_y_epsilon) 

2009 hi_plot = np.maximum(mu_vec + 2.0 * sd_vec, log_y_epsilon) 

2010 losses_s_plot = np.maximum(df_raw.loc[success_mask, tgt_col].to_numpy().astype(float), log_y_epsilon) if success_mask.any() else np.array([]) 

2011 else: 

2012 mu_plot = mu_vec 

2013 lo_plot = mu_vec - 2.0 * sd_vec 

2014 hi_plot = mu_vec + 2.0 * sd_vec 

2015 losses_s_plot = df_raw.loc[success_mask, tgt_col].to_numpy().astype(float) if success_mask.any() else np.array([]) 

2016 

2017 y_arrays = [lo_plot, hi_plot] + ([losses_s_plot] if losses_s_plot.size else []) 

2018 y_low = float(np.nanmin([np.nanmin(a) for a in y_arrays])) if y_arrays else 0.0 

2019 y_high = float(np.nanmax([np.nanmax(a) for a in y_arrays])) if y_arrays else 1.0 

2020 pad = 0.05 * (y_high - y_low + 1e-12) 

2021 y0_plot = (y_low - pad) if not use_log_scale_for_target_y else max(y_low / 1.5, log_y_epsilon) 

2022 y1_tmp = (y_high + pad) if not use_log_scale_for_target_y else y_high * 1.2 

2023 y_failed_band = y1_tmp + (y_high - y_low + 1e-12) * (0.08 if not use_log_scale_for_target_y else 0.3) 

2024 if use_log_scale_for_target_y and y_failed_band <= log_y_epsilon: 

2025 y_failed_band = max(10.0 * log_y_epsilon, y_high * 2.0) 

2026 y1_plot = y_failed_band + (0.02 if not use_log_scale_for_target_y else 0.05) * (y_high - y_low + 1e-12) 

2027 

2028 # x = 0..K-1 with tick labels 

2029 x_pos = np.arange(len(labels), dtype=float) 

2030 

2031 # grey out infeasible (p<thr) 

2032 def _shade_for_thresh(thr: float, alpha: float): 

2033 for k_i, p_i in enumerate(p_vec): 

2034 if p_i < thr: 

2035 fig.add_shape( 

2036 type="rect", 

2037 xref=f"x{'' if row_pos==1 else row_pos}", 

2038 yref=f"y{'' if row_pos==1 else row_pos}", 

2039 x0=k_i - 0.5, x1=k_i + 0.5, 

2040 y0=y0_plot, y1=y1_plot, 

2041 line=dict(width=0), 

2042 fillcolor=f"rgba(128,128,128,{alpha})", 

2043 layer="below", 

2044 row=row_pos, col=1 

2045 ) 

2046 _shade_for_thresh(0.8, 0.40) 

2047 _shade_for_thresh(0.5, 0.25) 

2048 

2049 show_legend = (row_pos == 1) 

2050 fig.add_trace(go.Scatter( 

2051 x=x_pos, y=mu_plot, mode="lines+markers", 

2052 line=dict(width=2, color=line_color), 

2053 marker=dict(size=7, color=line_color), 

2054 error_y=dict(type="data", array=(hi_plot - mu_plot), arrayminus=(mu_plot - lo_plot), visible=True), 

2055 name="E[target|success]", legendgroup="mean", showlegend=show_legend, 

2056 hovertemplate=(f"{base}: %{{text}}<br>E[target|success]: %{{y:.3f}}<extra></extra>"), 

2057 text=labels 

2058 ), row=row_pos, col=1) 

2059 

2060 # overlay optimal point for this base (single label at x*=opt) 

2061 if base in opt_df.columns: 

2062 lab_opt = str(opt_df.iloc[0][base]) 

2063 if lab_opt in labels: 

2064 xi = float(labels.index(lab_opt)) 

2065 y_opt = float(opt_df.iloc[0]["pred_target_mean"]) 

2066 y_opt_sd = float(opt_df.iloc[0].get("pred_target_sd", np.nan)) 

2067 fig.add_trace(go.Scattergl( 

2068 x=[xi], y=[y_opt], mode="markers", 

2069 marker=dict(size=10, color="yellow", line=dict(color="black", width=1.5), symbol="x"), 

2070 name="optimal", legendgroup="optimal", showlegend=show_legend, 

2071 hovertemplate=(f"predicted: %{{y:.3g}}" 

2072 + ("" if np.isnan(y_opt_sd) else f" ± {y_opt_sd:.3g}") 

2073 + f"<br>{base}: {lab_opt}<extra></extra>") 

2074 ), row=row_pos, col=1) 

2075 

2076 # overlay suggestions (optional) 

2077 if suggest and (suggest_df is not None) and (base in suggest_df.columns): 

2078 labs_sug = suggest_df[base].astype(str).tolist() 

2079 xs = [labels.index(l) for l in labs_sug if l in labels] 

2080 if xs: 

2081 keep_mask = [l in labels for l in labs_sug] 

2082 y_sug = suggest_df.loc[keep_mask, "pred_target_mean"].values 

2083 fig.add_trace(go.Scattergl( 

2084 x=np.array(xs, dtype=float), y=y_sug, mode="markers", 

2085 marker=dict(size=9, color="cyan", line=dict(color="black", width=1.2), symbol="star"), 

2086 name="suggested", legendgroup="suggested", showlegend=show_legend, 

2087 hovertemplate=(f"{base}: %{{text}}<br>" 

2088 "predicted: %{{y:.3g}}<extra>suggested</extra>"), 

2089 text=[labels[int(i)] for i in xs] 

2090 ), row=row_pos, col=1) 

2091 

2092 fig.update_xaxes( 

2093 tickmode="array", 

2094 tickvals=x_pos.tolist(), 

2095 ticktext=labels, 

2096 title_text=base, 

2097 row=row_pos, col=1 

2098 ) 

2099 fig.update_yaxes(title_text=f"{tgt_col}", row=row_pos, col=1) 

2100 _set_yaxis_range(fig, row=row_pos, col=1, 

2101 y0=y0_plot, y1=y1_plot, 

2102 log=use_log_scale_for_target_y, eps=log_y_epsilon) 

2103 

2104 # tidy rows 

2105 for lab, mu_i, sd_i, p_i in zip(labels, mu_vec, sd_vec, p_vec): 

2106 tidy_rows.append({ 

2107 "feature": base, 

2108 "x_display": str(lab), 

2109 "x_internal": float("nan"), 

2110 "target_conditional_mean": float(mu_i), 

2111 "target_conditional_sd": float(sd_i), 

2112 "success_probability": float(p_i), 

2113 }) 

2114 

2115 # ---- layout & IO ---- 

2116 parts = [f"1D PD at optimal setting of all other hyperparameters ({ds.attrs.get('target', 'target')})"] 

2117 if kw_num_raw: 

2118 def _fmt_c(v): 

2119 if isinstance(v, slice): 

2120 a = "" if v.start is None else f"{v.start:g}" 

2121 b = "" if v.stop is None else f"{v.stop:g}" 

2122 return f"[{a},{b}]" 

2123 if isinstance(v, (list, tuple, np.ndarray)): 

2124 try: 

2125 return "[" + ",".join(f"{float(x):g}" for x in np.asarray(v).tolist()) + "]" 

2126 except Exception: 

2127 return "[" + ",".join(map(str, v)) + "]" 

2128 try: 

2129 return f"{float(v):g}" 

2130 except Exception: 

2131 return str(v) 

2132 parts.append(", ".join(f"{k}={_fmt_c(v)}" for k, v in kw_num_raw.items())) 

2133 if cat_fixed: 

2134 parts.append(", ".join(f"{b}={lab}" for b, lab in cat_fixed.items())) 

2135 title = " — ".join(parts) 

2136 

2137 width = width if (width and width > 0) else 1200 

2138 height = height if (height and height > 0) else figure_height_per_row_px * len(panels) 

2139 fig.update_layout(height=height, width=width, template="simple_white", title=title, legend_title_text="") 

2140 

2141 if output: 

2142 output = Path(output); output.parent.mkdir(parents=True, exist_ok=True) 

2143 fig.write_html(str(output), include_plotlyjs="cdn") 

2144 if csv_out: 

2145 csv_out = Path(csv_out); csv_out.parent.mkdir(parents=True, exist_ok=True) 

2146 pd.DataFrame(tidy_rows).to_csv(str(csv_out), index=False) 

2147 if show: 

2148 fig.show("browser") 

2149 return fig