Coverage for psyop/opt.py: 41.45%

1228 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-10 06:02 +0000

1# opt.py 

2# -*- coding: utf-8 -*- 

3 

4from pathlib import Path 

5from typing import Callable, Any 

6import re 

7 

8import numpy as np 

9import pandas as pd 

10import xarray as xr 

11import hashlib 

12from scipy.special import ndtr # Φ(z), vectorized 

13 

14from .util import get_rng, df_to_table 

15from .model import ( 

16 kernel_diag_m52, 

17 kernel_m52_ard, 

18 add_jitter, 

19 solve_chol, 

20 solve_lower, 

21) 

22from .model import feature_raw_from_artifact_or_reconstruct 

23 

24from rich.console import Console 

25from rich.table import Table 

26 

27console = Console() 

28 

29_ONEHOT_RE = re.compile(r"^(?P<base>[^=]+)=(?P<label>.+)$") 

30 

31 

32def _pretty_conditioned_on( 

33 fixed_norm_numeric: dict | None = None, 

34 cat_fixed_label: dict | None = None, 

35) -> str: 

36 """ 

37 Combine numeric fixed constraints (already normalized to model space) 

38 with categorical fixed choices into a single human-readable string. 

39 

40 Examples: 

41 - fixed_norm_numeric = {"epochs": 12.0, "batch_size": 32} 

42 - cat_fixed_label = {"language": "Linear B"} 

43 

44 Returns: 

45 "epochs=12, batch_size=32, language=Linear B" 

46 (ordering is deterministic: keys sorted within each group) 

47 """ 

48 fixed_norm_numeric = fixed_norm_numeric or {} 

49 cat_fixed_label = cat_fixed_label or {} 

50 

51 parts = [] 

52 

53 # Prefer the project-standard formatter if present. 

54 try: 

55 if fixed_norm_numeric: 

56 txt = _fixed_as_string(fixed_norm_numeric) # e.g. "epochs=12, batch_size=32" 

57 if txt: 

58 parts.append(txt) 

59 except Exception: 

60 # Fallback: simple k=v with general formatting. 

61 if fixed_norm_numeric: 

62 items = [] 

63 for k, v in sorted(fixed_norm_numeric.items()): 

64 try: 

65 items.append(f"{k}={float(v):.6g}") 

66 except Exception: 

67 items.append(f"{k}={v}") 

68 parts.append(", ".join(items)) 

69 

70 # Append categorical fixed choices as "base=Label" 

71 if cat_fixed_label: 

72 cat_txt = ", ".join(f"{b}={lab}" for b, lab in sorted(cat_fixed_label.items())) 

73 if cat_txt: 

74 parts.append(cat_txt) 

75 

76 return ", ".join(p for p in parts if p) 

77 

78 

79def _split_constraints_for_numeric_and_categorical( 

80 feature_names: list[str], 

81 kwargs: dict[str, object], 

82): 

83 """ 

84 Split user constraints into: 

85 - numeric: user_fixed, user_ranges, user_choices_num (by feature name) 

86 - categorical: cat_fixed_label (base->label), cat_allowed (base->set(labels)) 

87 - and return one-hot groups 

88 

89 Interp rules: 

90 * For a categorical base key (e.g. 'language'): 

91 - str -> fixed single label 

92 - list/tuple of str -> allowed label set 

93 * For a numeric feature key (non one-hot member): 

94 - number -> fixed 

95 - slice(lo,hi[,step]) -> range (lo,hi) inclusive on ends in post-filter 

96 - list/tuple of numbers -> finite choices 

97 - range(...) (python range) -> tuple of ints (choices) 

98 """ 

99 groups = _onehot_groups(feature_names) 

100 bases = set(groups.keys()) 

101 feature_set = set(feature_names) 

102 

103 user_fixed: dict[str, float] = {} 

104 user_ranges: dict[str, tuple[float, float]] = {} 

105 user_choices_num: dict[str, list[int | float]] = {} 

106 

107 cat_fixed_label: dict[str, str] = {} 

108 cat_allowed: dict[str, set[str]] = {} 

109 

110 # helper 

111 def _is_intlike(x) -> bool: 

112 try: 

113 return float(int(round(float(x)))) == float(x) 

114 except Exception: 

115 return False 

116 

117 for key, raw in (kwargs or {}).items(): 

118 # --- CATEGORICAL (by base key, not member name) --- 

119 if key in bases: 

120 labels = groups[key]["labels"] 

121 # fixed single label 

122 if isinstance(raw, str): 

123 if raw not in labels: 

124 raise ValueError(f"Unknown category for {key!r}: {raw!r}. Choices: {labels}") 

125 cat_fixed_label[key] = raw 

126 cat_allowed[key] = {raw} 

127 continue 

128 # list/tuple of labels (choices restriction) 

129 if isinstance(raw, (list, tuple, set)): 

130 chosen = [v for v in raw if isinstance(v, str) and (v in labels)] 

131 if not chosen: 

132 raise ValueError(f"No valid categories for {key!r} in {raw!r}. Choices: {labels}") 

133 cat_allowed[key] = set(chosen) 

134 continue 

135 # anything else -> ignore for cats 

136 continue 

137 

138 # --- NUMERIC (by feature name; skip one-hot member names) --- 

139 # If user accidentally passes member name 'language=Linear A', ignore here 

140 if key not in feature_set or _ONEHOT_RE.match(key): 

141 # Unknown or member-level keys are ignored at this stage 

142 continue 

143 

144 # python range -> tuple of ints 

145 if isinstance(raw, range): 

146 raw = tuple(raw) 

147 

148 # number -> fixed 

149 if isinstance(raw, (int, float, np.number)): 

150 val = float(raw) 

151 if np.isfinite(val): 

152 user_fixed[key] = val 

153 continue 

154 

155 # slice -> float range 

156 if isinstance(raw, slice): 

157 if raw.start is None or raw.stop is None: 

158 continue 

159 lo = float(raw.start); hi = float(raw.stop) 

160 if not (np.isfinite(lo) and np.isfinite(hi)): 

161 continue 

162 if lo > hi: 

163 lo, hi = hi, lo 

164 user_ranges[key] = (lo, hi) 

165 continue 

166 

167 # list/tuple -> numeric choices 

168 if isinstance(raw, (list, tuple)): 

169 if len(raw) == 0: 

170 continue 

171 # preserve ints if all int-like, else floats 

172 if all(_is_intlike(v) for v in raw): 

173 user_choices_num[key] = [int(round(float(v))) for v in raw] 

174 else: 

175 user_choices_num[key] = [float(v) for v in raw] 

176 continue 

177 

178 # otherwise: ignore 

179 

180 # Numeric fixed wins over its own range/choices 

181 for k in list(user_fixed.keys()): 

182 user_ranges.pop(k, None) 

183 user_choices_num.pop(k, None) 

184 

185 return groups, user_fixed, user_ranges, user_choices_num, cat_fixed_label, cat_allowed 

186 

187 

188def _detect_categorical_groups(feature_names: list[str]) -> dict[str, list[tuple[str, str]]]: 

189 """ 

190 Detect one-hot groups: {"language": [("language=Linear A","Linear A"), ("language=Linear B","Linear B"), ...]} 

191 """ 

192 groups: dict[str, list[tuple[str, str]]] = {} 

193 for name in feature_names: 

194 m = _ONEHOT_RE.match(name) 

195 if not m: 

196 continue 

197 base = m.group("base") 

198 lab = m.group("label") 

199 groups.setdefault(base, []).append((name, lab)) 

200 # deterministic order 

201 for base in groups: 

202 groups[base].sort(key=lambda t: t[1]) 

203 return groups 

204 

205def _project_categoricals_to_valid_onehot(df: pd.DataFrame, groups: dict[str, list[tuple[str, str]]]) -> pd.DataFrame: 

206 """ 

207 For each categorical group ensure exactly one column is 1 and the rest 0 (argmax projection). 

208 Works whether columns are 0/1 already or arbitrary scores in [0,1]. 

209 """ 

210 for base, pairs in groups.items(): 

211 cols = [name for name, _ in pairs if name in df.columns] 

212 if len(cols) <= 1: 

213 continue 

214 sub = df[cols].to_numpy(dtype=float) 

215 # treat NaNs as -inf so they never win 

216 sub = np.where(np.isfinite(sub), sub, -np.inf) 

217 if sub.size == 0: 

218 continue 

219 idx = np.argmax(sub, axis=1) 

220 new = np.zeros_like(sub) 

221 new[np.arange(sub.shape[0]), idx] = 1.0 

222 df.loc[:, cols] = new 

223 return df 

224 

225 

226def _apply_categorical_constraints(df: pd.DataFrame, 

227 groups: dict[str, list[tuple[str, str]]], 

228 fixed_str: dict[str, str], 

229 allowed_strs: dict[str, list[str]]) -> pd.DataFrame: 

230 """ 

231 Filter rows by categorical constraints expressed on the base names, e.g. 

232 fixed_str = {"language": "Linear B"} 

233 allowed_strs = {"language": ["Linear A", "Linear B"]} 

234 Operates on one-hot columns, so call BEFORE collapsing to string columns. 

235 """ 

236 mask = np.ones(len(df), dtype=bool) 

237 for base, val in (fixed_str or {}).items(): 

238 if base not in groups: 

239 continue 

240 cols = {label: name for name, label in groups[base] if name in df.columns} 

241 want = cols.get(val) 

242 if want is None: 

243 # no matching one-hot column — drop all rows 

244 mask &= False 

245 else: 

246 mask &= (df[want] >= 0.5) # after projection, exactly 1 column is 1 

247 for base, vals in (allowed_strs or {}).items(): 

248 if base not in groups: 

249 continue 

250 cols = {label: name for name, label in groups[base] if name in df.columns} 

251 want_cols = [cols[v] for v in vals if v in cols] 

252 if want_cols: 

253 mask &= (df[want_cols].sum(axis=1) >= 0.5) 

254 else: 

255 mask &= False 

256 return df.loc[mask].reset_index(drop=True) 

257 

258 

259def _onehot_groups(feature_names: list[str]) -> dict[str, dict]: 

260 """ 

261 Detect one-hot groups among feature names like 'language=Linear A'. 

262 Returns: 

263 { 

264 base: { 

265 "labels": [label1, ...], 

266 "members": [(feat_name, label), ...], 

267 "name_by_label": {label: feat_name} 

268 }, 

269 ... 

270 } 

271 """ 

272 groups: dict[str, dict] = {} 

273 for name in feature_names: 

274 m = _ONEHOT_RE.match(name) 

275 if not m: 

276 continue 

277 base = m.group("base") 

278 label = m.group("label") 

279 g = groups.setdefault(base, {"labels": [], "members": [], "name_by_label": {}}) 

280 g["labels"].append(label) 

281 g["members"].append((name, label)) 

282 g["name_by_label"][label] = name 

283 # stable order for labels 

284 for g in groups.values(): 

285 # keep insertion order from feature_names, but ensure uniqueness 

286 seen = set() 

287 uniq = [] 

288 for lab in g["labels"]: 

289 if lab not in seen: 

290 uniq.append(lab); seen.add(lab) 

291 g["labels"] = uniq 

292 return groups 

293 

294 

295 

296def _numeric_specs_only(search_specs: dict, groups: dict) -> dict: 

297 """ 

298 Return a copy of search_specs with one-hot member feature names removed. 

299 `groups` is the output of _onehot_groups(feature_names). 

300 """ 

301 if not groups: 

302 return dict(search_specs) 

303 

304 onehot_member_names = set() 

305 for g in groups.values(): 

306 onehot_member_names.update(g["name_by_label"].values()) 

307 

308 return {k: v for k, v in search_specs.items() if k not in onehot_member_names} 

309 

310 

311def _assert_valid_onehot(df: pd.DataFrame, groups: dict[str, dict], where: str = "") -> None: 

312 """ 

313 Assert every one-hot block has exactly one '1' per row (no NaNs). 

314 Prints a small diagnostic if not. 

315 """ 

316 for base, g in groups.items(): 

317 member_cols = [g["name_by_label"][lab] for lab in g["labels"] if g["name_by_label"][lab] in df.columns] 

318 if not member_cols: 

319 print(f"[onehot] {where}: base={base} has no member columns present") 

320 continue 

321 

322 block = df[member_cols].to_numpy() 

323 nonfinite_mask = ~np.isfinite(block) 

324 sums = np.nan_to_num(block, nan=0.0, posinf=0.0, neginf=0.0).sum(axis=1) 

325 

326 bad = np.where(nonfinite_mask.any(axis=1) | (sums != 1))[0] 

327 if bad.size: 

328 print(f"[BUG onehot] {where}: base={base}, rows with invalid one-hot: {bad[:20].tolist()} (showing first 20)") 

329 print("member_cols:", member_cols) 

330 print(df.iloc[bad[:5]][member_cols]) # show a few bad rows 

331 raise RuntimeError(f"Invalid one-hot block for base={base} at {where}") 

332 

333def _get_float_attr(obj, names, default=0.0): 

334 for n in names: 

335 if hasattr(obj, n): 

336 v = getattr(obj, n) 

337 # skip boolean flags like mean_only 

338 if isinstance(v, (bool, np.bool_)): 

339 continue 

340 try: 

341 return float(v) 

342 except Exception: 

343 pass 

344 return float(default) 

345 

346 

347 

348import itertools 

349import numpy as np 

350import pandas as pd 

351from pathlib import Path 

352 

353# ---------- small utils (reuse in this file) ---------- 

354def _orig_to_std(j: int, x, transforms, mu, sd): 

355 arr = np.asarray(x, dtype=float) 

356 if transforms[j] == "log10": 

357 arr = np.where(arr <= 0, np.nan, arr) 

358 arr = np.log10(arr) 

359 return (arr - mu[j]) / sd[j] 

360 

361def _std_to_orig(j: int, arr, transforms, mu, sd): 

362 x = np.asarray(arr, float) * sd[j] + mu[j] 

363 if transforms[j] == "log10": 

364 x = np.power(10.0, x) 

365 return x 

366 

367def _groups_from_feature_names(feature_names: list[str]) -> dict: 

368 # same grouping logic you already use elsewhere 

369 groups = {} 

370 for nm in feature_names: 

371 if "=" in nm: 

372 base, lab = nm.split("=", 1) 

373 g = groups.setdefault(base, {"labels": [], "name_by_label": {}, "members": []}) 

374 g["labels"].append(lab) 

375 g["name_by_label"][lab] = nm 

376 g["members"].append(nm) 

377 # Stable order 

378 for b in groups: 

379 labs = list(dict.fromkeys(groups[b]["labels"])) 

380 groups[b]["labels"] = labs 

381 groups[b]["members"] = [groups[b]["name_by_label"][lab] for lab in labs] 

382 return groups 

383 

384def _pick_attr(obj, names, allow_none=False): 

385 """Return the first present attribute in names without truth-testing arrays.""" 

386 for n in names: 

387 if hasattr(obj, n): 

388 v = getattr(obj, n) 

389 if (v is not None) or allow_none: 

390 return v 

391 return None 

392 

393# ---------- analytic GP (mean + grad) ---------- 

394class _GPMarginal: 

395 def __init__(self, Xtr, ytr, ell, eta, sigma, mean_const): 

396 self.X = np.asarray(Xtr, float) # (N,p) 

397 self.y = np.asarray(ytr, float) # (N,) 

398 self.ell = np.asarray(ell, float) # (p,) 

399 self.eta = float(eta) 

400 self.sigma = float(sigma) 

401 self.m = float(mean_const) 

402 K = kernel_m52_ard(self.X, self.X, self.ell, self.eta) 

403 K[np.diag_indices_from(K)] += self.sigma**2 

404 L = np.linalg.cholesky(add_jitter(K)) 

405 self.L = L 

406 self.alpha = solve_chol(L, (self.y - self.m)) 

407 self.X_train = getattr(self, "X_train", getattr(self, "Xtr", getattr(self, "X", None))) 

408 self.ell = getattr(self, "ell", getattr(self, "ls", None)) 

409 # Back-compat aliases: 

410 self.Xtr = self.X_train 

411 self.ls = self.ell 

412 self.Xtr = self.X_train 

413 self.ell = ell; self.ls = self.ell 

414 self.m0 = float(mean_const); self.mean_const = self.m0 

415 

416 def sd_at(self, x: np.ndarray, include_observation_noise: bool = True) -> float: 

417 """ 

418 Predictive standard deviation at a single standardized point x. 

419 """ 

420 x = np.asarray(x, float).reshape(1, -1) # (1, p) 

421 Ks = kernel_m52_ard(x, self.Xtr, self.ls, self.eta) # (1, N) 

422 v = solve_lower(self.L, Ks.T) # (N, 1) 

423 kss = kernel_diag_m52(x, self.ls, self.eta)[0] # scalar diag K(x,x) = eta^2 

424 var = float(kss - np.sum(v * v)) 

425 if include_observation_noise: 

426 var += float(self.sigma ** 2) 

427 var = max(var, 1e-12) 

428 return float(np.sqrt(var)) 

429 

430 def _k_and_grad(self, x): 

431 """k(x, X), ∂k/∂x (p-dimensional gradient aggregated over train points).""" 

432 x = np.asarray(x, float).reshape(1, -1) # (1,p) 

433 X = self.X 

434 ell = self.ell 

435 eta = self.eta 

436 

437 # distances in lengthscale space 

438 D = (x[:, None, :] - X[None, :, :]) / ell[None, None, :] # (1,N,p) 

439 r2 = np.sum(D*D, axis=2) # (1,N) 

440 r = np.sqrt(np.maximum(r2, 0.0)) # (1,N) 

441 sqrt5_r = np.sqrt(5.0) * r 

442 # kernel 

443 k = (eta**2) * (1.0 + sqrt5_r + (5.0/3.0)*r2) * np.exp(-sqrt5_r) # (1,N) 

444 

445 # grad wrt x: -(5η^2/3) e^{-√5 r} (1 + √5 r) * (x - xi)/ell^2 

446 # handle r=0 safely -> derivative is 0 

447 coef = -(5.0 * (eta**2) / 3.0) * np.exp(-sqrt5_r) * (1.0 + sqrt5_r) # (1,N) 

448 S = (x[:, None, :] - X[None, :, :]) / (ell[None, None, :]**2) # (1,N,p) 

449 grad = np.sum(coef[:, :, None] * S, axis=1) # (1,p) 

450 

451 return k.ravel(), grad.ravel() 

452 

453 def mean_and_grad(self, x: np.ndarray): 

454 # --- resolve training matrix 

455 Xtr = _pick_attr(self, ["X_train", "Xtr", "X"]) 

456 if Xtr is None: 

457 raise AttributeError("GPMarginal: training inputs not found (tried X_train, Xtr, X).") 

458 Xtr = np.asarray(Xtr, float) 

459 

460 # --- resolve hyperparams / vectors 

461 ell = _pick_attr(self, ["ell", "ls"]) 

462 if ell is None: 

463 raise AttributeError("GPMarginal: lengthscales not found (tried ell, ls).") 

464 ell = np.asarray(ell, float) 

465 

466 eta = _get_float_attr(self, ["eta"]) 

467 alpha = _pick_attr(self, ["alpha", "alpha_vec"]) 

468 if alpha is None: 

469 raise AttributeError("GPMarginal: alpha not found (tried alpha, alpha_vec).") 

470 alpha = np.asarray(alpha, float).ravel() 

471 

472 # mean constant (name differs across versions) 

473 m0 = _get_float_attr(self, ["mean_const", "m0", "beta0", "mean_c", "mean"], default=0.0) 

474 

475 # --- mean 

476 Ks = kernel_m52_ard(x[None, :], Xtr, ell, eta).ravel() # (N_train,) 

477 mu = float(m0 + Ks @ alpha) 

478 

479 # --- gradient wrt x (shape (p,)) 

480 grad_k = _grad_k_m52_ard_wrt_x(x, Xtr, ell, eta) # (N_train, p) 

481 grad_mu = grad_k.T @ alpha # (p,) 

482 

483 return mu, grad_mu 

484 

485 

486 def mean_only(self, X): 

487 Ks = kernel_m52_ard(X, self.X, self.ell, self.eta) 

488 return self.m + Ks @ self.alpha 

489 

490 

491def _grad_k_m52_ard_wrt_x(x: np.ndarray, Xtr: np.ndarray, ls: np.ndarray, eta: float) -> np.ndarray: 

492 """ 

493 ∂k(x, Xtr_i)/∂x for Matérn 5/2 ARD. 

494 Returns (N_train, p) — one row per training point. 

495 """ 

496 x = np.asarray(x, float).reshape(1, -1) # (1, p) 

497 Xtr = np.asarray(Xtr, float) # (N, p) 

498 ls = np.asarray(ls, float).reshape(1, -1) # (1, p) 

499 

500 diff = x - Xtr # (N, p) 

501 z = diff / ls # (N, p) 

502 r = np.sqrt(np.sum(z*z, axis=1)) # (N,) 

503 

504 sr5 = np.sqrt(5.0) 

505 coef = -(5.0 * (eta**2) / 3.0) * np.exp(-sr5 * r) * (1.0 + sr5 * r) # (N,) 

506 grad = coef[:, None] * (diff / (ls*ls)) # (N, p) 

507 return grad 

508 

509 

510def rng_for_dataset(ds, seed=None): 

511 if isinstance(seed, np.random.Generator): 

512 return seed 

513 

514 # Hash something stable from the dataset; Xn_train is fine. 

515 x = np.ascontiguousarray(ds["Xn_train"].values.astype(np.float64)) 

516 digest64 = int.from_bytes(hashlib.sha256(x.tobytes()).digest()[:8], "big") # 64 bits 

517 

518 if seed is None: 

519 mixed = np.uint64(digest64) # dataset-deterministic 

520 else: 

521 mixed = np.uint64(seed) ^ np.uint64(digest64) # mix user seed with dataset hash 

522 

523 return np.random.default_rng(int(mixed)) 

524 

525def suggest( 

526 model: xr.Dataset | Path | str, 

527 count: int = 10, 

528 output: Path | None = None, 

529 repulsion: float = 0.34, # repulsion radius & weight 

530 explore: float = 0.5, # probability to optimize EI (explore) 

531 success_threshold: float = 0.8, 

532 softmax_temp: float | None = 0.2, # τ for EI softmax; None/0 => greedy EI 

533 n_starts: int = 32, 

534 max_iters: int = 200, 

535 penalty_lambda: float = 1.0, 

536 penalty_beta: float = 10.0, 

537 direction: str | None = None, # defaults to model's 

538 seed: int | np.random.Generator | None = 42, 

539 **kwargs, # constraints in ORIGINAL units 

540) -> pd.DataFrame: 

541 import itertools, math 

542 import numpy as np, pandas as pd 

543 

544 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model) 

545 rng = rng_for_dataset(ds, seed) # dataset-aware determinism 

546 

547 # --- metadata 

548 feature_names = [str(n) for n in ds["feature"].values.tolist()] 

549 transforms = [str(t) for t in ds["feature_transform"].values.tolist()] 

550 mu_f = ds["feature_mean"].values.astype(float) 

551 sd_f = ds["feature_std"].values.astype(float) 

552 p = len(feature_names) 

553 name_to_idx = {nm: j for j, nm in enumerate(feature_names)} 

554 groups = _groups_from_feature_names(feature_names) 

555 

556 # --- GP heads 

557 gp_s = _GPMarginal( 

558 Xtr=ds["Xn_train"].values.astype(float), 

559 ytr=ds["y_success"].values.astype(float), 

560 ell=ds["map_success_ell"].values.astype(float), 

561 eta=float(ds["map_success_eta"].values), 

562 sigma=float(ds["map_success_sigma"].values), 

563 mean_const=float(ds["map_success_beta0"].values), 

564 ) 

565 cond_mean = float(ds["conditional_loss_mean"].values) if "conditional_loss_mean" in ds else 0.0 

566 gp_l = _GPMarginal( 

567 Xtr=ds["Xn_success_only"].values.astype(float), 

568 ytr=ds["y_loss_centered"].values.astype(float), 

569 ell=ds["map_loss_ell"].values.astype(float), 

570 eta=float(ds["map_loss_eta"].values), 

571 sigma=float(ds["map_loss_sigma"].values), 

572 mean_const=float(ds["map_loss_mean_const"].values), 

573 ) 

574 

575 # --- direction & EI baseline 

576 if direction is None: 

577 direction = str(ds.attrs.get("direction", "min")) 

578 flip = -1.0 if direction == "max" else 1.0 

579 best_feasible = _best_feasible_observed(ds, direction) 

580 

581 # --- constraints -> bounds (std space) 

582 cat_allowed: dict[str, list[str]] = {b: list(g["labels"]) for b, g in groups.items()} 

583 cat_fixed: dict[str, str] = {} 

584 fixed_num_std: dict[int, float] = {} 

585 range_num_std: dict[int, tuple[float, float]] = {} 

586 choice_num: dict[int, np.ndarray] = {} 

587 

588 def canon_key(k: str) -> str: 

589 import re as _re 

590 raw = str(k) 

591 stripped = _re.sub(r"[^a-z0-9]+", "", raw.lower()) 

592 if raw in name_to_idx: return raw 

593 for base in groups.keys(): 

594 if stripped == _re.sub(r"[^a-z0-9]+", "", base.lower()): 

595 return base 

596 return raw 

597 

598 for k, v in (kwargs or {}).items(): 

599 ck = canon_key(k) 

600 if ck in groups: # categorical base 

601 labels = groups[ck]["labels"] 

602 if isinstance(v, str): 

603 if v not in labels: 

604 raise ValueError(f"Unknown category for {ck}: {v}. Choices: {labels}") 

605 cat_fixed[ck] = v 

606 else: 

607 L = [x for x in (list(v) if isinstance(v, (list, tuple, set)) else [v]) if isinstance(x, str) and x in labels] 

608 if not L: 

609 raise ValueError(f"No valid categories for {ck} in {v}. Choices: {labels}") 

610 if len(L) == 1: 

611 cat_fixed[ck] = L[0] 

612 else: 

613 cat_allowed[ck] = L 

614 elif ck in name_to_idx: # numeric 

615 j = name_to_idx[ck] 

616 if isinstance(v, range): 

617 v = tuple(v) 

618 if isinstance(v, slice): 

619 lo = _orig_to_std(j, v.start, transforms, mu_f, sd_f) 

620 hi = _orig_to_std(j, v.stop, transforms, mu_f, sd_f) 

621 lo, hi = float(np.nanmin([lo, hi])), float(np.nanmax([lo, hi])) 

622 range_num_std[j] = (lo, hi) 

623 elif isinstance(v, (list, tuple, np.ndarray)): 

624 arr = _orig_to_std(j, np.asarray(v, float), transforms, mu_f, sd_f) 

625 choice_num[j] = np.asarray(arr, float) 

626 else: 

627 fixed_num_std[j] = float(_orig_to_std(j, float(v), transforms, mu_f, sd_f)) 

628 else: 

629 raise ValueError(f"Unknown constraint key: {k!r}") 

630 

631 Xn = ds["Xn_train"].values.astype(float) 

632 p01 = np.percentile(Xn, 1, axis=0) 

633 p99 = np.percentile(Xn, 99, axis=0) 

634 wide_lo = np.minimum(p01 - 1.0, -3.0) # allow outside training range 

635 wide_hi = np.maximum(p99 + 1.0, 3.0) 

636 

637 bounds: list[tuple[float, float] | None] = [None]*p 

638 for j in range(p): 

639 if j in fixed_num_std: 

640 val = fixed_num_std[j] 

641 bounds[j] = (val, val) 

642 elif j in range_num_std: 

643 bounds[j] = range_num_std[j] 

644 elif j in choice_num: 

645 lo = float(np.nanmin(choice_num[j])); hi = float(np.nanmax(choice_num[j])) 

646 bounds[j] = (lo, hi) 

647 else: 

648 bounds[j] = (float(wide_lo[j]), float(wide_hi[j])) 

649 

650 # --- helpers 

651 onehot_members = {m for g in groups.values() for m in g["members"]} 

652 numeric_idx = [j for j, nm in enumerate(feature_names) if nm not in onehot_members] 

653 num_bounds = [bounds[j] for j in numeric_idx] 

654 

655 def apply_onehot(vec_std: np.ndarray, base: str, label: str): 

656 for lab in groups[base]["labels"]: 

657 member_name = groups[base]["name_by_label"][lab] 

658 j = name_to_idx[member_name] 

659 raw = 1.0 if lab == label else 0.0 

660 vec_std[j] = _orig_to_std(j, raw, transforms, mu_f, sd_f) 

661 

662 # exploitation objective (μ + soft penalty), with gradient 

663 def obj_grad_exploit(x_full: np.ndarray): 

664 mu_l, g_l = gp_l.mean_and_grad(x_full) 

665 mu = mu_l + cond_mean 

666 mu_p, g_p = gp_s.mean_and_grad(x_full) 

667 z = success_threshold - mu_p 

668 sig = 1.0 / (1.0 + np.exp(-penalty_beta * z)) 

669 penalty = penalty_lambda * (np.log1p(np.exp(penalty_beta * z)) / penalty_beta) 

670 grad_pen = - penalty_lambda * sig * g_p 

671 J = flip * mu + penalty 

672 gJ = flip * g_l + grad_pen 

673 return float(J), gJ.astype(float) 

674 

675 # exploration objective: -EI with feasibility gate (no analytic grad) 

676 def obj_scalar_explore(x_full: np.ndarray) -> float: 

677 mu_l = gp_l.mean_only(x_full[None, :])[0] 

678 mu = float(mu_l + cond_mean) 

679 sd = float(gp_l.sd_at(x_full, include_observation_noise=True)) 

680 ps = float(gp_s.mean_only(x_full[None, :])[0]) 

681 mu_signed, best_signed = _maybe_flip_for_direction(np.array([mu]), float(best_feasible), direction) 

682 gate = 1.0 / (1.0 + np.exp(-penalty_beta * (ps - success_threshold))) 

683 ei = float(_expected_improvement_minimize(mu_signed, np.array([sd]), best_signed)[0]) * gate 

684 return -ei # minimize 

685 

686 # numeric grad by central differences 

687 def _numeric_grad(fun_scalar, x_num: np.ndarray, eps: float = 1e-4) -> np.ndarray: 

688 g = np.zeros_like(x_num, dtype=float) 

689 for i in range(x_num.size): 

690 e = np.zeros_like(x_num); e[i] = eps 

691 g[i] = (fun_scalar(x_num + e) - fun_scalar(x_num - e)) / (2.0 * eps) 

692 return g 

693 

694 def sample_start(): 

695 x = np.zeros(p, float) 

696 for j in numeric_idx: 

697 lo, hi = num_bounds[numeric_idx.index(j)] 

698 x[j] = lo if lo == hi else rng.uniform(lo, hi) 

699 for j, choices in choice_num.items(): 

700 x[j] = choices[np.argmin(np.abs(choices - x[j]))] 

701 for j, v in fixed_num_std.items(): 

702 x[j] = v 

703 return x 

704 

705 # categorical combos 

706 cat_bases = list(groups.keys()) 

707 combo_space = [] 

708 for b in cat_bases: 

709 combo_space.append([cat_fixed[b]] if b in cat_fixed else cat_allowed[b]) 

710 all_label_combos = list(itertools.product(*combo_space)) if combo_space else [()] 

711 

712 # repulsion 

713 rep_sigma2 = float(repulsion) ** 2 

714 rep_weight = float(repulsion) 

715 

716 def _is_dup(xa: np.ndarray, xb: np.ndarray, tol=1e-3) -> bool: 

717 # allow >1 per combo when there are no numeric free dims 

718 if xa.size == 0 and xb.size == 0: 

719 return False 

720 return bool(np.linalg.norm(xa - xb) < tol) 

721 

722 def _accept_row(template, best_xnum, labels, labels_t, accepted_combo, accepted_global, rows): 

723 # compose full point 

724 x_full = template.copy() 

725 x_full[numeric_idx] = best_xnum 

726 for j, choices in choice_num.items(): 

727 x_full[j] = float(choices[np.argmin(np.abs(choices - x_full[j]))]) 

728 for idx in numeric_idx: 

729 lo, hi = num_bounds[numeric_idx.index(idx)] 

730 x_full[idx] = float(np.clip(x_full[idx], lo, hi)) 

731 

732 x_num_std = x_full[numeric_idx].copy() 

733 # dedupe (combo + global) 

734 if any(_is_dup(x_num_std, prev) for prev in accepted_combo): 

735 return False 

736 if any((labels_t == labt) and _is_dup(x_num_std, prev) for labt, prev in accepted_global): 

737 return False 

738 

739 # accept 

740 accepted_combo.append(x_num_std) 

741 accepted_global.append((labels_t, x_num_std)) 

742 

743 mu_l, _ = gp_l.mean_and_grad(x_full); mu = float(mu_l + cond_mean) 

744 ps, _ = gp_s.mean_and_grad(x_full); ps = float(np.clip(ps, 0.0, 1.0)) 

745 sd_opt = float(gp_l.sd_at(x_full, include_observation_noise=True)) 

746 

747 row = { 

748 "pred_p_success": ps, 

749 "pred_target_mean": mu, 

750 "pred_target_sd": sd_opt, 

751 } 

752 onehot_members_local = {m for g in groups.values() for m in g["members"]} 

753 for j, nm in enumerate(feature_names): 

754 if nm in onehot_members_local: 

755 continue 

756 row[nm] = float(_std_to_orig(j, x_full[j], transforms, mu_f, sd_f)) 

757 for b, lab in zip(cat_bases, labels): 

758 row[b] = lab 

759 

760 rows.append(row) 

761 return True 

762 

763 def _optimize_take(template, accepted_combo, use_explore): 

764 # inner objective in numeric subspace 

765 def f_g_only_num(x_num: np.ndarray): 

766 x_full = template.copy() 

767 x_full[numeric_idx] = x_num 

768 

769 def add_repulsion(J: float, g_num: np.ndarray | None): 

770 nonlocal accepted_combo 

771 if accepted_combo and rep_sigma2 > 0.0 and rep_weight > 0.0: 

772 for xk in accepted_combo: 

773 d = x_num - xk 

774 r2 = float(d @ d) 

775 w = math.exp(-0.5 * r2 / rep_sigma2) 

776 J += rep_weight * w 

777 if g_num is not None: 

778 g_num += rep_weight * w * (-d / rep_sigma2) 

779 return J, g_num 

780 

781 if not use_explore: 

782 J, g = obj_grad_exploit(x_full) 

783 J, g_num = add_repulsion(J, g[numeric_idx]) 

784 return float(J), g_num 

785 

786 # exploration branch: -EI, numerical grad 

787 def scalar_for_grad(xn: np.ndarray) -> float: 

788 x_tmp = template.copy() 

789 x_tmp[numeric_idx] = xn 

790 J = obj_scalar_explore(x_tmp) 

791 # include repulsion inside scalar for finite-diff consistency 

792 if accepted_combo and rep_sigma2 > 0.0 and rep_weight > 0.0: 

793 for xk in accepted_combo: 

794 d = xn - xk 

795 r2 = float(d @ d) 

796 w = math.exp(-0.5 * r2 / rep_sigma2) 

797 J += rep_weight * w 

798 return float(J) 

799 

800 J = scalar_for_grad(x_num) 

801 g_num = _numeric_grad(scalar_for_grad, x_num, eps=1e-4) 

802 return float(J), g_num 

803 

804 # collect best from multistarts 

805 from scipy.optimize import fmin_l_bfgs_b 

806 starts = [sample_start()[numeric_idx] for _ in range(n_starts)] 

807 if starts and starts[0].size == 0: 

808 # no numeric dims free → just return a zero-length vector 

809 return np.zeros((0,), float) 

810 

811 best_val = None 

812 best_xnum = None 

813 explore_candidates: list[tuple[np.ndarray, float]] = [] 

814 for x0 in starts: 

815 xopt, fval, _ = fmin_l_bfgs_b( 

816 func=lambda x: f_g_only_num(x), 

817 x0=x0, 

818 fprime=None, 

819 bounds=num_bounds, 

820 maxiter=max_iters, 

821 ) 

822 fval = float(fval) 

823 if not use_explore: 

824 if (best_val is None) or (fval < best_val): 

825 best_val = fval 

826 best_xnum = xopt 

827 else: 

828 # candidate scored by gated EI (no repulsion in score) 

829 x_tmp = template.copy() 

830 x_tmp[numeric_idx] = xopt 

831 gated_ei = -obj_scalar_explore(x_tmp) 

832 if not any(np.linalg.norm(xopt - c[0]) < 1e-3 for c in explore_candidates): 

833 explore_candidates.append((xopt, float(gated_ei))) 

834 

835 if use_explore: 

836 if not explore_candidates: 

837 return None 

838 if softmax_temp and softmax_temp > 0.0: 

839 eis = np.array([ei for _, ei in explore_candidates], dtype=float) 

840 z = eis - np.max(eis) 

841 probs = np.exp(z / float(softmax_temp)) 

842 probs = probs / probs.sum() 

843 idx = rng.choice(len(explore_candidates), p=probs) 

844 return explore_candidates[idx][0] 

845 # greedy EI 

846 idx = int(np.argmax([ei for _, ei in explore_candidates])) 

847 return explore_candidates[idx][0] 

848 return best_xnum 

849 

850 # ---------------- Core loop with dynamic allocation ---------------- 

851 rows: list[dict] = [] 

852 accepted_global: list[tuple[tuple[str, ...], np.ndarray]] = [] 

853 all_label_combos = all_label_combos or [()] 

854 

855 n_combos = max(1, len(all_label_combos)) 

856 for combo_idx, labels in enumerate(all_label_combos): 

857 if len(rows) >= count: 

858 break 

859 labels_t = tuple(labels) if labels else tuple() 

860 

861 template = np.zeros(p, float) 

862 for b, lab in zip(cat_bases, labels): 

863 apply_onehot(template, b, lab) 

864 

865 accepted_combo: list[np.ndarray] = [] 

866 

867 remain_total = count - len(rows) 

868 remain_combos = max(1, n_combos - combo_idx) 

869 k_each = max(1, math.ceil(remain_total / remain_combos)) 

870 

871 takes = 0 

872 while (takes < k_each) and (len(rows) < count): 

873 use_explore = (rng.random() < float(explore)) 

874 best_xnum = _optimize_take(template, accepted_combo, use_explore) 

875 if best_xnum is None: 

876 # try the opposite mode once 

877 best_xnum = _optimize_take(template, accepted_combo, not use_explore) 

878 if best_xnum is None: 

879 break # give up this take for this combo 

880 ok = _accept_row(template, best_xnum, labels, labels_t, accepted_combo, accepted_global, rows) 

881 if ok: 

882 takes += 1 

883 

884 # ---------------- Refill loop if still short ---------------- 

885 if len(rows) < count: 

886 # relax repulsion and try a few refill rounds with fresh starts 

887 for _refill in range(3): 

888 if len(rows) >= count: 

889 break 

890 rep_sigma2 *= 0.7 

891 rep_weight *= 0.7 

892 for combo_idx, labels in enumerate(all_label_combos): 

893 if len(rows) >= count: 

894 break 

895 labels_t = tuple(labels) if labels else tuple() 

896 template = np.zeros(p, float) 

897 for b, lab in zip(cat_bases, labels): 

898 apply_onehot(template, b, lab) 

899 # start with an empty combo-accepted set to avoid over-repelling 

900 accepted_combo = [] 

901 remain_total = count - len(rows) 

902 remain_combos = max(1, n_combos - combo_idx) 

903 k_each = max(1, math.ceil(remain_total / remain_combos)) 

904 takes = 0 

905 while (takes < k_each) and (len(rows) < count): 

906 use_explore = (rng.random() < float(explore)) 

907 best_xnum = _optimize_take(template, accepted_combo, use_explore) 

908 if best_xnum is None: 

909 best_xnum = _optimize_take(template, accepted_combo, not use_explore) 

910 if best_xnum is None: 

911 break 

912 ok = _accept_row(template, best_xnum, labels, labels_t, accepted_combo, accepted_global, rows) 

913 if ok: 

914 takes += 1 

915 

916 # ---------------- Last-resort fill with random projections ---------------- 

917 # Only used if optimization couldn't find enough unique points but space likely allows more. 

918 if len(rows) < count: 

919 tries = 0 

920 max_tries = max(200, 20 * (count - len(rows))) 

921 while (len(rows) < count) and (tries < max_tries): 

922 tries += 1 

923 # random labels 

924 labels = [] 

925 for b in cat_bases: 

926 pool = [cat_fixed[b]] if b in cat_fixed else cat_allowed[b] 

927 labels.append(pool[int(rng.integers(0, len(pool)))]) 

928 labels_t = tuple(labels) if labels else tuple() 

929 # template 

930 template = np.zeros(p, float) 

931 for b, lab in zip(cat_bases, labels): 

932 apply_onehot(template, b, lab) 

933 # random numeric in bounds 

934 x_num = np.zeros(len(numeric_idx), float) 

935 for ii, j in enumerate(numeric_idx): 

936 lo, hi = num_bounds[ii] 

937 x_num[ii] = lo if lo == hi else rng.uniform(lo, hi) 

938 # accept (weak dedupe by numeric subspace) 

939 accepted_combo = [] # local (empty) so only global dedupe applies 

940 _accept_row(template, x_num, labels, labels_t, accepted_combo, accepted_global, rows) 

941 

942 # ---------------- Assemble & rank ---------------- 

943 if not rows: 

944 raise ValueError("No solutions produced; check constraints.") 

945 

946 df = pd.DataFrame(rows) 

947 asc_mu = (direction != "max") 

948 df = df.sort_values(["pred_p_success", "pred_target_mean"], 

949 ascending=[False, asc_mu], 

950 kind="mergesort").reset_index(drop=True) 

951 # trim or pad (should be exact now, but keep the guard) 

952 if len(df) > count: 

953 df = df.head(count) 

954 df["rank"] = np.arange(1, len(df) + 1) 

955 

956 if output: 

957 output = Path(output) 

958 output.parent.mkdir(parents=True, exist_ok=True) 

959 df.to_csv(output, index=False) 

960 

961 try: 

962 console.print(f"\n[bold]Top {len(df)} suggested candidates:[/]") 

963 console.print(df_to_table(df)) # type: ignore[arg-type] 

964 except Exception: 

965 pass 

966 return df 

967 

968 

969def _collapse_onehot_to_categorical(df: pd.DataFrame, groups: dict[str, dict]) -> pd.DataFrame: 

970 """ 

971 Collapse one-hot blocks (e.g. language=Linear A, language=Linear B) into a single 

972 categorical column 'language'. Leaves <NA> only if a row is ambiguous (sum!=1). 

973 """ 

974 out = df.copy() 

975 

976 for base, g in groups.items(): 

977 # column order must match label order 

978 labels = list(g["labels"]) 

979 member_cols = [g["name_by_label"][lab] for lab in labels if g["name_by_label"][lab] in out.columns] 

980 if not member_cols: 

981 continue 

982 

983 # robust numeric block: NaN→0, float for safe sums/argmax 

984 block = out[member_cols].to_numpy(dtype=float) 

985 block = np.nan_to_num(block, nan=0.0, posinf=0.0, neginf=0.0) 

986 

987 row_sums = block.sum(axis=1) 

988 argmax = np.argmax(block, axis=1) 

989 

990 # exactly-one-hot per row (tolerant to tiny fp wiggle) 

991 valid = np.isfinite(row_sums) & (np.abs(row_sums - 1.0) <= 1e-9) 

992 

993 chosen = np.full(len(out), None, dtype=object) 

994 if valid.any(): 

995 lab_arr = np.array(labels, dtype=object) 

996 chosen[valid] = lab_arr[argmax[valid]] 

997 

998 # write the categorical column with proper alignment 

999 out[base] = pd.Series(chosen, index=out.index, dtype="string") 

1000 

1001 # drop the one-hot members 

1002 out.drop(columns=[c for c in member_cols if c in out.columns], inplace=True) 

1003 

1004 return out 

1005 

1006 

1007def _inject_onehot_groups( 

1008 cand_df: pd.DataFrame, 

1009 groups: dict[str, dict], 

1010 rng: np.random.Generator, 

1011 cat_fixed_label: dict[str, str], 

1012 cat_allowed: dict[str, set[str]], 

1013) -> pd.DataFrame: 

1014 """ 

1015 Ensure each one-hot block has exactly one '1' per row (or a fixed label), 

1016 by initializing member columns to 0 then writing the chosen label as 1. 

1017 """ 

1018 out = cand_df.copy() 

1019 n = len(out) 

1020 

1021 for base, g in groups.items(): 

1022 labels = g["labels"] 

1023 member_cols = [g["name_by_label"][lab] for lab in labels] 

1024 

1025 # Create/overwrite member columns with zeros to avoid NaNs 

1026 for col in member_cols: 

1027 out[col] = 0 

1028 

1029 # Allowed labels for this base 

1030 allowed = list(cat_allowed.get(base, set(labels))) 

1031 if not allowed: 

1032 allowed = labels 

1033 

1034 # Choose a label per row 

1035 if base in cat_fixed_label: 

1036 chosen = np.full(n, cat_fixed_label[base], dtype=object) 

1037 else: 

1038 idx = rng.integers(0, len(allowed), size=n) 

1039 chosen = np.array([allowed[i] for i in idx], dtype=object) 

1040 

1041 # Set one-hot = 1 for the chosen label, keep others at 0 

1042 for lab, col in zip(labels, member_cols): 

1043 out.loc[chosen == lab, col] = 1 

1044 

1045 # Enforce integer dtype (clean) 

1046 out[member_cols] = out[member_cols].astype(int) 

1047 

1048 return out 

1049 

1050 

1051def _postfilter_numeric_constraints( 

1052 df: pd.DataFrame, 

1053 user_fixed_num: dict, 

1054 user_ranges_num: dict, 

1055 user_choices_num: dict, 

1056) -> pd.DataFrame: 

1057 """ 

1058 Keep rows satisfying numeric constraints (fixed / ranges / choices). 

1059 Nonexistent columns are ignored. 

1060 """ 

1061 if df.empty: 

1062 return df 

1063 

1064 mask = np.ones(len(df), dtype=bool) 

1065 

1066 # ranges: inclusive 

1067 for k, (lo, hi) in user_ranges_num.items(): 

1068 if k in df.columns: 

1069 mask &= (df[k] >= lo) & (df[k] <= hi) 

1070 

1071 # finite numeric choices 

1072 for k, vals in user_choices_num.items(): 

1073 if k in df.columns: 

1074 mask &= df[k].isin(vals) 

1075 

1076 # fixed values (tolerate tiny float error) 

1077 for k, val in user_fixed_num.items(): 

1078 if k in df.columns: 

1079 col = df[k] 

1080 if pd.api.types.is_integer_dtype(col.dtype): 

1081 mask &= (col == int(round(val))) 

1082 else: 

1083 mask &= np.isfinite(col) & (np.abs(col - float(val)) <= 1e-12) 

1084 

1085 return df.loc[mask].reset_index(drop=True) 

1086 

1087 

1088def optimal( 

1089 model: xr.Dataset | Path | str, 

1090 output: Path | None = None, 

1091 count: int = 10, # ignored (we always return 1) 

1092 n_draws: int = 0, # ignored (mean-only optimizer) 

1093 success_threshold: float = 0.8, 

1094 seed: int | np.random.Generator | None = 42, 

1095 **kwargs, # constraints in ORIGINAL units 

1096) -> pd.DataFrame: 

1097 """ 

1098 Best single candidate by optimizing the GP *mean* posterior under constraints, 

1099 using L-BFGS-B on standardized features. Like `suggest` but returns 1 row. 

1100 

1101 Objective (we minimize): 

1102 J(x) = flip * μ_loss(x) + λ * softplus( threshold - p_success(x) ) 

1103 

1104 Notes: 

1105 • Categorical bases handled by enumeration over allowed labels. 

1106 • Numeric choices are projected to nearest allowed value after optimize. 

1107 • `count` and `n_draws` are ignored (kept for API compatibility). 

1108 """ 

1109 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model) 

1110 

1111 rng = rng_for_dataset(ds, seed) 

1112 

1113 # --- metadata 

1114 feature_names = [str(n) for n in ds["feature"].values.tolist()] 

1115 transforms = [str(t) for t in ds["feature_transform"].values.tolist()] 

1116 mu_f = ds["feature_mean"].values.astype(float) 

1117 sd_f = ds["feature_std"].values.astype(float) 

1118 p = len(feature_names) 

1119 name_to_idx = {nm: j for j, nm in enumerate(feature_names)} 

1120 groups = _groups_from_feature_names(feature_names) # {base:{labels, name_by_label, members}} 

1121 

1122 # --- GP heads (shared helper class) 

1123 gp_s = _GPMarginal( 

1124 Xtr=ds["Xn_train"].values.astype(float), 

1125 ytr=ds["y_success"].values.astype(float), 

1126 ell=ds["map_success_ell"].values.astype(float), 

1127 eta=float(ds["map_success_eta"].values), 

1128 sigma=float(ds["map_success_sigma"].values), 

1129 mean_const=float(ds["map_success_beta0"].values), 

1130 ) 

1131 cond_mean = float(ds["conditional_loss_mean"].values) if "conditional_loss_mean" in ds else 0.0 

1132 gp_l = _GPMarginal( 

1133 Xtr=ds["Xn_success_only"].values.astype(float), 

1134 ytr=ds["y_loss_centered"].values.astype(float), 

1135 ell=ds["map_loss_ell"].values.astype(float), 

1136 eta=float(ds["map_loss_eta"].values), 

1137 sigma=float(ds["map_loss_sigma"].values), 

1138 mean_const=float(ds["map_loss_mean_const"].values), 

1139 ) 

1140 

1141 # --- direction 

1142 direction = str(ds.attrs.get("direction", "min")) 

1143 flip = -1.0 if direction == "max" else 1.0 

1144 

1145 # --- parse constraints (numeric vs categorical) 

1146 cat_allowed: dict[str, list[str]] = {} 

1147 cat_fixed: dict[str, str] = {} 

1148 fixed_num_std: dict[int, float] = {} 

1149 range_num_std: dict[int, tuple[float, float]] = {} 

1150 choice_num: dict[int, np.ndarray] = {} 

1151 

1152 # default allowed = all labels for each base 

1153 for b, g in groups.items(): 

1154 cat_allowed[b] = list(g["labels"]) 

1155 

1156 import re 

1157 def canon_key(k: str) -> str: 

1158 raw = str(k) 

1159 stripped = re.sub(r"[^a-z0-9]+", "", raw.lower()) 

1160 if raw in name_to_idx: 

1161 return raw 

1162 for base in groups.keys(): 

1163 if stripped == re.sub(r"[^a-z0-9]+", "", base.lower()): 

1164 return base 

1165 return raw 

1166 

1167 for k, v in (kwargs or {}).items(): 

1168 ck = canon_key(k) 

1169 if ck in groups: 

1170 labels = groups[ck]["labels"] 

1171 if isinstance(v, str): 

1172 if v not in labels: 

1173 raise ValueError(f"Unknown category for {ck}: {v}. Choices: {labels}") 

1174 cat_fixed[ck] = v 

1175 else: 

1176 L = [x for x in (list(v) if isinstance(v, (list, tuple, set)) else [v]) 

1177 if isinstance(x, str) and x in labels] 

1178 if not L: 

1179 raise ValueError(f"No valid categories for {ck} in {v}. Choices: {labels}") 

1180 if len(L) == 1: 

1181 cat_fixed[ck] = L[0] 

1182 else: 

1183 cat_allowed[ck] = L 

1184 elif ck in name_to_idx: 

1185 j = name_to_idx[ck] 

1186 if isinstance(v, slice): 

1187 lo = _orig_to_std(j, v.start, transforms, mu_f, sd_f) 

1188 hi = _orig_to_std(j, v.stop, transforms, mu_f, sd_f) 

1189 lo, hi = float(np.nanmin([lo, hi])), float(np.nanmax([lo, hi])) 

1190 range_num_std[j] = (lo, hi) 

1191 elif isinstance(v, (list, tuple, np.ndarray)): 

1192 arr = _orig_to_std(j, np.asarray(v, float), transforms, mu_f, sd_f) 

1193 choice_num[j] = np.asarray(arr, float) 

1194 else: 

1195 fixed_num_std[j] = float(_orig_to_std(j, float(v), transforms, mu_f, sd_f)) 

1196 else: 

1197 raise ValueError(f"Unknown constraint key: {k!r}") 

1198 

1199 # --- numeric bounds (std space); allow outside training range 

1200 Xn = ds["Xn_train"].values.astype(float) 

1201 p01 = np.percentile(Xn, 1, axis=0) 

1202 p99 = np.percentile(Xn, 99, axis=0) 

1203 wide_lo = np.minimum(p01 - 1.0, -3.0) 

1204 wide_hi = np.maximum(p99 + 1.0, 3.0) 

1205 

1206 bounds: list[tuple[float, float] | None] = [None]*p 

1207 for j in range(p): 

1208 if j in fixed_num_std: 

1209 v = fixed_num_std[j] 

1210 bounds[j] = (v, v) 

1211 elif j in range_num_std: 

1212 bounds[j] = range_num_std[j] 

1213 elif j in choice_num: 

1214 lo = float(np.nanmin(choice_num[j])); hi = float(np.nanmax(choice_num[j])) 

1215 bounds[j] = (lo, hi) 

1216 else: 

1217 bounds[j] = (float(wide_lo[j]), float(wide_hi[j])) 

1218 

1219 # --- helpers 

1220 def apply_onehot(vec_std: np.ndarray, base: str, label: str): 

1221 for lab in groups[base]["labels"]: 

1222 member_name = groups[base]["name_by_label"][lab] 

1223 j = name_to_idx[member_name] 

1224 raw = 1.0 if lab == label else 0.0 

1225 vec_std[j] = _orig_to_std(j, raw, transforms, mu_f, sd_f) 

1226 

1227 penalty_lambda = 1.0 

1228 penalty_beta = 10.0 

1229 

1230 def obj_grad(x_std_full: np.ndarray) -> tuple[float, np.ndarray]: 

1231 # mean + grad for loss head (centered); add back cond_mean 

1232 mu_l, g_l = gp_l.mean_and_grad(x_std_full) 

1233 mu = mu_l + cond_mean 

1234 # success head (smooth, not clipped) 

1235 mu_p, g_p = gp_s.mean_and_grad(x_std_full) 

1236 # softplus penalty for p<thr 

1237 z = success_threshold - mu_p 

1238 sig = 1.0 / (1.0 + np.exp(-penalty_beta * z)) 

1239 penalty = penalty_lambda * (np.log1p(np.exp(penalty_beta * z)) / penalty_beta) 

1240 grad_pen = - penalty_lambda * sig * g_p 

1241 J = float(flip * mu + penalty) 

1242 gJ = (flip * g_l + grad_pen).astype(float) 

1243 return J, gJ 

1244 

1245 # exclude one-hot members from numeric optimization 

1246 onehot_members = {m for g in groups.values() for m in g["members"]} 

1247 numeric_idx = [j for j, nm in enumerate(feature_names) if nm not in onehot_members] 

1248 num_bounds = [bounds[j] for j in numeric_idx] 

1249 

1250 # uniform starts inside numeric bounds (std space) 

1251 def sample_start() -> np.ndarray: 

1252 x = np.zeros(p, float) 

1253 for j in numeric_idx: 

1254 lo, hi = num_bounds[numeric_idx.index(j)] 

1255 x[j] = lo if lo == hi else rng.uniform(lo, hi) 

1256 for j, choices in choice_num.items(): 

1257 x[j] = choices[np.argmin(np.abs(choices - x[j]))] 

1258 for j, v in fixed_num_std.items(): 

1259 x[j] = v 

1260 return x 

1261 

1262 # enumerate categorical combos (fixed → single) 

1263 cat_bases = list(groups.keys()) 

1264 combo_space = [] 

1265 for b in cat_bases: 

1266 combo_space.append([cat_fixed[b]] if b in cat_fixed else cat_allowed[b]) 

1267 all_label_combos = list(itertools.product(*combo_space)) if combo_space else [()] 

1268 

1269 # --- optimize (multi-start) and keep the single best over all combos 

1270 from scipy.optimize import fmin_l_bfgs_b 

1271 n_starts = 32 

1272 max_iters = 200 

1273 

1274 best_global_val: float | None = None 

1275 best_global_x: np.ndarray | None = None 

1276 best_global_labels: tuple[str, ...] | tuple = tuple() 

1277 

1278 for labels in all_label_combos: 

1279 template = np.zeros(p, float) 

1280 for b, lab in zip(cat_bases, labels): 

1281 apply_onehot(template, b, lab) 

1282 

1283 # numeric-only wrapper 

1284 def f_g_only_num(x_num: np.ndarray): 

1285 x_full = template.copy() 

1286 x_full[numeric_idx] = x_num 

1287 J, g = obj_grad(x_full) 

1288 return J, g[numeric_idx] 

1289 

1290 # multi-starts 

1291 starts = [] 

1292 for _ in range(n_starts): 

1293 s = sample_start() 

1294 for b, lab in zip(cat_bases, labels): 

1295 apply_onehot(s, b, lab) 

1296 starts.append(s[numeric_idx]) 

1297 

1298 # pick best for this combo 

1299 best_val = None 

1300 best_xnum = None 

1301 for x0 in starts: 

1302 xopt, fval, _ = fmin_l_bfgs_b( 

1303 func=lambda x: f_g_only_num(x), 

1304 x0=x0, 

1305 fprime=None, 

1306 bounds=num_bounds, 

1307 maxiter=max_iters, 

1308 ) 

1309 fval = float(fval) 

1310 if (best_val is None) or (fval < best_val): 

1311 best_val = fval 

1312 best_xnum = xopt 

1313 

1314 if best_xnum is None: 

1315 continue 

1316 

1317 # assemble full point, project choices, clip to bounds 

1318 x_full = template.copy() 

1319 x_full[numeric_idx] = best_xnum 

1320 for j, choices in choice_num.items(): 

1321 x_full[j] = float(choices[np.argmin(np.abs(choices - x_full[j]))]) 

1322 for idx in numeric_idx: 

1323 lo, hi = num_bounds[numeric_idx.index(idx)] 

1324 x_full[idx] = float(np.clip(x_full[idx], lo, hi)) 

1325 

1326 if (best_global_val is None) or (best_val < best_global_val): 

1327 best_global_val = float(best_val) 

1328 best_global_x = x_full.copy() 

1329 best_global_labels = tuple(labels) if labels else tuple() 

1330 

1331 if best_global_x is None: 

1332 raise ValueError("No feasible optimum produced; check/relax constraints.") 

1333 

1334 # --- build single-row DataFrame in ORIGINAL units 

1335 x_opt = best_global_x 

1336 mu_l_opt, _ = gp_l.mean_and_grad(x_opt) 

1337 mu_opt = float(mu_l_opt + cond_mean) 

1338 p_opt, _ = gp_s.mean_and_grad(x_opt) 

1339 p_opt = float(np.clip(p_opt, 0.0, 1.0)) 

1340 

1341 sd_opt = gp_l.sd_at(x_opt, include_observation_noise=True) 

1342 

1343 onehot_members = {m for g in groups.values() for m in g["members"]} 

1344 

1345 row: dict[str, object] = { 

1346 "pred_p_success": p_opt, 

1347 "pred_target_mean": mu_opt, 

1348 "pred_target_sd": float(sd_opt), 

1349 "rank": 1, 

1350 } 

1351 # numerics in original units (drop one-hot members) 

1352 for j, nm in enumerate(feature_names): 

1353 if nm in onehot_members: 

1354 continue 

1355 row[nm] = float(_std_to_orig(j, x_opt[j], transforms, mu_f, sd_f)) 

1356 # categorical base columns 

1357 for b, lab in zip(cat_bases, best_global_labels): 

1358 row[b] = lab 

1359 

1360 df = pd.DataFrame([row]) 

1361 

1362 if output: 

1363 output = Path(output) 

1364 output.parent.mkdir(parents=True, exist_ok=True) 

1365 df.to_csv(output, index=False) 

1366 

1367 try: 

1368 console.print(f"\n[bold]Optimal candidate (mean posterior):[/]") 

1369 console.print(df_to_table(df)) # type: ignore[arg-type] 

1370 except Exception: 

1371 pass 

1372 return df 

1373 

1374 

1375def optimal_old( 

1376 model: xr.Dataset | Path | str, 

1377 output: Path | None = None, 

1378 count: int = 10, 

1379 n_draws: int = 0, 

1380 success_threshold: float = 0.8, 

1381 seed: int | np.random.Generator | None = 42, 

1382 **kwargs, 

1383) -> pd.DataFrame: 

1384 """ 

1385 Rank candidates by probability of being the best feasible optimum (min/max), 

1386 honoring numeric *and* categorical constraints. 

1387 

1388 Constraints (original units): 

1389 - number (int/float): fixed value, e.g. epochs=20 

1390 - slice(lo, hi): inclusive float range, e.g. learning_rate=slice(1e-5, 1e-3) 

1391 - list/tuple: finite numeric choices, e.g. batch_size=(16, 32, 64) 

1392 - range(...): converted to tuple of ints (choices) 

1393 - categorical base, e.g. language="Linear B" or language=("Linear A","Linear B") 

1394 (use the *base* name; model stores one-hot members internally) 

1395 """ 

1396 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model) 

1397 pred_success, pred_loss = _build_predictors(ds) 

1398 

1399 if output: 

1400 output = Path(output) 

1401 output.parent.mkdir(parents=True, exist_ok=True) 

1402 

1403 # --- model metadata 

1404 feature_names = list(map(str, ds["feature"].values.tolist())) 

1405 transforms = list(map(str, ds["feature_transform"].values.tolist())) 

1406 feat_mean = ds["feature_mean"].values.astype(float) 

1407 feat_std = ds["feature_std"].values.astype(float) 

1408 

1409 # --- detect categorical one-hot groups from feature names 

1410 groups = _onehot_groups(feature_names) # { base: {"labels":[...], "name_by_label":{label->member}, "members":[...]} } 

1411 

1412 # --- infer numeric search specs from data (includes one-hot members but we’ll drop them below) 

1413 specs_full = _infer_search_specs(ds, feature_names, transforms) 

1414 

1415 # --- split user kwargs into numeric vs categorical constraints 

1416 (groups, # same structure as above (returned for convenience) 

1417 user_fixed_num, # {numeric_feature: value} 

1418 user_ranges_num, # {numeric_feature: (lo, hi)} 

1419 user_choices_num, # {numeric_feature: [choices]} 

1420 cat_fixed_label, # {base: "Label"} (fixed single label) 

1421 cat_allowed) = _split_constraints_for_numeric_and_categorical(feature_names, kwargs) 

1422 

1423 # numeric fixed beats numeric ranges/choices 

1424 for k in list(user_fixed_num.keys()): 

1425 user_ranges_num.pop(k, None) 

1426 user_choices_num.pop(k, None) 

1427 

1428 # --- keep only *numeric* specs (drop one-hot members) 

1429 numeric_specs = _numeric_specs_only(specs_full, groups) 

1430 

1431 # apply numeric bounds/choices, normalize numeric fixed 

1432 _apply_user_bounds(numeric_specs, user_ranges_num, user_choices_num) 

1433 fixed_norm_num = _normalize_fixed(user_fixed_num, numeric_specs) 

1434 

1435 # --- EI baseline: best feasible observed target 

1436 direction = str(ds.attrs.get("direction", "min")) 

1437 best_feasible = _best_feasible_observed(ds, direction) 

1438 flip = -1.0 if direction == "max" else 1.0 

1439 

1440 # --- sample candidate pool 

1441 rng = get_rng(seed) 

1442 target_pool = max(4000, count * 200) # make sure MC has enough variety 

1443 

1444 def _sample_pool(n: int) -> pd.DataFrame: 

1445 # sample numerics 

1446 base_num = _sample_candidates(numeric_specs, n=n, rng=rng, fixed=fixed_norm_num) 

1447 # inject legal one-hot blocks for categoricals 

1448 with_cats = _inject_onehot_groups(base_num, groups, rng, cat_fixed_label, cat_allowed) 

1449 # hard filter numerics (ranges/choices/fixed) 

1450 filtered = _postfilter_numeric_constraints(with_cats, user_fixed_num, user_ranges_num, user_choices_num) 

1451 return filtered 

1452 

1453 cand_df = _sample_pool(target_pool) 

1454 # if tight constraints reduce pool too much, try a few refills 

1455 attempts = 0 

1456 while len(cand_df) < max(count * 50, 1000) and attempts < 6: 

1457 extra = _sample_pool(target_pool) 

1458 if not extra.empty: 

1459 cand_df = pd.concat([cand_df, extra], ignore_index=True).drop_duplicates() 

1460 attempts += 1 

1461 

1462 if cand_df.empty: 

1463 raise ValueError("No candidates satisfy the provided constraints; relax the ranges or choices.") 

1464 

1465 # --- predictions in model space (use full feature order incl. one-hot members) 

1466 Xn_cands = _original_df_to_standardized(cand_df[feature_names], feature_names, transforms, feat_mean, feat_std) 

1467 p = pred_success(Xn_cands) 

1468 mu, sd = pred_loss(Xn_cands, include_observation_noise=True) 

1469 sd = np.maximum(sd, 1e-12) 

1470 

1471 # --- optional feasibility filter 

1472 keep = p >= float(success_threshold) 

1473 if not np.any(keep): 

1474 keep = np.ones_like(p, dtype=bool) 

1475 

1476 cand_df = cand_df.loc[keep].reset_index(drop=True) 

1477 Xn_cands = Xn_cands[keep] 

1478 p = p[keep]; mu = mu[keep]; sd = sd[keep] 

1479 N = len(cand_df) 

1480 if N == 0: 

1481 raise ValueError("All sampled candidates were filtered out by success_threshold.") 

1482 

1483 # --- mean-only mode when n_draws == 0 

1484 if int(n_draws) <= 0: 

1485 result = cand_df.copy() 

1486 result["pred_p_success"] = p 

1487 result["pred_target_mean"] = mu 

1488 result["pred_target_sd"] = sd 

1489 # keep columns for API parity 

1490 result["wins"] = 0 

1491 result["n_draws_effective"] = 0 

1492 result["prob_best_feasible"] = 0.0 

1493 result["conditioned_on"] = _pretty_conditioned_on( 

1494 fixed_norm_numeric=fixed_norm_num, 

1495 cat_fixed_label=cat_fixed_label, 

1496 ) 

1497 

1498 # Direction-aware sort by μ, then lower σ, then higher p 

1499 if str(ds.attrs.get("direction", "min")) == "max": 

1500 sort_cols = ["pred_target_mean", "pred_target_sd", "pred_p_success"] 

1501 ascending = [False, True, False] 

1502 else: # "min" 

1503 sort_cols = ["pred_target_mean", "pred_target_sd", "pred_p_success"] 

1504 ascending = [True, True, False] 

1505 

1506 result_sorted = result.sort_values( 

1507 sort_cols, ascending=ascending, kind="mergesort" 

1508 ).reset_index(drop=True) 

1509 result_sorted["rank_prob_best"] = np.arange(1, len(result_sorted) + 1) 

1510 

1511 top = result_sorted.head(count).reset_index(drop=True) 

1512 # collapse one-hot → single categorical columns (e.g., 'language') 

1513 top_view = _collapse_onehot_to_categorical(top, groups) 

1514 

1515 if output: 

1516 top_view.to_csv(output, index=False) 

1517 

1518 console.print(f"\n[bold]Top {len(top_view)} optimal solutions (mean-only, n_draws=0):[/]") 

1519 console.print(df_to_table(top_view)) 

1520 return top_view 

1521 

1522 # --- Monte Carlo winner-take-all over feasible draws 

1523 Z = mu[:, None] + sd[:, None] * rng.standard_normal((N, n_draws)) 

1524 success_mask = rng.random((N, n_draws)) < p[:, None] 

1525 feasible_draw = success_mask.any(axis=0) 

1526 if not feasible_draw.any(): 

1527 # fallback: deterministic sort (rare) 

1528 result = cand_df.copy() 

1529 result["pred_p_success"] = p 

1530 result["pred_target_mean"] = mu 

1531 result["pred_target_sd"] = sd 

1532 result["prob_best_feasible"] = 0.0 

1533 result["wins"] = 0 

1534 result["n_draws_effective"] = 0 

1535 # prettify conditioning (numeric fixed + categorical fixed) 

1536 result["conditioned_on"] = _pretty_conditioned_on( 

1537 fixed_norm_numeric=fixed_norm_num, 

1538 cat_fixed_label=cat_fixed_label, 

1539 ) 

1540 result_sorted = result.sort_values( 

1541 ["pred_target_mean", "pred_target_sd", "pred_p_success"], 

1542 ascending=[True, True, False], 

1543 kind="mergesort", 

1544 ).reset_index(drop=True) 

1545 result_sorted["rank_prob_best"] = np.arange(1, len(result_sorted) + 1) 

1546 top = result_sorted.head(count).reset_index(drop=True) 

1547 # collapse one-hot → single categorical columns for output 

1548 top_view = _collapse_onehot_to_categorical(top, groups) 

1549 if output: 

1550 top_view.to_csv(output, index=False) 

1551 console.print(f"\n[bold]Top {len(top_view)} optimal solutions:[/]") 

1552 console.print(df_to_table(top_view)) 

1553 return top_view 

1554 

1555 Z_eff = flip * np.where(success_mask, Z, np.inf) 

1556 Zf = Z_eff[:, feasible_draw] 

1557 

1558 winner_idx = np.argmin(Zf, axis=0) 

1559 counts = np.bincount(winner_idx, minlength=N) 

1560 n_eff = int(feasible_draw.sum()) 

1561 prob_best = counts / float(n_eff) 

1562 

1563 result = cand_df.copy() 

1564 result["pred_p_success"] = p 

1565 result["pred_target_mean"] = mu 

1566 result["pred_target_sd"] = sd 

1567 result["wins"] = counts 

1568 result["n_draws_effective"] = n_eff 

1569 result["prob_best_feasible"] = prob_best 

1570 result["conditioned_on"] = _pretty_conditioned_on( 

1571 fixed_norm_numeric=fixed_norm_num, 

1572 cat_fixed_label=cat_fixed_label, 

1573 ) 

1574 

1575 result_sorted = result.sort_values( 

1576 ["prob_best_feasible", "pred_p_success", "pred_target_mean", "pred_target_sd"], 

1577 ascending=[False, False, True, True], 

1578 kind="mergesort", 

1579 ).reset_index(drop=True) 

1580 result_sorted["rank_prob_best"] = np.arange(1, len(result_sorted) + 1) 

1581 

1582 top = result_sorted.head(count).reset_index(drop=True) 

1583 # collapse one-hot → single categorical columns (e.g. 'language') 

1584 top_view = _collapse_onehot_to_categorical(top, groups) 

1585 

1586 if output: 

1587 top_view.to_csv(output, index=False) 

1588 

1589 console.print(f"\n[bold]Top {len(top_view)} optimal solutions:[/]") 

1590 console.print(df_to_table(top_view)) 

1591 return top_view 

1592 

1593 

1594 

1595# ============================================================================= 

1596# Predictors reconstructed from artifact (no PyMC at runtime) 

1597# ============================================================================= 

1598 

1599def _build_predictors(ds: xr.Dataset) -> tuple[ 

1600 Callable[[np.ndarray], np.ndarray], 

1601 Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]] 

1602]: 

1603 """Return (predict_success_probability, predict_conditional_target).""" 

1604 Xn_all = ds["Xn_train"].values.astype(float) 

1605 y_success = ds["y_success"].values.astype(float) # not used, but handy to keep 

1606 Xn_ok = ds["Xn_success_only"].values.astype(float) 

1607 y_loss_centered = ds["y_loss_centered"].values.astype(float) 

1608 

1609 # Success head MAP 

1610 ell_s = ds["map_success_ell"].values.astype(float) 

1611 eta_s = float(ds["map_success_eta"].values) 

1612 sigma_s = float(ds["map_success_sigma"].values) 

1613 beta0_s = float(ds["map_success_beta0"].values) 

1614 

1615 # Loss head MAP 

1616 ell_l = ds["map_loss_ell"].values.astype(float) 

1617 eta_l = float(ds["map_loss_eta"].values) 

1618 sigma_l = float(ds["map_loss_sigma"].values) 

1619 mean_c = float(ds["map_loss_mean_const"].values) 

1620 cond_mean = float(ds["conditional_loss_mean"].values) 

1621 

1622 # Cholesky precomputations 

1623 K_s = kernel_m52_ard(Xn_all, Xn_all, ell_s, eta_s) + (sigma_s**2) * np.eye(Xn_all.shape[0]) 

1624 L_s = np.linalg.cholesky(add_jitter(K_s)) 

1625 alpha_s = solve_chol(L_s, (y_success - beta0_s)) 

1626 

1627 K_l = kernel_m52_ard(Xn_ok, Xn_ok, ell_l, eta_l) + (sigma_l**2) * np.eye(Xn_ok.shape[0]) 

1628 L_l = np.linalg.cholesky(add_jitter(K_l)) 

1629 alpha_l = solve_chol(L_l, (y_loss_centered - mean_c)) 

1630 

1631 def predict_success_probability(Xn: np.ndarray) -> np.ndarray: 

1632 Ks = kernel_m52_ard(Xn, Xn_all, ell_s, eta_s) 

1633 mu = beta0_s + Ks @ alpha_s 

1634 return np.clip(mu, 0.0, 1.0) 

1635 

1636 def predict_conditional_target(Xn: np.ndarray, include_observation_noise: bool = True) -> tuple[np.ndarray, np.ndarray]: 

1637 Kl = kernel_m52_ard(Xn, Xn_ok, ell_l, eta_l) 

1638 mu_c = mean_c + Kl @ alpha_l 

1639 mu = mu_c + cond_mean 

1640 v = solve_lower(L_l, Kl.T) 

1641 var = kernel_diag_m52(Xn, ell_l, eta_l) - np.sum(v * v, axis=0) 

1642 var = np.maximum(var, 1e-12) 

1643 if include_observation_noise: 

1644 var = var + sigma_l**2 

1645 sd = np.sqrt(var) 

1646 return mu, sd 

1647 

1648 return predict_success_probability, predict_conditional_target 

1649 

1650 

1651# ============================================================================= 

1652# Search space, conditioning, and featurization 

1653# ============================================================================= 

1654 

1655def _infer_search_specs( 

1656 ds: xr.Dataset, 

1657 feature_names: list[str], 

1658 transforms: list[str], 

1659 pad_frac: float = 0.10, 

1660) -> dict[str, dict]: 

1661 """ 

1662 Build per-feature search specs from the *original-unit* columns present in the artifact. 

1663 Returns dict: name -> spec, where spec is one of: 

1664 {"type":"float", "lo":float, "hi":float} 

1665 {"type":"int", "lo":int, "hi":int, "step":int (optional)} 

1666 {"type":"choice","choices": list[int|float], "dtype":"int"|"float"} 

1667 """ 

1668 specs: dict[str, dict] = {} 

1669 

1670 df_raw = pd.DataFrame({k: ds[k].values for k in ds.data_vars if ds[k].dims == ("row",)}) 

1671 # prefer top-level columns if present 

1672 for j, name in enumerate(feature_names): 

1673 if name in df_raw.columns: 

1674 vals = pd.to_numeric(pd.Series(df_raw[name]), errors="coerce").dropna().to_numpy() 

1675 else: 

1676 # fallback: reconstruct original units from standardized arrays if needed 

1677 # (in your artifact, raw columns are stored; so this path is rarely used) 

1678 try: 

1679 base_vals = ds[name].values # raw per-row column, if present 

1680 except KeyError: 

1681 # Not stored as a data_var (e.g., one-hot feature); reconstruct from Xn_train 

1682 # j is the feature index in feature_names; transforms[j] is 'identity' or 'log10' 

1683 base_vals = feature_raw_from_artifact_or_reconstruct(ds, j, name, transforms[j]) 

1684 

1685 vals = pd.to_numeric(pd.Series(base_vals), errors="coerce").dropna().to_numpy() 

1686 

1687 

1688 if vals.size == 0: 

1689 # degenerate column; fall back to [0,1] 

1690 specs[name] = {"type": "float", "lo": 0.0, "hi": 1.0} 

1691 continue 

1692 

1693 # detect integer-ish 

1694 intish = np.all(np.isfinite(vals)) and np.allclose(vals, np.round(vals)) 

1695 

1696 # robust bounds with padding 

1697 p1, p99 = np.percentile(vals, [1, 99]) 

1698 span = max(p99 - p1, 1e-12) 

1699 lo = p1 - pad_frac * span 

1700 hi = p99 + pad_frac * span 

1701 

1702 if intish: 

1703 lo_i = int(np.floor(lo)) 

1704 hi_i = int(np.ceil(hi)) 

1705 specs[name] = {"type": "int", "lo": lo_i, "hi": hi_i} 

1706 else: 

1707 specs[name] = {"type": "float", "lo": float(lo), "hi": float(hi)} 

1708 return specs 

1709 

1710 

1711def _normalize_fixed( 

1712 fixed_raw: dict[str, object], 

1713 specs: dict[str, dict], 

1714) -> dict[str, object]: 

1715 """ 

1716 Normalize user constraints to sanitized forms within inferred bounds. 

1717 Keeps the *shape*: 

1718 - number (int/float) -> fixed (clipped to [lo,hi]) 

1719 - slice(lo, hi) -> float range (clipped to [lo,hi]) 

1720 - list/tuple -> finite choices (filtered to within [lo,hi], cast to int for int specs) 

1721 Returns a dict usable directly by _sample_candidates. 

1722 """ 

1723 fixed_norm: dict[str, object] = {} 

1724 

1725 for name, val in (fixed_raw or {}).items(): 

1726 if name not in specs: 

1727 # unknown feature already warned upstream; skip silently here 

1728 continue 

1729 

1730 sp = specs[name] 

1731 typ = sp["type"] 

1732 

1733 # helper clamps 

1734 def _clip_float(x: float) -> float: 

1735 return float(np.clip(x, sp["lo"], sp["hi"])) 

1736 

1737 def _clip_int(x: int) -> int: 

1738 lo, hi = int(sp.get("lo", x)), int(sp.get("hi", x)) 

1739 return int(np.clip(int(round(x)), lo, hi)) 

1740 

1741 # numeric fixed 

1742 if isinstance(val, (int, float, np.number)): 

1743 if typ == "int": 

1744 fixed_norm[name] = _clip_int(int(round(val))) 

1745 elif typ == "choice" and sp.get("dtype") == "int": 

1746 fixed_norm[name] = _clip_int(int(round(val))) 

1747 else: 

1748 fixed_norm[name] = _clip_float(float(val)) 

1749 continue 

1750 

1751 # float range via slice(lo, hi) 

1752 if isinstance(val, slice): 

1753 lo = float(val.start) 

1754 hi = float(val.stop) 

1755 if lo > hi: 

1756 lo, hi = hi, lo 

1757 if typ in ("float", "choice") and sp.get("dtype") != "int": 

1758 lo_c = _clip_float(lo); hi_c = _clip_float(hi) 

1759 if lo_c > hi_c: lo_c, hi_c = hi_c, lo_c 

1760 fixed_norm[name] = slice(lo_c, hi_c) 

1761 else: 

1762 # int spec: convert to inclusive integer tuple 

1763 lo_i = _clip_int(int(np.floor(lo))) 

1764 hi_i = _clip_int(int(np.ceil(hi))) 

1765 choices = tuple(range(lo_i, hi_i + 1)) 

1766 fixed_norm[name] = choices 

1767 continue 

1768 

1769 # choices via list/tuple 

1770 if isinstance(val, (list, tuple)): 

1771 if typ in ("int",) or (typ == "choice" and sp.get("dtype") == "int"): 

1772 vv = [ _clip_int(int(round(x))) for x in val ] 

1773 # de-dup and sort 

1774 vv = sorted(set(vv)) 

1775 if not vv: 

1776 # fallback to center 

1777 center = _clip_int(int(np.round((sp["lo"] + sp["hi"]) / 2))) 

1778 vv = [center] 

1779 fixed_norm[name] = tuple(vv) 

1780 else: 

1781 vv = [ _clip_float(float(x)) for x in val ] 

1782 vv = sorted(set(vv)) 

1783 if not vv: 

1784 center = _clip_float((sp["lo"] + sp["hi"]) / 2.0) 

1785 vv = [center] 

1786 # keep list/tuple shape (tuple preferred) 

1787 fixed_norm[name] = tuple(vv) 

1788 continue 

1789 

1790 # otherwise: ignore incompatible type 

1791 # (you could raise here if you prefer a hard failure) 

1792 return fixed_norm 

1793 

1794 

1795def _sample_candidates( 

1796 specs: dict[str, dict], 

1797 n: int, 

1798 rng: np.random.Generator, 

1799 fixed: dict[str, object] | None = None, 

1800) -> pd.DataFrame: 

1801 """ 

1802 Sample n candidates in ORIGINAL units given search specs and optional fixed constraints. 

1803 """ 

1804 fixed = fixed or {} 

1805 cols: dict[str, np.ndarray] = {} 

1806 

1807 for name, sp in specs.items(): 

1808 typ = sp["type"] 

1809 

1810 # If fixed: honor numeric / slice / choices shape 

1811 if name in fixed: 

1812 val = fixed[name] 

1813 

1814 # numeric: constant column 

1815 if isinstance(val, (int, float, np.number)): 

1816 cols[name] = np.full(n, val, dtype=float) 

1817 

1818 # float range slice 

1819 elif isinstance(val, slice): 

1820 lo = float(val.start); hi = float(val.stop) 

1821 if lo > hi: lo, hi = hi, lo 

1822 cols[name] = rng.uniform(lo, hi, size=n) 

1823 

1824 # choices: list/tuple -> sample from set 

1825 elif isinstance(val, (list, tuple)): 

1826 arr = np.array(val, dtype=float) 

1827 if arr.size == 0: 

1828 # fallback to center of spec 

1829 if typ == "int": 

1830 center = int(np.round((sp["lo"] + sp["hi"]) / 2)) 

1831 arr = np.array([center], dtype=float) 

1832 else: 

1833 center = (sp["lo"] + sp["hi"]) / 2.0 

1834 arr = np.array([center], dtype=float) 

1835 idx = rng.integers(0, len(arr), size=n) 

1836 cols[name] = arr[idx] 

1837 

1838 else: 

1839 # unknown fixed type; fallback to spec sampling 

1840 if typ == "choice": 

1841 choices = np.asarray(sp["choices"], dtype=float) 

1842 idx = rng.integers(0, len(choices), size=n) 

1843 cols[name] = choices[idx] 

1844 elif typ == "int": 

1845 cols[name] = rng.integers(int(sp["lo"]), int(sp["hi"]) + 1, size=n).astype(float) 

1846 else: 

1847 cols[name] = rng.uniform(sp["lo"], sp["hi"], size=n) 

1848 

1849 else: 

1850 # Not fixed: sample from spec 

1851 if typ == "choice": 

1852 choices = np.asarray(sp["choices"], dtype=float) 

1853 idx = rng.integers(0, len(choices), size=n) 

1854 cols[name] = choices[idx] 

1855 elif typ == "int": 

1856 cols[name] = rng.integers(int(sp["lo"]), int(sp["hi"]) + 1, size=n).astype(float) 

1857 else: 

1858 cols[name] = rng.uniform(sp["lo"], sp["hi"], size=n) 

1859 

1860 df = pd.DataFrame(cols) 

1861 # ensure integer columns are ints if the spec says so (pretty output) 

1862 for name, sp in specs.items(): 

1863 if sp["type"] == "int" or (sp["type"] == "choice" and sp.get("dtype") == "int"): 

1864 df[name] = df[name].round().astype(int) 

1865 return df 

1866 

1867 

1868def _original_df_to_standardized( 

1869 df: pd.DataFrame, 

1870 feature_names: list[str], 

1871 transforms: list[str], 

1872 feat_mean: np.ndarray, 

1873 feat_std: np.ndarray, 

1874) -> np.ndarray: 

1875 cols = [] 

1876 for j, name in enumerate(feature_names): 

1877 x = df[name].to_numpy().astype(float) 

1878 tr = transforms[j] 

1879 if tr == "log10": 

1880 x = np.where(x <= 0, np.nan, x) 

1881 x = np.log10(x) 

1882 cols.append((x - feat_mean[j]) / feat_std[j]) 

1883 return np.column_stack(cols).astype(float) 

1884 

1885 

1886# ============================================================================= 

1887# Acquisition functions & utilities 

1888# ============================================================================= 

1889 

1890def _expected_improvement_minimize(mu: np.ndarray, sd: np.ndarray, best_y: float) -> np.ndarray: 

1891 sd = np.maximum(sd, 1e-12) 

1892 z = (best_y - mu) / sd 

1893 Phi = ndtr(z) 

1894 phi = np.exp(-0.5 * z * z) / np.sqrt(2.0 * np.pi) 

1895 return sd * (z * Phi + phi) 

1896 

1897 

1898def _constrained_EI(mu: np.ndarray, sd: np.ndarray, p_success: np.ndarray, best_y: float, 

1899 p_threshold: float = 0.8, softness: float = 0.05) -> np.ndarray: 

1900 ei = _expected_improvement_minimize(mu, sd, best_y) 

1901 s = 1.0 / (1.0 + np.exp(-(p_success - p_threshold) / max(softness, 1e-6))) 

1902 return ei * s 

1903 

1904 

1905def _exploration_score(sd_loss: np.ndarray, p_success: np.ndarray, 

1906 w_sd: float = 1.0, w_boundary: float = 0.5) -> np.ndarray: 

1907 return w_sd * sd_loss + w_boundary * (p_success * (1.0 - p_success)) 

1908 

1909 

1910def _novelty_score(Xn_cands: np.ndarray, Xn_seen: np.ndarray) -> np.ndarray: 

1911 m = Xn_cands.shape[0] 

1912 batch = 1024 

1913 out = np.empty(m, dtype=float) 

1914 for i in range(0, m, batch): 

1915 sl = slice(i, min(i + batch, m)) 

1916 diff = Xn_cands[sl, None, :] - Xn_seen[None, :, :] 

1917 d = np.linalg.norm(diff, axis=2) 

1918 out[sl] = np.min(d, axis=1) 

1919 return out 

1920 

1921 

1922def _maybe_flip_for_direction(mu: np.ndarray, best_y: float, direction: str) -> tuple[np.ndarray, float]: 

1923 if direction == "max": 

1924 return -mu, -best_y 

1925 return mu, best_y 

1926 

1927 

1928def _best_feasible_observed(ds: xr.Dataset, direction: str) -> float: 

1929 y_ok = ds["y_loss_success"].values.astype(float) 

1930 if y_ok.size == 0: 

1931 return np.inf if direction != "max" else -np.inf 

1932 if direction == "max": 

1933 return float(np.nanmax(y_ok)) 

1934 return float(np.nanmin(y_ok)) 

1935 

1936 

1937def _is_number(x) -> bool: 

1938 return isinstance(x, (int, float, np.integer, np.floating)) 

1939 

1940 

1941def _fmt_num(x) -> str: 

1942 try: 

1943 return f"{float(x):.6g}" 

1944 except Exception: 

1945 return str(x) 

1946 

1947 

1948def _fixed_as_string(fixed: dict) -> str: 

1949 """ 

1950 Human-readable constraints: 

1951 - number -> k=12 or k=0.00123 

1952 - slice -> k=lo:hi (inclusive; None shows as -inf/inf) 

1953 - list/tuple -> k=[v1, v2, ...] 

1954 - range -> k=[start, stop, step] (rare; usually normalized earlier) 

1955 - other scalars (str/bool) -> k=value 

1956 Keys are sorted for stability. 

1957 """ 

1958 parts: list[str] = [] 

1959 for k in sorted(fixed.keys()): 

1960 v = fixed[k] 

1961 if isinstance(v, slice): 

1962 a = "-inf" if v.start is None else _fmt_num(v.start) 

1963 b = "inf" if v.stop is None else _fmt_num(v.stop) 

1964 parts.append(f"{k}={a}:{b}") 

1965 elif isinstance(v, range): 

1966 parts.append(f"{k}=[{', '.join(_fmt_num(u) for u in (v.start, v.stop, v.step))}]") 

1967 elif isinstance(v, (list, tuple, np.ndarray)): 

1968 elems = ", ".join(_fmt_num(u) if _is_number(u) else str(u) for u in v) 

1969 parts.append(f"{k}=[{elems}]") 

1970 elif _is_number(v): 

1971 parts.append(f"{k}={_fmt_num(v)}") 

1972 else: 

1973 # fallback for str/bool/other scalars 

1974 parts.append(f"{k}={v}") 

1975 return ", ".join(parts) 

1976 

1977 

1978def _apply_user_bounds( 

1979 specs: dict[str, dict[str, Any]], 

1980 ranges: dict[str, tuple[float, float]], 

1981 choices: dict[str, list[float]], 

1982) -> None: 

1983 """ 

1984 Mutate `specs` with user-provided bounds/choices. 

1985 """ 

1986 for name, (lo, hi) in ranges.items(): 

1987 if name not in specs: 

1988 continue 

1989 sp = specs[name] 

1990 sp["kind"] = sp.get("kind", "float") 

1991 if sp["kind"] == "choice": 

1992 # Convert to float/int range if user provided range for a choice var 

1993 sp["kind"] = "float" 

1994 sp["low"] = float(lo) 

1995 sp["high"] = float(hi) 

1996 sp.pop("choices", None) 

1997 

1998 for name, opts in choices.items(): 

1999 if name not in specs: 

2000 continue 

2001 sp = specs[name] 

2002 # Keep kind="choice" and store list 

2003 sp["kind"] = "choice" 

2004 # Cast ints if all values are close to ints 

2005 if all(abs(v - round(v)) < 1e-12 for v in opts): 

2006 sp["choices"] = [int(round(v)) for v in opts] 

2007 else: 

2008 sp["choices"] = [float(v) for v in opts] 

2009 # Drop bounds (not used for choice) 

2010 sp.pop("low", None) 

2011 sp.pop("high", None)