Coverage for psyop/opt.py: 41.45%
1228 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-10 06:02 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-10 06:02 +0000
1# opt.py
2# -*- coding: utf-8 -*-
4from pathlib import Path
5from typing import Callable, Any
6import re
8import numpy as np
9import pandas as pd
10import xarray as xr
11import hashlib
12from scipy.special import ndtr # Φ(z), vectorized
14from .util import get_rng, df_to_table
15from .model import (
16 kernel_diag_m52,
17 kernel_m52_ard,
18 add_jitter,
19 solve_chol,
20 solve_lower,
21)
22from .model import feature_raw_from_artifact_or_reconstruct
24from rich.console import Console
25from rich.table import Table
27console = Console()
29_ONEHOT_RE = re.compile(r"^(?P<base>[^=]+)=(?P<label>.+)$")
32def _pretty_conditioned_on(
33 fixed_norm_numeric: dict | None = None,
34 cat_fixed_label: dict | None = None,
35) -> str:
36 """
37 Combine numeric fixed constraints (already normalized to model space)
38 with categorical fixed choices into a single human-readable string.
40 Examples:
41 - fixed_norm_numeric = {"epochs": 12.0, "batch_size": 32}
42 - cat_fixed_label = {"language": "Linear B"}
44 Returns:
45 "epochs=12, batch_size=32, language=Linear B"
46 (ordering is deterministic: keys sorted within each group)
47 """
48 fixed_norm_numeric = fixed_norm_numeric or {}
49 cat_fixed_label = cat_fixed_label or {}
51 parts = []
53 # Prefer the project-standard formatter if present.
54 try:
55 if fixed_norm_numeric:
56 txt = _fixed_as_string(fixed_norm_numeric) # e.g. "epochs=12, batch_size=32"
57 if txt:
58 parts.append(txt)
59 except Exception:
60 # Fallback: simple k=v with general formatting.
61 if fixed_norm_numeric:
62 items = []
63 for k, v in sorted(fixed_norm_numeric.items()):
64 try:
65 items.append(f"{k}={float(v):.6g}")
66 except Exception:
67 items.append(f"{k}={v}")
68 parts.append(", ".join(items))
70 # Append categorical fixed choices as "base=Label"
71 if cat_fixed_label:
72 cat_txt = ", ".join(f"{b}={lab}" for b, lab in sorted(cat_fixed_label.items()))
73 if cat_txt:
74 parts.append(cat_txt)
76 return ", ".join(p for p in parts if p)
79def _split_constraints_for_numeric_and_categorical(
80 feature_names: list[str],
81 kwargs: dict[str, object],
82):
83 """
84 Split user constraints into:
85 - numeric: user_fixed, user_ranges, user_choices_num (by feature name)
86 - categorical: cat_fixed_label (base->label), cat_allowed (base->set(labels))
87 - and return one-hot groups
89 Interp rules:
90 * For a categorical base key (e.g. 'language'):
91 - str -> fixed single label
92 - list/tuple of str -> allowed label set
93 * For a numeric feature key (non one-hot member):
94 - number -> fixed
95 - slice(lo,hi[,step]) -> range (lo,hi) inclusive on ends in post-filter
96 - list/tuple of numbers -> finite choices
97 - range(...) (python range) -> tuple of ints (choices)
98 """
99 groups = _onehot_groups(feature_names)
100 bases = set(groups.keys())
101 feature_set = set(feature_names)
103 user_fixed: dict[str, float] = {}
104 user_ranges: dict[str, tuple[float, float]] = {}
105 user_choices_num: dict[str, list[int | float]] = {}
107 cat_fixed_label: dict[str, str] = {}
108 cat_allowed: dict[str, set[str]] = {}
110 # helper
111 def _is_intlike(x) -> bool:
112 try:
113 return float(int(round(float(x)))) == float(x)
114 except Exception:
115 return False
117 for key, raw in (kwargs or {}).items():
118 # --- CATEGORICAL (by base key, not member name) ---
119 if key in bases:
120 labels = groups[key]["labels"]
121 # fixed single label
122 if isinstance(raw, str):
123 if raw not in labels:
124 raise ValueError(f"Unknown category for {key!r}: {raw!r}. Choices: {labels}")
125 cat_fixed_label[key] = raw
126 cat_allowed[key] = {raw}
127 continue
128 # list/tuple of labels (choices restriction)
129 if isinstance(raw, (list, tuple, set)):
130 chosen = [v for v in raw if isinstance(v, str) and (v in labels)]
131 if not chosen:
132 raise ValueError(f"No valid categories for {key!r} in {raw!r}. Choices: {labels}")
133 cat_allowed[key] = set(chosen)
134 continue
135 # anything else -> ignore for cats
136 continue
138 # --- NUMERIC (by feature name; skip one-hot member names) ---
139 # If user accidentally passes member name 'language=Linear A', ignore here
140 if key not in feature_set or _ONEHOT_RE.match(key):
141 # Unknown or member-level keys are ignored at this stage
142 continue
144 # python range -> tuple of ints
145 if isinstance(raw, range):
146 raw = tuple(raw)
148 # number -> fixed
149 if isinstance(raw, (int, float, np.number)):
150 val = float(raw)
151 if np.isfinite(val):
152 user_fixed[key] = val
153 continue
155 # slice -> float range
156 if isinstance(raw, slice):
157 if raw.start is None or raw.stop is None:
158 continue
159 lo = float(raw.start); hi = float(raw.stop)
160 if not (np.isfinite(lo) and np.isfinite(hi)):
161 continue
162 if lo > hi:
163 lo, hi = hi, lo
164 user_ranges[key] = (lo, hi)
165 continue
167 # list/tuple -> numeric choices
168 if isinstance(raw, (list, tuple)):
169 if len(raw) == 0:
170 continue
171 # preserve ints if all int-like, else floats
172 if all(_is_intlike(v) for v in raw):
173 user_choices_num[key] = [int(round(float(v))) for v in raw]
174 else:
175 user_choices_num[key] = [float(v) for v in raw]
176 continue
178 # otherwise: ignore
180 # Numeric fixed wins over its own range/choices
181 for k in list(user_fixed.keys()):
182 user_ranges.pop(k, None)
183 user_choices_num.pop(k, None)
185 return groups, user_fixed, user_ranges, user_choices_num, cat_fixed_label, cat_allowed
188def _detect_categorical_groups(feature_names: list[str]) -> dict[str, list[tuple[str, str]]]:
189 """
190 Detect one-hot groups: {"language": [("language=Linear A","Linear A"), ("language=Linear B","Linear B"), ...]}
191 """
192 groups: dict[str, list[tuple[str, str]]] = {}
193 for name in feature_names:
194 m = _ONEHOT_RE.match(name)
195 if not m:
196 continue
197 base = m.group("base")
198 lab = m.group("label")
199 groups.setdefault(base, []).append((name, lab))
200 # deterministic order
201 for base in groups:
202 groups[base].sort(key=lambda t: t[1])
203 return groups
205def _project_categoricals_to_valid_onehot(df: pd.DataFrame, groups: dict[str, list[tuple[str, str]]]) -> pd.DataFrame:
206 """
207 For each categorical group ensure exactly one column is 1 and the rest 0 (argmax projection).
208 Works whether columns are 0/1 already or arbitrary scores in [0,1].
209 """
210 for base, pairs in groups.items():
211 cols = [name for name, _ in pairs if name in df.columns]
212 if len(cols) <= 1:
213 continue
214 sub = df[cols].to_numpy(dtype=float)
215 # treat NaNs as -inf so they never win
216 sub = np.where(np.isfinite(sub), sub, -np.inf)
217 if sub.size == 0:
218 continue
219 idx = np.argmax(sub, axis=1)
220 new = np.zeros_like(sub)
221 new[np.arange(sub.shape[0]), idx] = 1.0
222 df.loc[:, cols] = new
223 return df
226def _apply_categorical_constraints(df: pd.DataFrame,
227 groups: dict[str, list[tuple[str, str]]],
228 fixed_str: dict[str, str],
229 allowed_strs: dict[str, list[str]]) -> pd.DataFrame:
230 """
231 Filter rows by categorical constraints expressed on the base names, e.g.
232 fixed_str = {"language": "Linear B"}
233 allowed_strs = {"language": ["Linear A", "Linear B"]}
234 Operates on one-hot columns, so call BEFORE collapsing to string columns.
235 """
236 mask = np.ones(len(df), dtype=bool)
237 for base, val in (fixed_str or {}).items():
238 if base not in groups:
239 continue
240 cols = {label: name for name, label in groups[base] if name in df.columns}
241 want = cols.get(val)
242 if want is None:
243 # no matching one-hot column — drop all rows
244 mask &= False
245 else:
246 mask &= (df[want] >= 0.5) # after projection, exactly 1 column is 1
247 for base, vals in (allowed_strs or {}).items():
248 if base not in groups:
249 continue
250 cols = {label: name for name, label in groups[base] if name in df.columns}
251 want_cols = [cols[v] for v in vals if v in cols]
252 if want_cols:
253 mask &= (df[want_cols].sum(axis=1) >= 0.5)
254 else:
255 mask &= False
256 return df.loc[mask].reset_index(drop=True)
259def _onehot_groups(feature_names: list[str]) -> dict[str, dict]:
260 """
261 Detect one-hot groups among feature names like 'language=Linear A'.
262 Returns:
263 {
264 base: {
265 "labels": [label1, ...],
266 "members": [(feat_name, label), ...],
267 "name_by_label": {label: feat_name}
268 },
269 ...
270 }
271 """
272 groups: dict[str, dict] = {}
273 for name in feature_names:
274 m = _ONEHOT_RE.match(name)
275 if not m:
276 continue
277 base = m.group("base")
278 label = m.group("label")
279 g = groups.setdefault(base, {"labels": [], "members": [], "name_by_label": {}})
280 g["labels"].append(label)
281 g["members"].append((name, label))
282 g["name_by_label"][label] = name
283 # stable order for labels
284 for g in groups.values():
285 # keep insertion order from feature_names, but ensure uniqueness
286 seen = set()
287 uniq = []
288 for lab in g["labels"]:
289 if lab not in seen:
290 uniq.append(lab); seen.add(lab)
291 g["labels"] = uniq
292 return groups
296def _numeric_specs_only(search_specs: dict, groups: dict) -> dict:
297 """
298 Return a copy of search_specs with one-hot member feature names removed.
299 `groups` is the output of _onehot_groups(feature_names).
300 """
301 if not groups:
302 return dict(search_specs)
304 onehot_member_names = set()
305 for g in groups.values():
306 onehot_member_names.update(g["name_by_label"].values())
308 return {k: v for k, v in search_specs.items() if k not in onehot_member_names}
311def _assert_valid_onehot(df: pd.DataFrame, groups: dict[str, dict], where: str = "") -> None:
312 """
313 Assert every one-hot block has exactly one '1' per row (no NaNs).
314 Prints a small diagnostic if not.
315 """
316 for base, g in groups.items():
317 member_cols = [g["name_by_label"][lab] for lab in g["labels"] if g["name_by_label"][lab] in df.columns]
318 if not member_cols:
319 print(f"[onehot] {where}: base={base} has no member columns present")
320 continue
322 block = df[member_cols].to_numpy()
323 nonfinite_mask = ~np.isfinite(block)
324 sums = np.nan_to_num(block, nan=0.0, posinf=0.0, neginf=0.0).sum(axis=1)
326 bad = np.where(nonfinite_mask.any(axis=1) | (sums != 1))[0]
327 if bad.size:
328 print(f"[BUG onehot] {where}: base={base}, rows with invalid one-hot: {bad[:20].tolist()} (showing first 20)")
329 print("member_cols:", member_cols)
330 print(df.iloc[bad[:5]][member_cols]) # show a few bad rows
331 raise RuntimeError(f"Invalid one-hot block for base={base} at {where}")
333def _get_float_attr(obj, names, default=0.0):
334 for n in names:
335 if hasattr(obj, n):
336 v = getattr(obj, n)
337 # skip boolean flags like mean_only
338 if isinstance(v, (bool, np.bool_)):
339 continue
340 try:
341 return float(v)
342 except Exception:
343 pass
344 return float(default)
348import itertools
349import numpy as np
350import pandas as pd
351from pathlib import Path
353# ---------- small utils (reuse in this file) ----------
354def _orig_to_std(j: int, x, transforms, mu, sd):
355 arr = np.asarray(x, dtype=float)
356 if transforms[j] == "log10":
357 arr = np.where(arr <= 0, np.nan, arr)
358 arr = np.log10(arr)
359 return (arr - mu[j]) / sd[j]
361def _std_to_orig(j: int, arr, transforms, mu, sd):
362 x = np.asarray(arr, float) * sd[j] + mu[j]
363 if transforms[j] == "log10":
364 x = np.power(10.0, x)
365 return x
367def _groups_from_feature_names(feature_names: list[str]) -> dict:
368 # same grouping logic you already use elsewhere
369 groups = {}
370 for nm in feature_names:
371 if "=" in nm:
372 base, lab = nm.split("=", 1)
373 g = groups.setdefault(base, {"labels": [], "name_by_label": {}, "members": []})
374 g["labels"].append(lab)
375 g["name_by_label"][lab] = nm
376 g["members"].append(nm)
377 # Stable order
378 for b in groups:
379 labs = list(dict.fromkeys(groups[b]["labels"]))
380 groups[b]["labels"] = labs
381 groups[b]["members"] = [groups[b]["name_by_label"][lab] for lab in labs]
382 return groups
384def _pick_attr(obj, names, allow_none=False):
385 """Return the first present attribute in names without truth-testing arrays."""
386 for n in names:
387 if hasattr(obj, n):
388 v = getattr(obj, n)
389 if (v is not None) or allow_none:
390 return v
391 return None
393# ---------- analytic GP (mean + grad) ----------
394class _GPMarginal:
395 def __init__(self, Xtr, ytr, ell, eta, sigma, mean_const):
396 self.X = np.asarray(Xtr, float) # (N,p)
397 self.y = np.asarray(ytr, float) # (N,)
398 self.ell = np.asarray(ell, float) # (p,)
399 self.eta = float(eta)
400 self.sigma = float(sigma)
401 self.m = float(mean_const)
402 K = kernel_m52_ard(self.X, self.X, self.ell, self.eta)
403 K[np.diag_indices_from(K)] += self.sigma**2
404 L = np.linalg.cholesky(add_jitter(K))
405 self.L = L
406 self.alpha = solve_chol(L, (self.y - self.m))
407 self.X_train = getattr(self, "X_train", getattr(self, "Xtr", getattr(self, "X", None)))
408 self.ell = getattr(self, "ell", getattr(self, "ls", None))
409 # Back-compat aliases:
410 self.Xtr = self.X_train
411 self.ls = self.ell
412 self.Xtr = self.X_train
413 self.ell = ell; self.ls = self.ell
414 self.m0 = float(mean_const); self.mean_const = self.m0
416 def sd_at(self, x: np.ndarray, include_observation_noise: bool = True) -> float:
417 """
418 Predictive standard deviation at a single standardized point x.
419 """
420 x = np.asarray(x, float).reshape(1, -1) # (1, p)
421 Ks = kernel_m52_ard(x, self.Xtr, self.ls, self.eta) # (1, N)
422 v = solve_lower(self.L, Ks.T) # (N, 1)
423 kss = kernel_diag_m52(x, self.ls, self.eta)[0] # scalar diag K(x,x) = eta^2
424 var = float(kss - np.sum(v * v))
425 if include_observation_noise:
426 var += float(self.sigma ** 2)
427 var = max(var, 1e-12)
428 return float(np.sqrt(var))
430 def _k_and_grad(self, x):
431 """k(x, X), ∂k/∂x (p-dimensional gradient aggregated over train points)."""
432 x = np.asarray(x, float).reshape(1, -1) # (1,p)
433 X = self.X
434 ell = self.ell
435 eta = self.eta
437 # distances in lengthscale space
438 D = (x[:, None, :] - X[None, :, :]) / ell[None, None, :] # (1,N,p)
439 r2 = np.sum(D*D, axis=2) # (1,N)
440 r = np.sqrt(np.maximum(r2, 0.0)) # (1,N)
441 sqrt5_r = np.sqrt(5.0) * r
442 # kernel
443 k = (eta**2) * (1.0 + sqrt5_r + (5.0/3.0)*r2) * np.exp(-sqrt5_r) # (1,N)
445 # grad wrt x: -(5η^2/3) e^{-√5 r} (1 + √5 r) * (x - xi)/ell^2
446 # handle r=0 safely -> derivative is 0
447 coef = -(5.0 * (eta**2) / 3.0) * np.exp(-sqrt5_r) * (1.0 + sqrt5_r) # (1,N)
448 S = (x[:, None, :] - X[None, :, :]) / (ell[None, None, :]**2) # (1,N,p)
449 grad = np.sum(coef[:, :, None] * S, axis=1) # (1,p)
451 return k.ravel(), grad.ravel()
453 def mean_and_grad(self, x: np.ndarray):
454 # --- resolve training matrix
455 Xtr = _pick_attr(self, ["X_train", "Xtr", "X"])
456 if Xtr is None:
457 raise AttributeError("GPMarginal: training inputs not found (tried X_train, Xtr, X).")
458 Xtr = np.asarray(Xtr, float)
460 # --- resolve hyperparams / vectors
461 ell = _pick_attr(self, ["ell", "ls"])
462 if ell is None:
463 raise AttributeError("GPMarginal: lengthscales not found (tried ell, ls).")
464 ell = np.asarray(ell, float)
466 eta = _get_float_attr(self, ["eta"])
467 alpha = _pick_attr(self, ["alpha", "alpha_vec"])
468 if alpha is None:
469 raise AttributeError("GPMarginal: alpha not found (tried alpha, alpha_vec).")
470 alpha = np.asarray(alpha, float).ravel()
472 # mean constant (name differs across versions)
473 m0 = _get_float_attr(self, ["mean_const", "m0", "beta0", "mean_c", "mean"], default=0.0)
475 # --- mean
476 Ks = kernel_m52_ard(x[None, :], Xtr, ell, eta).ravel() # (N_train,)
477 mu = float(m0 + Ks @ alpha)
479 # --- gradient wrt x (shape (p,))
480 grad_k = _grad_k_m52_ard_wrt_x(x, Xtr, ell, eta) # (N_train, p)
481 grad_mu = grad_k.T @ alpha # (p,)
483 return mu, grad_mu
486 def mean_only(self, X):
487 Ks = kernel_m52_ard(X, self.X, self.ell, self.eta)
488 return self.m + Ks @ self.alpha
491def _grad_k_m52_ard_wrt_x(x: np.ndarray, Xtr: np.ndarray, ls: np.ndarray, eta: float) -> np.ndarray:
492 """
493 ∂k(x, Xtr_i)/∂x for Matérn 5/2 ARD.
494 Returns (N_train, p) — one row per training point.
495 """
496 x = np.asarray(x, float).reshape(1, -1) # (1, p)
497 Xtr = np.asarray(Xtr, float) # (N, p)
498 ls = np.asarray(ls, float).reshape(1, -1) # (1, p)
500 diff = x - Xtr # (N, p)
501 z = diff / ls # (N, p)
502 r = np.sqrt(np.sum(z*z, axis=1)) # (N,)
504 sr5 = np.sqrt(5.0)
505 coef = -(5.0 * (eta**2) / 3.0) * np.exp(-sr5 * r) * (1.0 + sr5 * r) # (N,)
506 grad = coef[:, None] * (diff / (ls*ls)) # (N, p)
507 return grad
510def rng_for_dataset(ds, seed=None):
511 if isinstance(seed, np.random.Generator):
512 return seed
514 # Hash something stable from the dataset; Xn_train is fine.
515 x = np.ascontiguousarray(ds["Xn_train"].values.astype(np.float64))
516 digest64 = int.from_bytes(hashlib.sha256(x.tobytes()).digest()[:8], "big") # 64 bits
518 if seed is None:
519 mixed = np.uint64(digest64) # dataset-deterministic
520 else:
521 mixed = np.uint64(seed) ^ np.uint64(digest64) # mix user seed with dataset hash
523 return np.random.default_rng(int(mixed))
525def suggest(
526 model: xr.Dataset | Path | str,
527 count: int = 10,
528 output: Path | None = None,
529 repulsion: float = 0.34, # repulsion radius & weight
530 explore: float = 0.5, # probability to optimize EI (explore)
531 success_threshold: float = 0.8,
532 softmax_temp: float | None = 0.2, # τ for EI softmax; None/0 => greedy EI
533 n_starts: int = 32,
534 max_iters: int = 200,
535 penalty_lambda: float = 1.0,
536 penalty_beta: float = 10.0,
537 direction: str | None = None, # defaults to model's
538 seed: int | np.random.Generator | None = 42,
539 **kwargs, # constraints in ORIGINAL units
540) -> pd.DataFrame:
541 import itertools, math
542 import numpy as np, pandas as pd
544 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model)
545 rng = rng_for_dataset(ds, seed) # dataset-aware determinism
547 # --- metadata
548 feature_names = [str(n) for n in ds["feature"].values.tolist()]
549 transforms = [str(t) for t in ds["feature_transform"].values.tolist()]
550 mu_f = ds["feature_mean"].values.astype(float)
551 sd_f = ds["feature_std"].values.astype(float)
552 p = len(feature_names)
553 name_to_idx = {nm: j for j, nm in enumerate(feature_names)}
554 groups = _groups_from_feature_names(feature_names)
556 # --- GP heads
557 gp_s = _GPMarginal(
558 Xtr=ds["Xn_train"].values.astype(float),
559 ytr=ds["y_success"].values.astype(float),
560 ell=ds["map_success_ell"].values.astype(float),
561 eta=float(ds["map_success_eta"].values),
562 sigma=float(ds["map_success_sigma"].values),
563 mean_const=float(ds["map_success_beta0"].values),
564 )
565 cond_mean = float(ds["conditional_loss_mean"].values) if "conditional_loss_mean" in ds else 0.0
566 gp_l = _GPMarginal(
567 Xtr=ds["Xn_success_only"].values.astype(float),
568 ytr=ds["y_loss_centered"].values.astype(float),
569 ell=ds["map_loss_ell"].values.astype(float),
570 eta=float(ds["map_loss_eta"].values),
571 sigma=float(ds["map_loss_sigma"].values),
572 mean_const=float(ds["map_loss_mean_const"].values),
573 )
575 # --- direction & EI baseline
576 if direction is None:
577 direction = str(ds.attrs.get("direction", "min"))
578 flip = -1.0 if direction == "max" else 1.0
579 best_feasible = _best_feasible_observed(ds, direction)
581 # --- constraints -> bounds (std space)
582 cat_allowed: dict[str, list[str]] = {b: list(g["labels"]) for b, g in groups.items()}
583 cat_fixed: dict[str, str] = {}
584 fixed_num_std: dict[int, float] = {}
585 range_num_std: dict[int, tuple[float, float]] = {}
586 choice_num: dict[int, np.ndarray] = {}
588 def canon_key(k: str) -> str:
589 import re as _re
590 raw = str(k)
591 stripped = _re.sub(r"[^a-z0-9]+", "", raw.lower())
592 if raw in name_to_idx: return raw
593 for base in groups.keys():
594 if stripped == _re.sub(r"[^a-z0-9]+", "", base.lower()):
595 return base
596 return raw
598 for k, v in (kwargs or {}).items():
599 ck = canon_key(k)
600 if ck in groups: # categorical base
601 labels = groups[ck]["labels"]
602 if isinstance(v, str):
603 if v not in labels:
604 raise ValueError(f"Unknown category for {ck}: {v}. Choices: {labels}")
605 cat_fixed[ck] = v
606 else:
607 L = [x for x in (list(v) if isinstance(v, (list, tuple, set)) else [v]) if isinstance(x, str) and x in labels]
608 if not L:
609 raise ValueError(f"No valid categories for {ck} in {v}. Choices: {labels}")
610 if len(L) == 1:
611 cat_fixed[ck] = L[0]
612 else:
613 cat_allowed[ck] = L
614 elif ck in name_to_idx: # numeric
615 j = name_to_idx[ck]
616 if isinstance(v, range):
617 v = tuple(v)
618 if isinstance(v, slice):
619 lo = _orig_to_std(j, v.start, transforms, mu_f, sd_f)
620 hi = _orig_to_std(j, v.stop, transforms, mu_f, sd_f)
621 lo, hi = float(np.nanmin([lo, hi])), float(np.nanmax([lo, hi]))
622 range_num_std[j] = (lo, hi)
623 elif isinstance(v, (list, tuple, np.ndarray)):
624 arr = _orig_to_std(j, np.asarray(v, float), transforms, mu_f, sd_f)
625 choice_num[j] = np.asarray(arr, float)
626 else:
627 fixed_num_std[j] = float(_orig_to_std(j, float(v), transforms, mu_f, sd_f))
628 else:
629 raise ValueError(f"Unknown constraint key: {k!r}")
631 Xn = ds["Xn_train"].values.astype(float)
632 p01 = np.percentile(Xn, 1, axis=0)
633 p99 = np.percentile(Xn, 99, axis=0)
634 wide_lo = np.minimum(p01 - 1.0, -3.0) # allow outside training range
635 wide_hi = np.maximum(p99 + 1.0, 3.0)
637 bounds: list[tuple[float, float] | None] = [None]*p
638 for j in range(p):
639 if j in fixed_num_std:
640 val = fixed_num_std[j]
641 bounds[j] = (val, val)
642 elif j in range_num_std:
643 bounds[j] = range_num_std[j]
644 elif j in choice_num:
645 lo = float(np.nanmin(choice_num[j])); hi = float(np.nanmax(choice_num[j]))
646 bounds[j] = (lo, hi)
647 else:
648 bounds[j] = (float(wide_lo[j]), float(wide_hi[j]))
650 # --- helpers
651 onehot_members = {m for g in groups.values() for m in g["members"]}
652 numeric_idx = [j for j, nm in enumerate(feature_names) if nm not in onehot_members]
653 num_bounds = [bounds[j] for j in numeric_idx]
655 def apply_onehot(vec_std: np.ndarray, base: str, label: str):
656 for lab in groups[base]["labels"]:
657 member_name = groups[base]["name_by_label"][lab]
658 j = name_to_idx[member_name]
659 raw = 1.0 if lab == label else 0.0
660 vec_std[j] = _orig_to_std(j, raw, transforms, mu_f, sd_f)
662 # exploitation objective (μ + soft penalty), with gradient
663 def obj_grad_exploit(x_full: np.ndarray):
664 mu_l, g_l = gp_l.mean_and_grad(x_full)
665 mu = mu_l + cond_mean
666 mu_p, g_p = gp_s.mean_and_grad(x_full)
667 z = success_threshold - mu_p
668 sig = 1.0 / (1.0 + np.exp(-penalty_beta * z))
669 penalty = penalty_lambda * (np.log1p(np.exp(penalty_beta * z)) / penalty_beta)
670 grad_pen = - penalty_lambda * sig * g_p
671 J = flip * mu + penalty
672 gJ = flip * g_l + grad_pen
673 return float(J), gJ.astype(float)
675 # exploration objective: -EI with feasibility gate (no analytic grad)
676 def obj_scalar_explore(x_full: np.ndarray) -> float:
677 mu_l = gp_l.mean_only(x_full[None, :])[0]
678 mu = float(mu_l + cond_mean)
679 sd = float(gp_l.sd_at(x_full, include_observation_noise=True))
680 ps = float(gp_s.mean_only(x_full[None, :])[0])
681 mu_signed, best_signed = _maybe_flip_for_direction(np.array([mu]), float(best_feasible), direction)
682 gate = 1.0 / (1.0 + np.exp(-penalty_beta * (ps - success_threshold)))
683 ei = float(_expected_improvement_minimize(mu_signed, np.array([sd]), best_signed)[0]) * gate
684 return -ei # minimize
686 # numeric grad by central differences
687 def _numeric_grad(fun_scalar, x_num: np.ndarray, eps: float = 1e-4) -> np.ndarray:
688 g = np.zeros_like(x_num, dtype=float)
689 for i in range(x_num.size):
690 e = np.zeros_like(x_num); e[i] = eps
691 g[i] = (fun_scalar(x_num + e) - fun_scalar(x_num - e)) / (2.0 * eps)
692 return g
694 def sample_start():
695 x = np.zeros(p, float)
696 for j in numeric_idx:
697 lo, hi = num_bounds[numeric_idx.index(j)]
698 x[j] = lo if lo == hi else rng.uniform(lo, hi)
699 for j, choices in choice_num.items():
700 x[j] = choices[np.argmin(np.abs(choices - x[j]))]
701 for j, v in fixed_num_std.items():
702 x[j] = v
703 return x
705 # categorical combos
706 cat_bases = list(groups.keys())
707 combo_space = []
708 for b in cat_bases:
709 combo_space.append([cat_fixed[b]] if b in cat_fixed else cat_allowed[b])
710 all_label_combos = list(itertools.product(*combo_space)) if combo_space else [()]
712 # repulsion
713 rep_sigma2 = float(repulsion) ** 2
714 rep_weight = float(repulsion)
716 def _is_dup(xa: np.ndarray, xb: np.ndarray, tol=1e-3) -> bool:
717 # allow >1 per combo when there are no numeric free dims
718 if xa.size == 0 and xb.size == 0:
719 return False
720 return bool(np.linalg.norm(xa - xb) < tol)
722 def _accept_row(template, best_xnum, labels, labels_t, accepted_combo, accepted_global, rows):
723 # compose full point
724 x_full = template.copy()
725 x_full[numeric_idx] = best_xnum
726 for j, choices in choice_num.items():
727 x_full[j] = float(choices[np.argmin(np.abs(choices - x_full[j]))])
728 for idx in numeric_idx:
729 lo, hi = num_bounds[numeric_idx.index(idx)]
730 x_full[idx] = float(np.clip(x_full[idx], lo, hi))
732 x_num_std = x_full[numeric_idx].copy()
733 # dedupe (combo + global)
734 if any(_is_dup(x_num_std, prev) for prev in accepted_combo):
735 return False
736 if any((labels_t == labt) and _is_dup(x_num_std, prev) for labt, prev in accepted_global):
737 return False
739 # accept
740 accepted_combo.append(x_num_std)
741 accepted_global.append((labels_t, x_num_std))
743 mu_l, _ = gp_l.mean_and_grad(x_full); mu = float(mu_l + cond_mean)
744 ps, _ = gp_s.mean_and_grad(x_full); ps = float(np.clip(ps, 0.0, 1.0))
745 sd_opt = float(gp_l.sd_at(x_full, include_observation_noise=True))
747 row = {
748 "pred_p_success": ps,
749 "pred_target_mean": mu,
750 "pred_target_sd": sd_opt,
751 }
752 onehot_members_local = {m for g in groups.values() for m in g["members"]}
753 for j, nm in enumerate(feature_names):
754 if nm in onehot_members_local:
755 continue
756 row[nm] = float(_std_to_orig(j, x_full[j], transforms, mu_f, sd_f))
757 for b, lab in zip(cat_bases, labels):
758 row[b] = lab
760 rows.append(row)
761 return True
763 def _optimize_take(template, accepted_combo, use_explore):
764 # inner objective in numeric subspace
765 def f_g_only_num(x_num: np.ndarray):
766 x_full = template.copy()
767 x_full[numeric_idx] = x_num
769 def add_repulsion(J: float, g_num: np.ndarray | None):
770 nonlocal accepted_combo
771 if accepted_combo and rep_sigma2 > 0.0 and rep_weight > 0.0:
772 for xk in accepted_combo:
773 d = x_num - xk
774 r2 = float(d @ d)
775 w = math.exp(-0.5 * r2 / rep_sigma2)
776 J += rep_weight * w
777 if g_num is not None:
778 g_num += rep_weight * w * (-d / rep_sigma2)
779 return J, g_num
781 if not use_explore:
782 J, g = obj_grad_exploit(x_full)
783 J, g_num = add_repulsion(J, g[numeric_idx])
784 return float(J), g_num
786 # exploration branch: -EI, numerical grad
787 def scalar_for_grad(xn: np.ndarray) -> float:
788 x_tmp = template.copy()
789 x_tmp[numeric_idx] = xn
790 J = obj_scalar_explore(x_tmp)
791 # include repulsion inside scalar for finite-diff consistency
792 if accepted_combo and rep_sigma2 > 0.0 and rep_weight > 0.0:
793 for xk in accepted_combo:
794 d = xn - xk
795 r2 = float(d @ d)
796 w = math.exp(-0.5 * r2 / rep_sigma2)
797 J += rep_weight * w
798 return float(J)
800 J = scalar_for_grad(x_num)
801 g_num = _numeric_grad(scalar_for_grad, x_num, eps=1e-4)
802 return float(J), g_num
804 # collect best from multistarts
805 from scipy.optimize import fmin_l_bfgs_b
806 starts = [sample_start()[numeric_idx] for _ in range(n_starts)]
807 if starts and starts[0].size == 0:
808 # no numeric dims free → just return a zero-length vector
809 return np.zeros((0,), float)
811 best_val = None
812 best_xnum = None
813 explore_candidates: list[tuple[np.ndarray, float]] = []
814 for x0 in starts:
815 xopt, fval, _ = fmin_l_bfgs_b(
816 func=lambda x: f_g_only_num(x),
817 x0=x0,
818 fprime=None,
819 bounds=num_bounds,
820 maxiter=max_iters,
821 )
822 fval = float(fval)
823 if not use_explore:
824 if (best_val is None) or (fval < best_val):
825 best_val = fval
826 best_xnum = xopt
827 else:
828 # candidate scored by gated EI (no repulsion in score)
829 x_tmp = template.copy()
830 x_tmp[numeric_idx] = xopt
831 gated_ei = -obj_scalar_explore(x_tmp)
832 if not any(np.linalg.norm(xopt - c[0]) < 1e-3 for c in explore_candidates):
833 explore_candidates.append((xopt, float(gated_ei)))
835 if use_explore:
836 if not explore_candidates:
837 return None
838 if softmax_temp and softmax_temp > 0.0:
839 eis = np.array([ei for _, ei in explore_candidates], dtype=float)
840 z = eis - np.max(eis)
841 probs = np.exp(z / float(softmax_temp))
842 probs = probs / probs.sum()
843 idx = rng.choice(len(explore_candidates), p=probs)
844 return explore_candidates[idx][0]
845 # greedy EI
846 idx = int(np.argmax([ei for _, ei in explore_candidates]))
847 return explore_candidates[idx][0]
848 return best_xnum
850 # ---------------- Core loop with dynamic allocation ----------------
851 rows: list[dict] = []
852 accepted_global: list[tuple[tuple[str, ...], np.ndarray]] = []
853 all_label_combos = all_label_combos or [()]
855 n_combos = max(1, len(all_label_combos))
856 for combo_idx, labels in enumerate(all_label_combos):
857 if len(rows) >= count:
858 break
859 labels_t = tuple(labels) if labels else tuple()
861 template = np.zeros(p, float)
862 for b, lab in zip(cat_bases, labels):
863 apply_onehot(template, b, lab)
865 accepted_combo: list[np.ndarray] = []
867 remain_total = count - len(rows)
868 remain_combos = max(1, n_combos - combo_idx)
869 k_each = max(1, math.ceil(remain_total / remain_combos))
871 takes = 0
872 while (takes < k_each) and (len(rows) < count):
873 use_explore = (rng.random() < float(explore))
874 best_xnum = _optimize_take(template, accepted_combo, use_explore)
875 if best_xnum is None:
876 # try the opposite mode once
877 best_xnum = _optimize_take(template, accepted_combo, not use_explore)
878 if best_xnum is None:
879 break # give up this take for this combo
880 ok = _accept_row(template, best_xnum, labels, labels_t, accepted_combo, accepted_global, rows)
881 if ok:
882 takes += 1
884 # ---------------- Refill loop if still short ----------------
885 if len(rows) < count:
886 # relax repulsion and try a few refill rounds with fresh starts
887 for _refill in range(3):
888 if len(rows) >= count:
889 break
890 rep_sigma2 *= 0.7
891 rep_weight *= 0.7
892 for combo_idx, labels in enumerate(all_label_combos):
893 if len(rows) >= count:
894 break
895 labels_t = tuple(labels) if labels else tuple()
896 template = np.zeros(p, float)
897 for b, lab in zip(cat_bases, labels):
898 apply_onehot(template, b, lab)
899 # start with an empty combo-accepted set to avoid over-repelling
900 accepted_combo = []
901 remain_total = count - len(rows)
902 remain_combos = max(1, n_combos - combo_idx)
903 k_each = max(1, math.ceil(remain_total / remain_combos))
904 takes = 0
905 while (takes < k_each) and (len(rows) < count):
906 use_explore = (rng.random() < float(explore))
907 best_xnum = _optimize_take(template, accepted_combo, use_explore)
908 if best_xnum is None:
909 best_xnum = _optimize_take(template, accepted_combo, not use_explore)
910 if best_xnum is None:
911 break
912 ok = _accept_row(template, best_xnum, labels, labels_t, accepted_combo, accepted_global, rows)
913 if ok:
914 takes += 1
916 # ---------------- Last-resort fill with random projections ----------------
917 # Only used if optimization couldn't find enough unique points but space likely allows more.
918 if len(rows) < count:
919 tries = 0
920 max_tries = max(200, 20 * (count - len(rows)))
921 while (len(rows) < count) and (tries < max_tries):
922 tries += 1
923 # random labels
924 labels = []
925 for b in cat_bases:
926 pool = [cat_fixed[b]] if b in cat_fixed else cat_allowed[b]
927 labels.append(pool[int(rng.integers(0, len(pool)))])
928 labels_t = tuple(labels) if labels else tuple()
929 # template
930 template = np.zeros(p, float)
931 for b, lab in zip(cat_bases, labels):
932 apply_onehot(template, b, lab)
933 # random numeric in bounds
934 x_num = np.zeros(len(numeric_idx), float)
935 for ii, j in enumerate(numeric_idx):
936 lo, hi = num_bounds[ii]
937 x_num[ii] = lo if lo == hi else rng.uniform(lo, hi)
938 # accept (weak dedupe by numeric subspace)
939 accepted_combo = [] # local (empty) so only global dedupe applies
940 _accept_row(template, x_num, labels, labels_t, accepted_combo, accepted_global, rows)
942 # ---------------- Assemble & rank ----------------
943 if not rows:
944 raise ValueError("No solutions produced; check constraints.")
946 df = pd.DataFrame(rows)
947 asc_mu = (direction != "max")
948 df = df.sort_values(["pred_p_success", "pred_target_mean"],
949 ascending=[False, asc_mu],
950 kind="mergesort").reset_index(drop=True)
951 # trim or pad (should be exact now, but keep the guard)
952 if len(df) > count:
953 df = df.head(count)
954 df["rank"] = np.arange(1, len(df) + 1)
956 if output:
957 output = Path(output)
958 output.parent.mkdir(parents=True, exist_ok=True)
959 df.to_csv(output, index=False)
961 try:
962 console.print(f"\n[bold]Top {len(df)} suggested candidates:[/]")
963 console.print(df_to_table(df)) # type: ignore[arg-type]
964 except Exception:
965 pass
966 return df
969def _collapse_onehot_to_categorical(df: pd.DataFrame, groups: dict[str, dict]) -> pd.DataFrame:
970 """
971 Collapse one-hot blocks (e.g. language=Linear A, language=Linear B) into a single
972 categorical column 'language'. Leaves <NA> only if a row is ambiguous (sum!=1).
973 """
974 out = df.copy()
976 for base, g in groups.items():
977 # column order must match label order
978 labels = list(g["labels"])
979 member_cols = [g["name_by_label"][lab] for lab in labels if g["name_by_label"][lab] in out.columns]
980 if not member_cols:
981 continue
983 # robust numeric block: NaN→0, float for safe sums/argmax
984 block = out[member_cols].to_numpy(dtype=float)
985 block = np.nan_to_num(block, nan=0.0, posinf=0.0, neginf=0.0)
987 row_sums = block.sum(axis=1)
988 argmax = np.argmax(block, axis=1)
990 # exactly-one-hot per row (tolerant to tiny fp wiggle)
991 valid = np.isfinite(row_sums) & (np.abs(row_sums - 1.0) <= 1e-9)
993 chosen = np.full(len(out), None, dtype=object)
994 if valid.any():
995 lab_arr = np.array(labels, dtype=object)
996 chosen[valid] = lab_arr[argmax[valid]]
998 # write the categorical column with proper alignment
999 out[base] = pd.Series(chosen, index=out.index, dtype="string")
1001 # drop the one-hot members
1002 out.drop(columns=[c for c in member_cols if c in out.columns], inplace=True)
1004 return out
1007def _inject_onehot_groups(
1008 cand_df: pd.DataFrame,
1009 groups: dict[str, dict],
1010 rng: np.random.Generator,
1011 cat_fixed_label: dict[str, str],
1012 cat_allowed: dict[str, set[str]],
1013) -> pd.DataFrame:
1014 """
1015 Ensure each one-hot block has exactly one '1' per row (or a fixed label),
1016 by initializing member columns to 0 then writing the chosen label as 1.
1017 """
1018 out = cand_df.copy()
1019 n = len(out)
1021 for base, g in groups.items():
1022 labels = g["labels"]
1023 member_cols = [g["name_by_label"][lab] for lab in labels]
1025 # Create/overwrite member columns with zeros to avoid NaNs
1026 for col in member_cols:
1027 out[col] = 0
1029 # Allowed labels for this base
1030 allowed = list(cat_allowed.get(base, set(labels)))
1031 if not allowed:
1032 allowed = labels
1034 # Choose a label per row
1035 if base in cat_fixed_label:
1036 chosen = np.full(n, cat_fixed_label[base], dtype=object)
1037 else:
1038 idx = rng.integers(0, len(allowed), size=n)
1039 chosen = np.array([allowed[i] for i in idx], dtype=object)
1041 # Set one-hot = 1 for the chosen label, keep others at 0
1042 for lab, col in zip(labels, member_cols):
1043 out.loc[chosen == lab, col] = 1
1045 # Enforce integer dtype (clean)
1046 out[member_cols] = out[member_cols].astype(int)
1048 return out
1051def _postfilter_numeric_constraints(
1052 df: pd.DataFrame,
1053 user_fixed_num: dict,
1054 user_ranges_num: dict,
1055 user_choices_num: dict,
1056) -> pd.DataFrame:
1057 """
1058 Keep rows satisfying numeric constraints (fixed / ranges / choices).
1059 Nonexistent columns are ignored.
1060 """
1061 if df.empty:
1062 return df
1064 mask = np.ones(len(df), dtype=bool)
1066 # ranges: inclusive
1067 for k, (lo, hi) in user_ranges_num.items():
1068 if k in df.columns:
1069 mask &= (df[k] >= lo) & (df[k] <= hi)
1071 # finite numeric choices
1072 for k, vals in user_choices_num.items():
1073 if k in df.columns:
1074 mask &= df[k].isin(vals)
1076 # fixed values (tolerate tiny float error)
1077 for k, val in user_fixed_num.items():
1078 if k in df.columns:
1079 col = df[k]
1080 if pd.api.types.is_integer_dtype(col.dtype):
1081 mask &= (col == int(round(val)))
1082 else:
1083 mask &= np.isfinite(col) & (np.abs(col - float(val)) <= 1e-12)
1085 return df.loc[mask].reset_index(drop=True)
1088def optimal(
1089 model: xr.Dataset | Path | str,
1090 output: Path | None = None,
1091 count: int = 10, # ignored (we always return 1)
1092 n_draws: int = 0, # ignored (mean-only optimizer)
1093 success_threshold: float = 0.8,
1094 seed: int | np.random.Generator | None = 42,
1095 **kwargs, # constraints in ORIGINAL units
1096) -> pd.DataFrame:
1097 """
1098 Best single candidate by optimizing the GP *mean* posterior under constraints,
1099 using L-BFGS-B on standardized features. Like `suggest` but returns 1 row.
1101 Objective (we minimize):
1102 J(x) = flip * μ_loss(x) + λ * softplus( threshold - p_success(x) )
1104 Notes:
1105 • Categorical bases handled by enumeration over allowed labels.
1106 • Numeric choices are projected to nearest allowed value after optimize.
1107 • `count` and `n_draws` are ignored (kept for API compatibility).
1108 """
1109 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model)
1111 rng = rng_for_dataset(ds, seed)
1113 # --- metadata
1114 feature_names = [str(n) for n in ds["feature"].values.tolist()]
1115 transforms = [str(t) for t in ds["feature_transform"].values.tolist()]
1116 mu_f = ds["feature_mean"].values.astype(float)
1117 sd_f = ds["feature_std"].values.astype(float)
1118 p = len(feature_names)
1119 name_to_idx = {nm: j for j, nm in enumerate(feature_names)}
1120 groups = _groups_from_feature_names(feature_names) # {base:{labels, name_by_label, members}}
1122 # --- GP heads (shared helper class)
1123 gp_s = _GPMarginal(
1124 Xtr=ds["Xn_train"].values.astype(float),
1125 ytr=ds["y_success"].values.astype(float),
1126 ell=ds["map_success_ell"].values.astype(float),
1127 eta=float(ds["map_success_eta"].values),
1128 sigma=float(ds["map_success_sigma"].values),
1129 mean_const=float(ds["map_success_beta0"].values),
1130 )
1131 cond_mean = float(ds["conditional_loss_mean"].values) if "conditional_loss_mean" in ds else 0.0
1132 gp_l = _GPMarginal(
1133 Xtr=ds["Xn_success_only"].values.astype(float),
1134 ytr=ds["y_loss_centered"].values.astype(float),
1135 ell=ds["map_loss_ell"].values.astype(float),
1136 eta=float(ds["map_loss_eta"].values),
1137 sigma=float(ds["map_loss_sigma"].values),
1138 mean_const=float(ds["map_loss_mean_const"].values),
1139 )
1141 # --- direction
1142 direction = str(ds.attrs.get("direction", "min"))
1143 flip = -1.0 if direction == "max" else 1.0
1145 # --- parse constraints (numeric vs categorical)
1146 cat_allowed: dict[str, list[str]] = {}
1147 cat_fixed: dict[str, str] = {}
1148 fixed_num_std: dict[int, float] = {}
1149 range_num_std: dict[int, tuple[float, float]] = {}
1150 choice_num: dict[int, np.ndarray] = {}
1152 # default allowed = all labels for each base
1153 for b, g in groups.items():
1154 cat_allowed[b] = list(g["labels"])
1156 import re
1157 def canon_key(k: str) -> str:
1158 raw = str(k)
1159 stripped = re.sub(r"[^a-z0-9]+", "", raw.lower())
1160 if raw in name_to_idx:
1161 return raw
1162 for base in groups.keys():
1163 if stripped == re.sub(r"[^a-z0-9]+", "", base.lower()):
1164 return base
1165 return raw
1167 for k, v in (kwargs or {}).items():
1168 ck = canon_key(k)
1169 if ck in groups:
1170 labels = groups[ck]["labels"]
1171 if isinstance(v, str):
1172 if v not in labels:
1173 raise ValueError(f"Unknown category for {ck}: {v}. Choices: {labels}")
1174 cat_fixed[ck] = v
1175 else:
1176 L = [x for x in (list(v) if isinstance(v, (list, tuple, set)) else [v])
1177 if isinstance(x, str) and x in labels]
1178 if not L:
1179 raise ValueError(f"No valid categories for {ck} in {v}. Choices: {labels}")
1180 if len(L) == 1:
1181 cat_fixed[ck] = L[0]
1182 else:
1183 cat_allowed[ck] = L
1184 elif ck in name_to_idx:
1185 j = name_to_idx[ck]
1186 if isinstance(v, slice):
1187 lo = _orig_to_std(j, v.start, transforms, mu_f, sd_f)
1188 hi = _orig_to_std(j, v.stop, transforms, mu_f, sd_f)
1189 lo, hi = float(np.nanmin([lo, hi])), float(np.nanmax([lo, hi]))
1190 range_num_std[j] = (lo, hi)
1191 elif isinstance(v, (list, tuple, np.ndarray)):
1192 arr = _orig_to_std(j, np.asarray(v, float), transforms, mu_f, sd_f)
1193 choice_num[j] = np.asarray(arr, float)
1194 else:
1195 fixed_num_std[j] = float(_orig_to_std(j, float(v), transforms, mu_f, sd_f))
1196 else:
1197 raise ValueError(f"Unknown constraint key: {k!r}")
1199 # --- numeric bounds (std space); allow outside training range
1200 Xn = ds["Xn_train"].values.astype(float)
1201 p01 = np.percentile(Xn, 1, axis=0)
1202 p99 = np.percentile(Xn, 99, axis=0)
1203 wide_lo = np.minimum(p01 - 1.0, -3.0)
1204 wide_hi = np.maximum(p99 + 1.0, 3.0)
1206 bounds: list[tuple[float, float] | None] = [None]*p
1207 for j in range(p):
1208 if j in fixed_num_std:
1209 v = fixed_num_std[j]
1210 bounds[j] = (v, v)
1211 elif j in range_num_std:
1212 bounds[j] = range_num_std[j]
1213 elif j in choice_num:
1214 lo = float(np.nanmin(choice_num[j])); hi = float(np.nanmax(choice_num[j]))
1215 bounds[j] = (lo, hi)
1216 else:
1217 bounds[j] = (float(wide_lo[j]), float(wide_hi[j]))
1219 # --- helpers
1220 def apply_onehot(vec_std: np.ndarray, base: str, label: str):
1221 for lab in groups[base]["labels"]:
1222 member_name = groups[base]["name_by_label"][lab]
1223 j = name_to_idx[member_name]
1224 raw = 1.0 if lab == label else 0.0
1225 vec_std[j] = _orig_to_std(j, raw, transforms, mu_f, sd_f)
1227 penalty_lambda = 1.0
1228 penalty_beta = 10.0
1230 def obj_grad(x_std_full: np.ndarray) -> tuple[float, np.ndarray]:
1231 # mean + grad for loss head (centered); add back cond_mean
1232 mu_l, g_l = gp_l.mean_and_grad(x_std_full)
1233 mu = mu_l + cond_mean
1234 # success head (smooth, not clipped)
1235 mu_p, g_p = gp_s.mean_and_grad(x_std_full)
1236 # softplus penalty for p<thr
1237 z = success_threshold - mu_p
1238 sig = 1.0 / (1.0 + np.exp(-penalty_beta * z))
1239 penalty = penalty_lambda * (np.log1p(np.exp(penalty_beta * z)) / penalty_beta)
1240 grad_pen = - penalty_lambda * sig * g_p
1241 J = float(flip * mu + penalty)
1242 gJ = (flip * g_l + grad_pen).astype(float)
1243 return J, gJ
1245 # exclude one-hot members from numeric optimization
1246 onehot_members = {m for g in groups.values() for m in g["members"]}
1247 numeric_idx = [j for j, nm in enumerate(feature_names) if nm not in onehot_members]
1248 num_bounds = [bounds[j] for j in numeric_idx]
1250 # uniform starts inside numeric bounds (std space)
1251 def sample_start() -> np.ndarray:
1252 x = np.zeros(p, float)
1253 for j in numeric_idx:
1254 lo, hi = num_bounds[numeric_idx.index(j)]
1255 x[j] = lo if lo == hi else rng.uniform(lo, hi)
1256 for j, choices in choice_num.items():
1257 x[j] = choices[np.argmin(np.abs(choices - x[j]))]
1258 for j, v in fixed_num_std.items():
1259 x[j] = v
1260 return x
1262 # enumerate categorical combos (fixed → single)
1263 cat_bases = list(groups.keys())
1264 combo_space = []
1265 for b in cat_bases:
1266 combo_space.append([cat_fixed[b]] if b in cat_fixed else cat_allowed[b])
1267 all_label_combos = list(itertools.product(*combo_space)) if combo_space else [()]
1269 # --- optimize (multi-start) and keep the single best over all combos
1270 from scipy.optimize import fmin_l_bfgs_b
1271 n_starts = 32
1272 max_iters = 200
1274 best_global_val: float | None = None
1275 best_global_x: np.ndarray | None = None
1276 best_global_labels: tuple[str, ...] | tuple = tuple()
1278 for labels in all_label_combos:
1279 template = np.zeros(p, float)
1280 for b, lab in zip(cat_bases, labels):
1281 apply_onehot(template, b, lab)
1283 # numeric-only wrapper
1284 def f_g_only_num(x_num: np.ndarray):
1285 x_full = template.copy()
1286 x_full[numeric_idx] = x_num
1287 J, g = obj_grad(x_full)
1288 return J, g[numeric_idx]
1290 # multi-starts
1291 starts = []
1292 for _ in range(n_starts):
1293 s = sample_start()
1294 for b, lab in zip(cat_bases, labels):
1295 apply_onehot(s, b, lab)
1296 starts.append(s[numeric_idx])
1298 # pick best for this combo
1299 best_val = None
1300 best_xnum = None
1301 for x0 in starts:
1302 xopt, fval, _ = fmin_l_bfgs_b(
1303 func=lambda x: f_g_only_num(x),
1304 x0=x0,
1305 fprime=None,
1306 bounds=num_bounds,
1307 maxiter=max_iters,
1308 )
1309 fval = float(fval)
1310 if (best_val is None) or (fval < best_val):
1311 best_val = fval
1312 best_xnum = xopt
1314 if best_xnum is None:
1315 continue
1317 # assemble full point, project choices, clip to bounds
1318 x_full = template.copy()
1319 x_full[numeric_idx] = best_xnum
1320 for j, choices in choice_num.items():
1321 x_full[j] = float(choices[np.argmin(np.abs(choices - x_full[j]))])
1322 for idx in numeric_idx:
1323 lo, hi = num_bounds[numeric_idx.index(idx)]
1324 x_full[idx] = float(np.clip(x_full[idx], lo, hi))
1326 if (best_global_val is None) or (best_val < best_global_val):
1327 best_global_val = float(best_val)
1328 best_global_x = x_full.copy()
1329 best_global_labels = tuple(labels) if labels else tuple()
1331 if best_global_x is None:
1332 raise ValueError("No feasible optimum produced; check/relax constraints.")
1334 # --- build single-row DataFrame in ORIGINAL units
1335 x_opt = best_global_x
1336 mu_l_opt, _ = gp_l.mean_and_grad(x_opt)
1337 mu_opt = float(mu_l_opt + cond_mean)
1338 p_opt, _ = gp_s.mean_and_grad(x_opt)
1339 p_opt = float(np.clip(p_opt, 0.0, 1.0))
1341 sd_opt = gp_l.sd_at(x_opt, include_observation_noise=True)
1343 onehot_members = {m for g in groups.values() for m in g["members"]}
1345 row: dict[str, object] = {
1346 "pred_p_success": p_opt,
1347 "pred_target_mean": mu_opt,
1348 "pred_target_sd": float(sd_opt),
1349 "rank": 1,
1350 }
1351 # numerics in original units (drop one-hot members)
1352 for j, nm in enumerate(feature_names):
1353 if nm in onehot_members:
1354 continue
1355 row[nm] = float(_std_to_orig(j, x_opt[j], transforms, mu_f, sd_f))
1356 # categorical base columns
1357 for b, lab in zip(cat_bases, best_global_labels):
1358 row[b] = lab
1360 df = pd.DataFrame([row])
1362 if output:
1363 output = Path(output)
1364 output.parent.mkdir(parents=True, exist_ok=True)
1365 df.to_csv(output, index=False)
1367 try:
1368 console.print(f"\n[bold]Optimal candidate (mean posterior):[/]")
1369 console.print(df_to_table(df)) # type: ignore[arg-type]
1370 except Exception:
1371 pass
1372 return df
1375def optimal_old(
1376 model: xr.Dataset | Path | str,
1377 output: Path | None = None,
1378 count: int = 10,
1379 n_draws: int = 0,
1380 success_threshold: float = 0.8,
1381 seed: int | np.random.Generator | None = 42,
1382 **kwargs,
1383) -> pd.DataFrame:
1384 """
1385 Rank candidates by probability of being the best feasible optimum (min/max),
1386 honoring numeric *and* categorical constraints.
1388 Constraints (original units):
1389 - number (int/float): fixed value, e.g. epochs=20
1390 - slice(lo, hi): inclusive float range, e.g. learning_rate=slice(1e-5, 1e-3)
1391 - list/tuple: finite numeric choices, e.g. batch_size=(16, 32, 64)
1392 - range(...): converted to tuple of ints (choices)
1393 - categorical base, e.g. language="Linear B" or language=("Linear A","Linear B")
1394 (use the *base* name; model stores one-hot members internally)
1395 """
1396 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model)
1397 pred_success, pred_loss = _build_predictors(ds)
1399 if output:
1400 output = Path(output)
1401 output.parent.mkdir(parents=True, exist_ok=True)
1403 # --- model metadata
1404 feature_names = list(map(str, ds["feature"].values.tolist()))
1405 transforms = list(map(str, ds["feature_transform"].values.tolist()))
1406 feat_mean = ds["feature_mean"].values.astype(float)
1407 feat_std = ds["feature_std"].values.astype(float)
1409 # --- detect categorical one-hot groups from feature names
1410 groups = _onehot_groups(feature_names) # { base: {"labels":[...], "name_by_label":{label->member}, "members":[...]} }
1412 # --- infer numeric search specs from data (includes one-hot members but we’ll drop them below)
1413 specs_full = _infer_search_specs(ds, feature_names, transforms)
1415 # --- split user kwargs into numeric vs categorical constraints
1416 (groups, # same structure as above (returned for convenience)
1417 user_fixed_num, # {numeric_feature: value}
1418 user_ranges_num, # {numeric_feature: (lo, hi)}
1419 user_choices_num, # {numeric_feature: [choices]}
1420 cat_fixed_label, # {base: "Label"} (fixed single label)
1421 cat_allowed) = _split_constraints_for_numeric_and_categorical(feature_names, kwargs)
1423 # numeric fixed beats numeric ranges/choices
1424 for k in list(user_fixed_num.keys()):
1425 user_ranges_num.pop(k, None)
1426 user_choices_num.pop(k, None)
1428 # --- keep only *numeric* specs (drop one-hot members)
1429 numeric_specs = _numeric_specs_only(specs_full, groups)
1431 # apply numeric bounds/choices, normalize numeric fixed
1432 _apply_user_bounds(numeric_specs, user_ranges_num, user_choices_num)
1433 fixed_norm_num = _normalize_fixed(user_fixed_num, numeric_specs)
1435 # --- EI baseline: best feasible observed target
1436 direction = str(ds.attrs.get("direction", "min"))
1437 best_feasible = _best_feasible_observed(ds, direction)
1438 flip = -1.0 if direction == "max" else 1.0
1440 # --- sample candidate pool
1441 rng = get_rng(seed)
1442 target_pool = max(4000, count * 200) # make sure MC has enough variety
1444 def _sample_pool(n: int) -> pd.DataFrame:
1445 # sample numerics
1446 base_num = _sample_candidates(numeric_specs, n=n, rng=rng, fixed=fixed_norm_num)
1447 # inject legal one-hot blocks for categoricals
1448 with_cats = _inject_onehot_groups(base_num, groups, rng, cat_fixed_label, cat_allowed)
1449 # hard filter numerics (ranges/choices/fixed)
1450 filtered = _postfilter_numeric_constraints(with_cats, user_fixed_num, user_ranges_num, user_choices_num)
1451 return filtered
1453 cand_df = _sample_pool(target_pool)
1454 # if tight constraints reduce pool too much, try a few refills
1455 attempts = 0
1456 while len(cand_df) < max(count * 50, 1000) and attempts < 6:
1457 extra = _sample_pool(target_pool)
1458 if not extra.empty:
1459 cand_df = pd.concat([cand_df, extra], ignore_index=True).drop_duplicates()
1460 attempts += 1
1462 if cand_df.empty:
1463 raise ValueError("No candidates satisfy the provided constraints; relax the ranges or choices.")
1465 # --- predictions in model space (use full feature order incl. one-hot members)
1466 Xn_cands = _original_df_to_standardized(cand_df[feature_names], feature_names, transforms, feat_mean, feat_std)
1467 p = pred_success(Xn_cands)
1468 mu, sd = pred_loss(Xn_cands, include_observation_noise=True)
1469 sd = np.maximum(sd, 1e-12)
1471 # --- optional feasibility filter
1472 keep = p >= float(success_threshold)
1473 if not np.any(keep):
1474 keep = np.ones_like(p, dtype=bool)
1476 cand_df = cand_df.loc[keep].reset_index(drop=True)
1477 Xn_cands = Xn_cands[keep]
1478 p = p[keep]; mu = mu[keep]; sd = sd[keep]
1479 N = len(cand_df)
1480 if N == 0:
1481 raise ValueError("All sampled candidates were filtered out by success_threshold.")
1483 # --- mean-only mode when n_draws == 0
1484 if int(n_draws) <= 0:
1485 result = cand_df.copy()
1486 result["pred_p_success"] = p
1487 result["pred_target_mean"] = mu
1488 result["pred_target_sd"] = sd
1489 # keep columns for API parity
1490 result["wins"] = 0
1491 result["n_draws_effective"] = 0
1492 result["prob_best_feasible"] = 0.0
1493 result["conditioned_on"] = _pretty_conditioned_on(
1494 fixed_norm_numeric=fixed_norm_num,
1495 cat_fixed_label=cat_fixed_label,
1496 )
1498 # Direction-aware sort by μ, then lower σ, then higher p
1499 if str(ds.attrs.get("direction", "min")) == "max":
1500 sort_cols = ["pred_target_mean", "pred_target_sd", "pred_p_success"]
1501 ascending = [False, True, False]
1502 else: # "min"
1503 sort_cols = ["pred_target_mean", "pred_target_sd", "pred_p_success"]
1504 ascending = [True, True, False]
1506 result_sorted = result.sort_values(
1507 sort_cols, ascending=ascending, kind="mergesort"
1508 ).reset_index(drop=True)
1509 result_sorted["rank_prob_best"] = np.arange(1, len(result_sorted) + 1)
1511 top = result_sorted.head(count).reset_index(drop=True)
1512 # collapse one-hot → single categorical columns (e.g., 'language')
1513 top_view = _collapse_onehot_to_categorical(top, groups)
1515 if output:
1516 top_view.to_csv(output, index=False)
1518 console.print(f"\n[bold]Top {len(top_view)} optimal solutions (mean-only, n_draws=0):[/]")
1519 console.print(df_to_table(top_view))
1520 return top_view
1522 # --- Monte Carlo winner-take-all over feasible draws
1523 Z = mu[:, None] + sd[:, None] * rng.standard_normal((N, n_draws))
1524 success_mask = rng.random((N, n_draws)) < p[:, None]
1525 feasible_draw = success_mask.any(axis=0)
1526 if not feasible_draw.any():
1527 # fallback: deterministic sort (rare)
1528 result = cand_df.copy()
1529 result["pred_p_success"] = p
1530 result["pred_target_mean"] = mu
1531 result["pred_target_sd"] = sd
1532 result["prob_best_feasible"] = 0.0
1533 result["wins"] = 0
1534 result["n_draws_effective"] = 0
1535 # prettify conditioning (numeric fixed + categorical fixed)
1536 result["conditioned_on"] = _pretty_conditioned_on(
1537 fixed_norm_numeric=fixed_norm_num,
1538 cat_fixed_label=cat_fixed_label,
1539 )
1540 result_sorted = result.sort_values(
1541 ["pred_target_mean", "pred_target_sd", "pred_p_success"],
1542 ascending=[True, True, False],
1543 kind="mergesort",
1544 ).reset_index(drop=True)
1545 result_sorted["rank_prob_best"] = np.arange(1, len(result_sorted) + 1)
1546 top = result_sorted.head(count).reset_index(drop=True)
1547 # collapse one-hot → single categorical columns for output
1548 top_view = _collapse_onehot_to_categorical(top, groups)
1549 if output:
1550 top_view.to_csv(output, index=False)
1551 console.print(f"\n[bold]Top {len(top_view)} optimal solutions:[/]")
1552 console.print(df_to_table(top_view))
1553 return top_view
1555 Z_eff = flip * np.where(success_mask, Z, np.inf)
1556 Zf = Z_eff[:, feasible_draw]
1558 winner_idx = np.argmin(Zf, axis=0)
1559 counts = np.bincount(winner_idx, minlength=N)
1560 n_eff = int(feasible_draw.sum())
1561 prob_best = counts / float(n_eff)
1563 result = cand_df.copy()
1564 result["pred_p_success"] = p
1565 result["pred_target_mean"] = mu
1566 result["pred_target_sd"] = sd
1567 result["wins"] = counts
1568 result["n_draws_effective"] = n_eff
1569 result["prob_best_feasible"] = prob_best
1570 result["conditioned_on"] = _pretty_conditioned_on(
1571 fixed_norm_numeric=fixed_norm_num,
1572 cat_fixed_label=cat_fixed_label,
1573 )
1575 result_sorted = result.sort_values(
1576 ["prob_best_feasible", "pred_p_success", "pred_target_mean", "pred_target_sd"],
1577 ascending=[False, False, True, True],
1578 kind="mergesort",
1579 ).reset_index(drop=True)
1580 result_sorted["rank_prob_best"] = np.arange(1, len(result_sorted) + 1)
1582 top = result_sorted.head(count).reset_index(drop=True)
1583 # collapse one-hot → single categorical columns (e.g. 'language')
1584 top_view = _collapse_onehot_to_categorical(top, groups)
1586 if output:
1587 top_view.to_csv(output, index=False)
1589 console.print(f"\n[bold]Top {len(top_view)} optimal solutions:[/]")
1590 console.print(df_to_table(top_view))
1591 return top_view
1595# =============================================================================
1596# Predictors reconstructed from artifact (no PyMC at runtime)
1597# =============================================================================
1599def _build_predictors(ds: xr.Dataset) -> tuple[
1600 Callable[[np.ndarray], np.ndarray],
1601 Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]
1602]:
1603 """Return (predict_success_probability, predict_conditional_target)."""
1604 Xn_all = ds["Xn_train"].values.astype(float)
1605 y_success = ds["y_success"].values.astype(float) # not used, but handy to keep
1606 Xn_ok = ds["Xn_success_only"].values.astype(float)
1607 y_loss_centered = ds["y_loss_centered"].values.astype(float)
1609 # Success head MAP
1610 ell_s = ds["map_success_ell"].values.astype(float)
1611 eta_s = float(ds["map_success_eta"].values)
1612 sigma_s = float(ds["map_success_sigma"].values)
1613 beta0_s = float(ds["map_success_beta0"].values)
1615 # Loss head MAP
1616 ell_l = ds["map_loss_ell"].values.astype(float)
1617 eta_l = float(ds["map_loss_eta"].values)
1618 sigma_l = float(ds["map_loss_sigma"].values)
1619 mean_c = float(ds["map_loss_mean_const"].values)
1620 cond_mean = float(ds["conditional_loss_mean"].values)
1622 # Cholesky precomputations
1623 K_s = kernel_m52_ard(Xn_all, Xn_all, ell_s, eta_s) + (sigma_s**2) * np.eye(Xn_all.shape[0])
1624 L_s = np.linalg.cholesky(add_jitter(K_s))
1625 alpha_s = solve_chol(L_s, (y_success - beta0_s))
1627 K_l = kernel_m52_ard(Xn_ok, Xn_ok, ell_l, eta_l) + (sigma_l**2) * np.eye(Xn_ok.shape[0])
1628 L_l = np.linalg.cholesky(add_jitter(K_l))
1629 alpha_l = solve_chol(L_l, (y_loss_centered - mean_c))
1631 def predict_success_probability(Xn: np.ndarray) -> np.ndarray:
1632 Ks = kernel_m52_ard(Xn, Xn_all, ell_s, eta_s)
1633 mu = beta0_s + Ks @ alpha_s
1634 return np.clip(mu, 0.0, 1.0)
1636 def predict_conditional_target(Xn: np.ndarray, include_observation_noise: bool = True) -> tuple[np.ndarray, np.ndarray]:
1637 Kl = kernel_m52_ard(Xn, Xn_ok, ell_l, eta_l)
1638 mu_c = mean_c + Kl @ alpha_l
1639 mu = mu_c + cond_mean
1640 v = solve_lower(L_l, Kl.T)
1641 var = kernel_diag_m52(Xn, ell_l, eta_l) - np.sum(v * v, axis=0)
1642 var = np.maximum(var, 1e-12)
1643 if include_observation_noise:
1644 var = var + sigma_l**2
1645 sd = np.sqrt(var)
1646 return mu, sd
1648 return predict_success_probability, predict_conditional_target
1651# =============================================================================
1652# Search space, conditioning, and featurization
1653# =============================================================================
1655def _infer_search_specs(
1656 ds: xr.Dataset,
1657 feature_names: list[str],
1658 transforms: list[str],
1659 pad_frac: float = 0.10,
1660) -> dict[str, dict]:
1661 """
1662 Build per-feature search specs from the *original-unit* columns present in the artifact.
1663 Returns dict: name -> spec, where spec is one of:
1664 {"type":"float", "lo":float, "hi":float}
1665 {"type":"int", "lo":int, "hi":int, "step":int (optional)}
1666 {"type":"choice","choices": list[int|float], "dtype":"int"|"float"}
1667 """
1668 specs: dict[str, dict] = {}
1670 df_raw = pd.DataFrame({k: ds[k].values for k in ds.data_vars if ds[k].dims == ("row",)})
1671 # prefer top-level columns if present
1672 for j, name in enumerate(feature_names):
1673 if name in df_raw.columns:
1674 vals = pd.to_numeric(pd.Series(df_raw[name]), errors="coerce").dropna().to_numpy()
1675 else:
1676 # fallback: reconstruct original units from standardized arrays if needed
1677 # (in your artifact, raw columns are stored; so this path is rarely used)
1678 try:
1679 base_vals = ds[name].values # raw per-row column, if present
1680 except KeyError:
1681 # Not stored as a data_var (e.g., one-hot feature); reconstruct from Xn_train
1682 # j is the feature index in feature_names; transforms[j] is 'identity' or 'log10'
1683 base_vals = feature_raw_from_artifact_or_reconstruct(ds, j, name, transforms[j])
1685 vals = pd.to_numeric(pd.Series(base_vals), errors="coerce").dropna().to_numpy()
1688 if vals.size == 0:
1689 # degenerate column; fall back to [0,1]
1690 specs[name] = {"type": "float", "lo": 0.0, "hi": 1.0}
1691 continue
1693 # detect integer-ish
1694 intish = np.all(np.isfinite(vals)) and np.allclose(vals, np.round(vals))
1696 # robust bounds with padding
1697 p1, p99 = np.percentile(vals, [1, 99])
1698 span = max(p99 - p1, 1e-12)
1699 lo = p1 - pad_frac * span
1700 hi = p99 + pad_frac * span
1702 if intish:
1703 lo_i = int(np.floor(lo))
1704 hi_i = int(np.ceil(hi))
1705 specs[name] = {"type": "int", "lo": lo_i, "hi": hi_i}
1706 else:
1707 specs[name] = {"type": "float", "lo": float(lo), "hi": float(hi)}
1708 return specs
1711def _normalize_fixed(
1712 fixed_raw: dict[str, object],
1713 specs: dict[str, dict],
1714) -> dict[str, object]:
1715 """
1716 Normalize user constraints to sanitized forms within inferred bounds.
1717 Keeps the *shape*:
1718 - number (int/float) -> fixed (clipped to [lo,hi])
1719 - slice(lo, hi) -> float range (clipped to [lo,hi])
1720 - list/tuple -> finite choices (filtered to within [lo,hi], cast to int for int specs)
1721 Returns a dict usable directly by _sample_candidates.
1722 """
1723 fixed_norm: dict[str, object] = {}
1725 for name, val in (fixed_raw or {}).items():
1726 if name not in specs:
1727 # unknown feature already warned upstream; skip silently here
1728 continue
1730 sp = specs[name]
1731 typ = sp["type"]
1733 # helper clamps
1734 def _clip_float(x: float) -> float:
1735 return float(np.clip(x, sp["lo"], sp["hi"]))
1737 def _clip_int(x: int) -> int:
1738 lo, hi = int(sp.get("lo", x)), int(sp.get("hi", x))
1739 return int(np.clip(int(round(x)), lo, hi))
1741 # numeric fixed
1742 if isinstance(val, (int, float, np.number)):
1743 if typ == "int":
1744 fixed_norm[name] = _clip_int(int(round(val)))
1745 elif typ == "choice" and sp.get("dtype") == "int":
1746 fixed_norm[name] = _clip_int(int(round(val)))
1747 else:
1748 fixed_norm[name] = _clip_float(float(val))
1749 continue
1751 # float range via slice(lo, hi)
1752 if isinstance(val, slice):
1753 lo = float(val.start)
1754 hi = float(val.stop)
1755 if lo > hi:
1756 lo, hi = hi, lo
1757 if typ in ("float", "choice") and sp.get("dtype") != "int":
1758 lo_c = _clip_float(lo); hi_c = _clip_float(hi)
1759 if lo_c > hi_c: lo_c, hi_c = hi_c, lo_c
1760 fixed_norm[name] = slice(lo_c, hi_c)
1761 else:
1762 # int spec: convert to inclusive integer tuple
1763 lo_i = _clip_int(int(np.floor(lo)))
1764 hi_i = _clip_int(int(np.ceil(hi)))
1765 choices = tuple(range(lo_i, hi_i + 1))
1766 fixed_norm[name] = choices
1767 continue
1769 # choices via list/tuple
1770 if isinstance(val, (list, tuple)):
1771 if typ in ("int",) or (typ == "choice" and sp.get("dtype") == "int"):
1772 vv = [ _clip_int(int(round(x))) for x in val ]
1773 # de-dup and sort
1774 vv = sorted(set(vv))
1775 if not vv:
1776 # fallback to center
1777 center = _clip_int(int(np.round((sp["lo"] + sp["hi"]) / 2)))
1778 vv = [center]
1779 fixed_norm[name] = tuple(vv)
1780 else:
1781 vv = [ _clip_float(float(x)) for x in val ]
1782 vv = sorted(set(vv))
1783 if not vv:
1784 center = _clip_float((sp["lo"] + sp["hi"]) / 2.0)
1785 vv = [center]
1786 # keep list/tuple shape (tuple preferred)
1787 fixed_norm[name] = tuple(vv)
1788 continue
1790 # otherwise: ignore incompatible type
1791 # (you could raise here if you prefer a hard failure)
1792 return fixed_norm
1795def _sample_candidates(
1796 specs: dict[str, dict],
1797 n: int,
1798 rng: np.random.Generator,
1799 fixed: dict[str, object] | None = None,
1800) -> pd.DataFrame:
1801 """
1802 Sample n candidates in ORIGINAL units given search specs and optional fixed constraints.
1803 """
1804 fixed = fixed or {}
1805 cols: dict[str, np.ndarray] = {}
1807 for name, sp in specs.items():
1808 typ = sp["type"]
1810 # If fixed: honor numeric / slice / choices shape
1811 if name in fixed:
1812 val = fixed[name]
1814 # numeric: constant column
1815 if isinstance(val, (int, float, np.number)):
1816 cols[name] = np.full(n, val, dtype=float)
1818 # float range slice
1819 elif isinstance(val, slice):
1820 lo = float(val.start); hi = float(val.stop)
1821 if lo > hi: lo, hi = hi, lo
1822 cols[name] = rng.uniform(lo, hi, size=n)
1824 # choices: list/tuple -> sample from set
1825 elif isinstance(val, (list, tuple)):
1826 arr = np.array(val, dtype=float)
1827 if arr.size == 0:
1828 # fallback to center of spec
1829 if typ == "int":
1830 center = int(np.round((sp["lo"] + sp["hi"]) / 2))
1831 arr = np.array([center], dtype=float)
1832 else:
1833 center = (sp["lo"] + sp["hi"]) / 2.0
1834 arr = np.array([center], dtype=float)
1835 idx = rng.integers(0, len(arr), size=n)
1836 cols[name] = arr[idx]
1838 else:
1839 # unknown fixed type; fallback to spec sampling
1840 if typ == "choice":
1841 choices = np.asarray(sp["choices"], dtype=float)
1842 idx = rng.integers(0, len(choices), size=n)
1843 cols[name] = choices[idx]
1844 elif typ == "int":
1845 cols[name] = rng.integers(int(sp["lo"]), int(sp["hi"]) + 1, size=n).astype(float)
1846 else:
1847 cols[name] = rng.uniform(sp["lo"], sp["hi"], size=n)
1849 else:
1850 # Not fixed: sample from spec
1851 if typ == "choice":
1852 choices = np.asarray(sp["choices"], dtype=float)
1853 idx = rng.integers(0, len(choices), size=n)
1854 cols[name] = choices[idx]
1855 elif typ == "int":
1856 cols[name] = rng.integers(int(sp["lo"]), int(sp["hi"]) + 1, size=n).astype(float)
1857 else:
1858 cols[name] = rng.uniform(sp["lo"], sp["hi"], size=n)
1860 df = pd.DataFrame(cols)
1861 # ensure integer columns are ints if the spec says so (pretty output)
1862 for name, sp in specs.items():
1863 if sp["type"] == "int" or (sp["type"] == "choice" and sp.get("dtype") == "int"):
1864 df[name] = df[name].round().astype(int)
1865 return df
1868def _original_df_to_standardized(
1869 df: pd.DataFrame,
1870 feature_names: list[str],
1871 transforms: list[str],
1872 feat_mean: np.ndarray,
1873 feat_std: np.ndarray,
1874) -> np.ndarray:
1875 cols = []
1876 for j, name in enumerate(feature_names):
1877 x = df[name].to_numpy().astype(float)
1878 tr = transforms[j]
1879 if tr == "log10":
1880 x = np.where(x <= 0, np.nan, x)
1881 x = np.log10(x)
1882 cols.append((x - feat_mean[j]) / feat_std[j])
1883 return np.column_stack(cols).astype(float)
1886# =============================================================================
1887# Acquisition functions & utilities
1888# =============================================================================
1890def _expected_improvement_minimize(mu: np.ndarray, sd: np.ndarray, best_y: float) -> np.ndarray:
1891 sd = np.maximum(sd, 1e-12)
1892 z = (best_y - mu) / sd
1893 Phi = ndtr(z)
1894 phi = np.exp(-0.5 * z * z) / np.sqrt(2.0 * np.pi)
1895 return sd * (z * Phi + phi)
1898def _constrained_EI(mu: np.ndarray, sd: np.ndarray, p_success: np.ndarray, best_y: float,
1899 p_threshold: float = 0.8, softness: float = 0.05) -> np.ndarray:
1900 ei = _expected_improvement_minimize(mu, sd, best_y)
1901 s = 1.0 / (1.0 + np.exp(-(p_success - p_threshold) / max(softness, 1e-6)))
1902 return ei * s
1905def _exploration_score(sd_loss: np.ndarray, p_success: np.ndarray,
1906 w_sd: float = 1.0, w_boundary: float = 0.5) -> np.ndarray:
1907 return w_sd * sd_loss + w_boundary * (p_success * (1.0 - p_success))
1910def _novelty_score(Xn_cands: np.ndarray, Xn_seen: np.ndarray) -> np.ndarray:
1911 m = Xn_cands.shape[0]
1912 batch = 1024
1913 out = np.empty(m, dtype=float)
1914 for i in range(0, m, batch):
1915 sl = slice(i, min(i + batch, m))
1916 diff = Xn_cands[sl, None, :] - Xn_seen[None, :, :]
1917 d = np.linalg.norm(diff, axis=2)
1918 out[sl] = np.min(d, axis=1)
1919 return out
1922def _maybe_flip_for_direction(mu: np.ndarray, best_y: float, direction: str) -> tuple[np.ndarray, float]:
1923 if direction == "max":
1924 return -mu, -best_y
1925 return mu, best_y
1928def _best_feasible_observed(ds: xr.Dataset, direction: str) -> float:
1929 y_ok = ds["y_loss_success"].values.astype(float)
1930 if y_ok.size == 0:
1931 return np.inf if direction != "max" else -np.inf
1932 if direction == "max":
1933 return float(np.nanmax(y_ok))
1934 return float(np.nanmin(y_ok))
1937def _is_number(x) -> bool:
1938 return isinstance(x, (int, float, np.integer, np.floating))
1941def _fmt_num(x) -> str:
1942 try:
1943 return f"{float(x):.6g}"
1944 except Exception:
1945 return str(x)
1948def _fixed_as_string(fixed: dict) -> str:
1949 """
1950 Human-readable constraints:
1951 - number -> k=12 or k=0.00123
1952 - slice -> k=lo:hi (inclusive; None shows as -inf/inf)
1953 - list/tuple -> k=[v1, v2, ...]
1954 - range -> k=[start, stop, step] (rare; usually normalized earlier)
1955 - other scalars (str/bool) -> k=value
1956 Keys are sorted for stability.
1957 """
1958 parts: list[str] = []
1959 for k in sorted(fixed.keys()):
1960 v = fixed[k]
1961 if isinstance(v, slice):
1962 a = "-inf" if v.start is None else _fmt_num(v.start)
1963 b = "inf" if v.stop is None else _fmt_num(v.stop)
1964 parts.append(f"{k}={a}:{b}")
1965 elif isinstance(v, range):
1966 parts.append(f"{k}=[{', '.join(_fmt_num(u) for u in (v.start, v.stop, v.step))}]")
1967 elif isinstance(v, (list, tuple, np.ndarray)):
1968 elems = ", ".join(_fmt_num(u) if _is_number(u) else str(u) for u in v)
1969 parts.append(f"{k}=[{elems}]")
1970 elif _is_number(v):
1971 parts.append(f"{k}={_fmt_num(v)}")
1972 else:
1973 # fallback for str/bool/other scalars
1974 parts.append(f"{k}={v}")
1975 return ", ".join(parts)
1978def _apply_user_bounds(
1979 specs: dict[str, dict[str, Any]],
1980 ranges: dict[str, tuple[float, float]],
1981 choices: dict[str, list[float]],
1982) -> None:
1983 """
1984 Mutate `specs` with user-provided bounds/choices.
1985 """
1986 for name, (lo, hi) in ranges.items():
1987 if name not in specs:
1988 continue
1989 sp = specs[name]
1990 sp["kind"] = sp.get("kind", "float")
1991 if sp["kind"] == "choice":
1992 # Convert to float/int range if user provided range for a choice var
1993 sp["kind"] = "float"
1994 sp["low"] = float(lo)
1995 sp["high"] = float(hi)
1996 sp.pop("choices", None)
1998 for name, opts in choices.items():
1999 if name not in specs:
2000 continue
2001 sp = specs[name]
2002 # Keep kind="choice" and store list
2003 sp["kind"] = "choice"
2004 # Cast ints if all values are close to ints
2005 if all(abs(v - round(v)) < 1e-12 for v in opts):
2006 sp["choices"] = [int(round(v)) for v in opts]
2007 else:
2008 sp["choices"] = [float(v) for v in opts]
2009 # Drop bounds (not used for choice)
2010 sp.pop("low", None)
2011 sp.pop("high", None)