Coverage for psyop/model.py: 74.65%
355 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-10-29 03:44 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-10-29 03:44 +0000
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4# Make BLAS single-threaded to avoid oversubscription / macOS crashes
5import os
6for _env_var in (
7 "MKL_NUM_THREADS",
8 "OPENBLAS_NUM_THREADS",
9 "OMP_NUM_THREADS",
10 "VECLIB_MAXIMUM_THREADS",
11 "NUMEXPR_NUM_THREADS",
12):
13 os.environ.setdefault(_env_var, "1")
15from pathlib import Path
16import base64
17import pickle
18import json
20import numpy as np
21import pandas as pd
22import pymc as pm
23import xarray as xr
24from rich.console import Console
25from rich.table import Table
27from .util import get_rng, df_to_table
29def _safe_vec(ds: xr.Dataset, name: str, nF: int) -> np.ndarray:
30 if name not in ds:
31 return np.full(nF, np.nan)
32 arr = np.asarray(ds[name].values)
33 if arr.size == nF:
34 return arr
35 out = np.full(nF, np.nan)
36 out[:min(nF, arr.size)] = arr.ravel()[:min(nF, arr.size)]
37 return out
39def _safe_scalar(ds: xr.Dataset, name: str) -> float:
40 if name in ds:
41 try: return float(np.asarray(ds[name].values).item())
42 except Exception: return float(np.nan)
43 if name in ds.attrs:
44 try: return float(ds.attrs[name])
45 except Exception: return float(np.nan)
46 return float(np.nan)
48def _diagnostic_feature_dataframe(ds: xr.Dataset, top_k: int = 20) -> pd.DataFrame:
49 """Top-K features by |corr_loss_success| without any repeated globals."""
50 f = [str(x) for x in ds["feature"].values]
51 nF = len(f)
53 is_oh = ds["feature_is_onehot_member"].values.astype(int) if "feature_is_onehot_member" in ds else np.zeros(nF, int)
54 base = ds["feature_onehot_base"].values if "feature_onehot_base" in ds else np.array([""]*nF, object)
56 df = pd.DataFrame({
57 "feature": f,
58 "type": np.where(is_oh == 1, "categorical(one-hot)", "numeric"),
59 "onehot_base": np.where(is_oh == 1, base, ""),
60 "n_unique_raw": _safe_vec(ds, "n_unique_raw", nF),
61 "raw_min": _safe_vec(ds, "raw_min", nF),
62 "raw_max": _safe_vec(ds, "raw_max", nF),
63 "Xn_span": _safe_vec(ds, "Xn_span", nF),
64 "ell_s": _safe_vec(ds, "map_success_ell", nF),
65 "ell_l": _safe_vec(ds, "map_loss_ell", nF),
66 "ell/span_s": _safe_vec(ds, "ell_over_span_success", nF),
67 "ell/span_l": _safe_vec(ds, "ell_over_span_loss", nF),
68 "corr_success": _safe_vec(ds, "corr_success", nF),
69 "corr_loss_success": _safe_vec(ds, "corr_loss_success", nF),
70 })
71 df["|corr_loss_success|"] = np.abs(df["corr_loss_success"].astype(float))
72 return df.sort_values("|corr_loss_success|", ascending=False, kind="mergesort").head(top_k).reset_index(drop=True)
74def _diagnostic_global_dataframe(ds: xr.Dataset) -> pd.DataFrame:
75 """Single-row DataFrame of global/model-level diagnostics."""
76 # Scalars from attrs
77 target = ds.attrs.get("target", "")
78 direction = ds.attrs.get("direction", "")
79 n_rows = int(ds.attrs.get("n_rows", np.nan))
80 n_success_rows = int(ds.attrs.get("n_success_rows", np.nan)) if "n_success_rows" in ds.attrs else int(np.sum(np.asarray(ds["success_mask"].values)) if "success_mask" in ds else np.nan)
81 success_rate = float(ds.attrs.get("success_rate", np.nan))
82 rng_bitgen = ds.attrs.get("rng_bitgen", "")
83 numpy_version = ds.attrs.get("numpy_version", "")
84 pymc_version = ds.attrs.get("pymc_version", "")
86 # Scalars from data_vars (with attr fallback)
87 conditional_loss_mean = _safe_scalar(ds, "conditional_loss_mean")
89 # GP MAP scalars
90 map_success_eta = _safe_scalar(ds, "map_success_eta")
91 map_success_sigma = _safe_scalar(ds, "map_success_sigma")
92 map_success_beta0 = _safe_scalar(ds, "map_success_beta0")
94 map_loss_eta = _safe_scalar(ds, "map_loss_eta")
95 map_loss_sigma = _safe_scalar(ds, "map_loss_sigma")
96 map_loss_mean_c = _safe_scalar(ds, "map_loss_mean_const")
98 # Data dispersion on successes
99 y_ls_std = float(np.nanstd(ds["y_loss_success"].values)) if "y_loss_success" in ds else np.nan
101 # Handy amplitude/noise ratio for loss head
102 eta_l_over_sigma_l = (map_loss_eta / map_loss_sigma) if (np.isfinite(map_loss_eta) and np.isfinite(map_loss_sigma) and map_loss_sigma != 0) else np.nan
104 rows = [{
105 "target": target,
106 "direction": direction,
107 "n_rows": n_rows,
108 "n_success_rows": n_success_rows,
109 "success_rate": success_rate,
110 "conditional_loss_mean": conditional_loss_mean,
111 "map_success_eta": map_success_eta,
112 "map_success_sigma": map_success_sigma,
113 "map_success_beta0": map_success_beta0,
114 "map_loss_eta": map_loss_eta,
115 "map_loss_sigma": map_loss_sigma,
116 "map_loss_mean_const": map_loss_mean_c,
117 "y_loss_success_std": y_ls_std,
118 "eta_l/sigma_l": eta_l_over_sigma_l,
119 "rng_bitgen": rng_bitgen,
120 "numpy_version": numpy_version,
121 "pymc_version": pymc_version,
122 }]
123 return pd.DataFrame(rows)
127def _print_diagnostics_table(ds: xr.Dataset, top_k: int = 20) -> None:
128 """3-sigfig Rich table with key diagnostics."""
129 feat_df = _diagnostic_feature_dataframe(ds, top_k=20)
130 global_df = _diagnostic_global_dataframe(ds)
132 console = Console()
134 console.print("\n[bold]Model diagnostics (top by |corr_loss_success|):[/]")
135 console.print(df_to_table(feat_df, transpose=False, show_index=False)) # regular table
137 console.print("\n[bold]Model globals:[/]")
138 # transpose for key/value look, with magenta header column
139 console.print(df_to_table(global_df, transpose=True))
142def build_model(
143 input: pd.DataFrame|Path|str,
144 target: str,
145 output: Path | str | None = None,
146 exclude: list[str] | str | None = None,
147 direction: str = "auto",
148 seed: int | np.random.Generator | None = 42,
149 compress: bool = True,
150 prior_model: Path | str | xr.Dataset | None = None,
151) -> xr.Dataset:
152 """
153 Fit two-head GP (success prob + conditional loss) and save a single NetCDF artifact.
154 Also stores rich diagnostics to help debug flat PD curves and model wiring.
155 """
156 # ---------- Load ----------
157 if isinstance(input, pd.DataFrame):
158 df = input
159 else:
160 input_path = Path(input)
161 if not input_path.exists():
162 raise FileNotFoundError(f"Input CSV not found: {input_path.resolve()}")
163 df = pd.read_csv(input_path)
165 if target not in df.columns:
166 raise ValueError(f"Target column '{target}' not found in CSV.")
168 # Keep a raw copy for artifact (strings as pandas 'string' so NetCDF can handle them)
169 df_raw = df.copy()
170 for c in df_raw.columns:
171 if df_raw[c].dtype == object:
172 df_raw[c] = df_raw[c].astype("string")
174 # ---------- Success inference ----------
175 success = (~df[target].isna()).to_numpy().astype(int)
176 has_success = bool(np.any(success == 1))
177 if not has_success:
178 raise RuntimeError("No successful rows detected (cannot fit conditional-loss GP).")
180 # ---------- Feature selection ----------
181 reserved_internals = {"__success__", "__fail__", "__status__"}
182 exclude = [exclude] if isinstance(exclude, str) else (exclude or [])
183 excluded = set(exclude) | {target} | reserved_internals
185 numeric_cols = [
186 c for c in df.columns
187 if c not in excluded and pd.api.types.is_numeric_dtype(df[c])
188 ]
189 cat_cols = [
190 c for c in df.columns
191 if c not in excluded
192 and (pd.api.types.is_string_dtype(df[c]) or pd.api.types.is_categorical_dtype(df[c]))
193 ]
195 feature_names: list[str] = []
196 transforms: list[str] = []
197 X_raw_cols: list[np.ndarray] = []
199 # Diagnostics to fill
200 onehot_base_per_feature: list[str] = [] # "" for numeric / non-onehot
201 is_onehot_member: list[int] = []
202 categorical_groups: dict[str, dict] = {} # base -> {"labels":[...], "members":[...]}
204 # numeric features
205 for name in numeric_cols:
206 col = df[name].to_numpy(dtype=float)
207 tr = _choose_transform(name, col)
208 transforms.append(tr)
209 feature_names.append(name)
210 X_raw_cols.append(_apply_transform(tr, col))
211 onehot_base_per_feature.append("")
212 is_onehot_member.append(0)
214 # categoricals → one-hot
215 for base in cat_cols:
216 s_cat = pd.Categorical(df[base].astype("string").fillna("<NA>"))
217 H = pd.get_dummies(s_cat, prefix=base, prefix_sep="=", dtype=float) # e.g., language=Linear A
218 members = []
219 labels = []
220 for new_col in H.columns:
221 feature_names.append(new_col)
222 transforms.append("identity")
223 X_raw_cols.append(H[new_col].to_numpy(dtype=float))
224 onehot_base_per_feature.append(base)
225 is_onehot_member.append(1)
226 members.append(new_col)
227 # label is the part after "base="
228 labels.append(str(new_col.split("=", 1)[1]) if "=" in new_col else str(new_col))
229 categorical_groups[base] = {"labels": labels, "members": members}
231 X_raw = np.column_stack(X_raw_cols).astype(float)
232 n, p = X_raw.shape
234 prior_start_success: dict[str, np.ndarray | float] | None = None
235 prior_start_loss: dict[str, np.ndarray | float] | None = None
236 if prior_model is not None:
237 if isinstance(prior_model, xr.Dataset):
238 prior_ds = prior_model
239 else:
240 prior_path = Path(prior_model)
241 if not prior_path.exists():
242 raise FileNotFoundError(f"Prior model artifact not found: {prior_path.resolve()}")
243 prior_ds = xr.load_dataset(prior_path)
244 try:
245 prior_feature_names = [str(v) for v in prior_ds["feature"].values.tolist()]
246 except Exception:
247 prior_ds = None
248 else:
249 if prior_feature_names != feature_names:
250 prior_ds = None
251 if prior_ds is not None:
252 try:
253 ell_s_prev = np.asarray(prior_ds["map_success_ell"].values, dtype=float)
254 eta_s_prev = float(np.asarray(prior_ds["map_success_eta"].values))
255 sigma_s_prev = float(np.asarray(prior_ds["map_success_sigma"].values))
256 beta0_s_prev = float(np.asarray(prior_ds["map_success_beta0"].values))
257 ell_l_prev = np.asarray(prior_ds["map_loss_ell"].values, dtype=float)
258 eta_l_prev = float(np.asarray(prior_ds["map_loss_eta"].values))
259 sigma_l_prev = float(np.asarray(prior_ds["map_loss_sigma"].values))
260 mean_c_prev = float(np.asarray(prior_ds["map_loss_mean_const"].values))
261 if ell_s_prev.shape == (p,) and ell_l_prev.shape == (p,):
262 prior_start_success = {
263 "ell": ell_s_prev,
264 "eta": eta_s_prev,
265 "sigma": sigma_s_prev,
266 "beta0": beta0_s_prev,
267 }
268 prior_start_loss = {
269 "ell": ell_l_prev,
270 "eta": eta_l_prev,
271 "sigma": sigma_l_prev,
272 "mean_const": mean_c_prev,
273 }
274 except Exception:
275 prior_start_success = None
276 prior_start_loss = None
278 # ---------- Standardize ----------
279 X_mean = X_raw.mean(axis=0)
280 X_std = X_raw.std(axis=0)
281 X_std = np.where(X_std == 0.0, 1.0, X_std) # keep inert dims harmless
282 Xn = (X_raw - X_mean) / X_std
284 # Targets
285 y_success = success.astype(float)
286 ok_mask = (success == 1)
287 y_loss_success = df.loc[ok_mask, target].to_numpy(dtype=float)
289 conditional_loss_mean = float(np.nanmean(y_loss_success)) if len(y_loss_success) else 0.0
290 y_loss_centered = y_loss_success - conditional_loss_mean
291 Xn_success_only = Xn[ok_mask, :]
293 rng = get_rng(seed)
294 state_bytes = pickle.dumps(rng.bit_generator.state)
295 rng_state_b64 = base64.b64encode(state_bytes).decode("ascii")
297 base_success_rate = float(y_success.mean())
299 # ---------- Fit Head A: success ----------
300 with pm.Model() as model_s:
301 ell_s = pm.HalfNormal("ell", sigma=2.0, shape=p)
302 eta_s = pm.HalfNormal("eta", sigma=2.0)
303 sigma_s = pm.HalfNormal("sigma", sigma=0.3)
304 beta0_s = pm.Normal("beta0", mu=base_success_rate, sigma=0.15)
306 K_s = eta_s**2 * pm.gp.cov.Matern52(input_dim=p, ls=ell_s)
307 m_s = pm.gp.mean.Constant(beta0_s)
308 gp_s = pm.gp.Marginal(mean_func=m_s, cov_func=K_s)
310 _ = gp_s.marginal_likelihood("y_obs_s", X=Xn, y=y_success, sigma=sigma_s)
311 map_s = pm.find_MAP(start=prior_start_success) if prior_start_success is not None else pm.find_MAP()
313 with model_s:
314 mu_s, var_s = gp_s.predict(Xn, point=map_s, diag=True, pred_noise=True)
315 mu_s = np.clip(mu_s, 0.0, 1.0)
317 # ---------- Fit Head B: conditional loss (success-only) ----------
318 if Xn_success_only.shape[0] == 0:
319 raise RuntimeError("No successful rows to fit the conditional-loss GP.")
321 with pm.Model() as model_l:
322 ell_l = pm.TruncatedNormal("ell", mu=1.0, sigma=0.5, lower=0.1, shape=p)
323 eta_l = pm.HalfNormal("eta", sigma=1.0)
324 sigma_l = pm.HalfNormal("sigma", sigma=1.0)
325 mean_c = pm.Normal("mean_const", mu=0.0, sigma=10.0)
327 K_l = eta_l**2 * pm.gp.cov.Matern52(input_dim=p, ls=ell_l)
328 m_l = pm.gp.mean.Constant(mean_c)
329 gp_l = pm.gp.Marginal(mean_func=m_l, cov_func=K_l)
331 _ = gp_l.marginal_likelihood("y_obs", X=Xn_success_only, y=y_loss_centered, sigma=sigma_l)
332 map_l = pm.find_MAP(start=prior_start_loss) if prior_start_loss is not None else pm.find_MAP()
334 with model_l:
335 mu_l_c, var_l = gp_l.predict(Xn_success_only, point=map_l, diag=True, pred_noise=True)
336 mu_l = mu_l_c + conditional_loss_mean
337 sd_l = np.sqrt(var_l)
339 # ---------- Diagnostics (per-feature) ----------
340 # raw stats in ORIGINAL units (before any transform)
341 raw_stats = {
342 "raw_min": [], "raw_max": [], "raw_mean": [], "raw_std": [], "n_unique_raw": [],
343 }
344 for k, name in enumerate(feature_names):
345 # Try to recover a raw column if present; otherwise invert transform on X_raw[:,k]
346 if name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[name]):
347 raw_col = df_raw[name].to_numpy(dtype=float)
348 else:
349 # Inverse transform of "internal" values only if it's a simple one:
350 internal = X_raw[:, k]
351 tr = transforms[k]
352 if tr == "log10":
353 raw_col = np.power(10.0, internal)
354 else:
355 raw_col = internal
356 x = np.asarray(raw_col, dtype=float)
357 x_finite = x[np.isfinite(x)]
358 if x_finite.size == 0:
359 x_finite = np.array([np.nan])
360 raw_stats["raw_min"].append(float(np.nanmin(x_finite)))
361 raw_stats["raw_max"].append(float(np.nanmax(x_finite)))
362 raw_stats["raw_mean"].append(float(np.nanmean(x_finite)))
363 raw_stats["raw_std"].append(float(np.nanstd(x_finite)))
364 raw_stats["n_unique_raw"].append(int(np.unique(np.round(x_finite, 12)).size))
366 # internal (transformed) stats PRIOR to standardization
367 internal_min = np.nanmin(X_raw, axis=0)
368 internal_max = np.nanmax(X_raw, axis=0)
369 internal_mean = np.nanmean(X_raw, axis=0)
370 internal_std = np.nanstd(X_raw, axis=0)
372 # standardized 1–99% span (what PD uses)
373 Xn_p01 = np.percentile(Xn, 1, axis=0)
374 Xn_p99 = np.percentile(Xn, 99, axis=0)
375 Xn_span = Xn_p99 - Xn_p01
376 Xn_span = np.where(np.isfinite(Xn_span), Xn_span, np.nan)
378 # lengthscale-to-span ratios (big → likely flat)
379 ell_s_arr = _np1d(map_s["ell"], p)
380 ell_l_arr = _np1d(map_l["ell"], p)
381 with np.errstate(divide="ignore", invalid="ignore"):
382 ell_over_span_success = np.where(Xn_span > 0, ell_s_arr / Xn_span, np.nan)
383 ell_over_span_loss = np.where(Xn_span > 0, ell_l_arr / Xn_span, np.nan)
385 # simple correlations (training)
386 def _safe_corr(a, b):
387 a = np.asarray(a, float); b = np.asarray(b, float)
388 m = np.isfinite(a) & np.isfinite(b)
389 if m.sum() < 3:
390 return np.nan
391 va = np.var(a[m]); vb = np.var(b[m])
392 if va == 0 or vb == 0:
393 return np.nan
394 return float(np.corrcoef(a[m], b[m])[0,1])
396 corr_success = np.array([_safe_corr(X_raw[:, j], y_success) for j in range(p)], dtype=float)
397 corr_loss_success = np.array(
398 [_safe_corr(X_raw[ok_mask, j], y_loss_success) if ok_mask.any() else np.nan for j in range(p)],
399 dtype=float
400 )
402 # ---------- Build xarray Dataset ----------
403 feature_coord = xr.DataArray(np.array(feature_names, dtype=object), dims=("feature",), name="feature")
404 row_coord = xr.DataArray(np.arange(n, dtype=np.int64), dims=("row",), name="row")
405 row_ok_coord = xr.DataArray(np.where(ok_mask)[0].astype(np.int64), dims=("row_success",), name="row_success")
407 # Raw columns (strings and numerics) — from df_raw
408 raw_vars = {}
409 for col in df_raw.columns:
410 s = df_raw[col]
411 if (
412 pd.api.types.is_integer_dtype(s)
413 or pd.api.types.is_float_dtype(s)
414 or pd.api.types.is_bool_dtype(s)
415 ):
416 raw_vars[col] = (("row",), s.to_numpy())
417 else:
418 s_str = s.astype("string").fillna("<NA>")
419 vals = s_str.to_numpy(dtype="U")
420 raw_vars[col] = (("row",), vals)
422 ds = xr.Dataset(
423 data_vars={
424 # Design matrices (standardized)
425 "Xn_train": (("row", "feature"), Xn),
426 "Xn_success_only": (("row_success", "feature"), Xn_success_only),
428 # Targets
429 "y_success": (("row",), y_success),
430 "y_loss_success": (("row_success",), y_loss_success),
431 "y_loss_centered": (("row_success",), y_loss_centered),
433 # Standardization + transforms
434 "feature_mean": (("feature",), X_mean),
435 "feature_std": (("feature",), X_std),
436 "feature_transform": (("feature",), np.array(transforms, dtype=object)),
438 # Masks / indexing
439 "success_mask": (("row",), ok_mask.astype(np.int8)),
441 # Head A (success) MAP params
442 "map_success_ell": (("feature",), ell_s_arr),
443 "map_success_eta": ((), float(np.asarray(map_s["eta"]))),
444 "map_success_sigma": ((), float(np.asarray(map_s["sigma"]))),
445 "map_success_beta0": ((), float(np.asarray(map_s["beta0"]))),
447 # Head B (loss|success) MAP params
448 "map_loss_ell": (("feature",), ell_l_arr),
449 "map_loss_eta": ((), float(np.asarray(map_l["eta"]))),
450 "map_loss_sigma": ((), float(np.asarray(map_l["sigma"]))),
451 "map_loss_mean_const": ((), float(np.asarray(map_l["mean_const"]))),
453 # Convenience predictions on training data
454 "pred_success_mu_train": (("row",), mu_s),
455 "pred_success_var_train": (("row",), var_s),
456 "pred_loss_mu_success_train": (("row_success",), mu_l),
457 "pred_loss_sd_success_train": (("row_success",), sd_l),
459 # Useful scalars for predictors
460 "conditional_loss_mean": ((), float(conditional_loss_mean)),
462 # ------- Diagnostics (per-feature) -------
463 "raw_min": (("feature",), np.array(raw_stats["raw_min"], dtype=float)),
464 "raw_max": (("feature",), np.array(raw_stats["raw_max"], dtype=float)),
465 "raw_mean": (("feature",), np.array(raw_stats["raw_mean"], dtype=float)),
466 "raw_std": (("feature",), np.array(raw_stats["raw_std"], dtype=float)),
467 "n_unique_raw": (("feature",), np.array(raw_stats["n_unique_raw"], dtype=np.int32)),
469 "internal_min": (("feature",), internal_min.astype(float)),
470 "internal_max": (("feature",), internal_max.astype(float)),
471 "internal_mean": (("feature",), internal_mean.astype(float)),
472 "internal_std": (("feature",), internal_std.astype(float)),
474 "Xn_p01": (("feature",), Xn_p01.astype(float)),
475 "Xn_p99": (("feature",), Xn_p99.astype(float)),
476 "Xn_span": (("feature",), Xn_span.astype(float)),
478 "ell_over_span_success": (("feature",), ell_over_span_success.astype(float)),
479 "ell_over_span_loss": (("feature",), ell_over_span_loss.astype(float)),
481 "corr_success": (("feature",), corr_success),
482 "corr_loss_success": (("feature",), corr_loss_success),
484 "feature_is_onehot_member": (("feature",), np.array(is_onehot_member, dtype=np.int8)),
485 "feature_onehot_base": (("feature",), np.array(onehot_base_per_feature, dtype=object)),
486 },
487 coords={
488 "feature": feature_coord,
489 "row": row_coord,
490 "row_success": row_ok_coord,
491 },
492 attrs={
493 "artifact_version": "0.1.2", # bumped for diagnostics
494 "target": target,
495 "direction": direction,
496 "rng_state": rng_state_b64,
497 "rng_bitgen": rng.bit_generator.__class__.__name__,
498 "numpy_version": np.__version__,
499 "pymc_version": pm.__version__,
500 "n_rows": int(n),
501 "n_success_rows": int(int(ok_mask.sum())),
502 "success_rate": float(base_success_rate),
503 "categorical_groups_json": json.dumps(categorical_groups), # base -> {labels, members}
504 },
505 )
507 # Attach raw columns
508 for k, v in raw_vars.items():
509 ds[k] = v
511 # ---------- Save ----------
512 encoding = None
513 if compress:
514 encoding = {}
515 for name, da in ds.data_vars.items():
516 if np.issubdtype(da.dtype, np.number):
517 encoding[name] = {"zlib": True, "complevel": 4}
519 if output:
520 output = Path(output)
521 output.parent.mkdir(parents=True, exist_ok=True)
522 engine, encoding = _select_netcdf_engine_and_encoding(ds, compress=compress)
523 ds.to_netcdf(output, engine=engine, encoding=encoding)
525 _print_diagnostics_table(ds, top_k=20)
527 return ds
531def kernel_diag_m52(XA: np.ndarray, ls: np.ndarray, eta: float) -> np.ndarray:
532 return np.full(XA.shape[0], eta ** 2, dtype=float)
535def kernel_m52_ard(XA: np.ndarray, XB: np.ndarray, ls: np.ndarray, eta: float) -> np.ndarray:
536 XA = np.asarray(XA, float)
537 XB = np.asarray(XB, float)
538 ls = np.asarray(ls, float).reshape(1, 1, -1)
539 diff = (XA[:, None, :] - XB[None, :, :]) / ls
540 r2 = np.sum(diff * diff, axis=2)
541 r = np.sqrt(np.maximum(r2, 0.0))
542 sqrt5_r = np.sqrt(5.0) * r
543 k = (eta ** 2) * (1.0 + sqrt5_r + (5.0 / 3.0) * r2) * np.exp(-sqrt5_r)
544 return k
547def add_jitter(K: np.ndarray, eps: float = 1e-8) -> np.ndarray:
548 jitter = eps * float(np.mean(np.diag(K)) + 1.0)
549 return K + jitter * np.eye(K.shape[0], dtype=K.dtype)
552def solve_chol(L: np.ndarray, b: np.ndarray) -> np.ndarray:
553 y = np.linalg.solve(L, b)
554 return np.linalg.solve(L.T, y)
557def solve_lower(L: np.ndarray, B: np.ndarray) -> np.ndarray:
558 return np.linalg.solve(L, B)
561def feature_raw_from_artifact_or_reconstruct(
562 ds: xr.Dataset,
563 j: int,
564 name: str,
565 transform: str,
566) -> np.ndarray:
567 """
568 Return the feature values in ORIGINAL units for each training row.
569 Prefer a stored raw column (ds[name]) if present; otherwise reconstruct
570 from Xn_train using feature_mean/std and the recorded transform.
571 """
572 # 1) Use stored raw per-row column if present
573 if name in ds.data_vars:
574 da = ds[name]
575 if "row" in da.dims and da.sizes.get("row", None) == ds.sizes.get("row", None):
576 vals = np.asarray(da.values, dtype=float)
577 return vals
579 # 2) Reconstruct from standardized training matrix
580 Xn = ds["Xn_train"].values.astype(float) # (N, p)
581 mu = ds["feature_mean"].values.astype(float)[j]
582 sd = ds["feature_std"].values.astype(float)[j]
583 x_internal = Xn[:, j] * sd + mu # internal model space
584 if transform == "log10":
585 raw = 10.0 ** x_internal
586 else:
587 raw = x_internal
588 return raw
591# ---------------------------------------------------------------------
592# Helpers
593# ---------------------------------------------------------------------
594def _to_bool01(arr: np.ndarray) -> np.ndarray:
595 """Map arbitrary truthy/falsy values to {0,1} int array."""
596 if arr.dtype == bool:
597 return arr.astype(np.int32)
598 if np.issubdtype(arr.dtype, np.number):
599 return (arr != 0).astype(np.int32)
600 truthy = {"1", "true", "yes", "y", "ok", "success"}
601 return np.array([1 if (str(x).strip().lower() in truthy) else 0 for x in arr], dtype=np.int32)
604def _choose_transform(name: str, col: np.ndarray) -> str:
605 """
606 Choose a simple per-feature transform: 'identity' or 'log10'.
607 Heuristics:
608 - if column name looks like a learning rate (lr/learning_rate) AND >0 => log10
609 - else if strictly positive and p99/p1 >= 1e3 => log10
610 - else identity
611 """
612 name_l = name.lower()
613 strictly_pos = np.all(np.isfinite(col)) and np.nanmin(col) > 0.0
614 looks_lr = ("learning_rate" in name_l) or (name_l == "lr")
615 if strictly_pos and (looks_lr or _large_dynamic_range(col)):
616 return "log10"
617 return "identity"
620def _large_dynamic_range(col: np.ndarray) -> bool:
621 x = col[np.isfinite(col)]
622 if x.size == 0:
623 return False
624 p1, p99 = np.percentile(x, [1, 99])
625 p1 = max(p1, 1e-12)
626 return (p99 / p1) >= 1e3
629def _apply_transform(tr: str, col: np.ndarray) -> np.ndarray:
630 if tr == "log10":
631 x = np.asarray(col, dtype=float)
632 x = np.where(x <= 0, np.nan, x)
633 return np.log10(x)
634 return np.asarray(col, dtype=float)
637def _np1d(x: np.ndarray, p: int) -> np.ndarray:
638 a = np.asarray(x, dtype=float).ravel()
639 if a.size != p:
640 a = np.full((p,), float(a.item()) if a.size == 1 else np.nan, dtype=float)
641 return a
644# ---- choose engine + encoding safely across backends ----
645def _select_netcdf_engine_and_encoding(ds: xr.Dataset, compress: bool):
646 # Prefer netcdf4
647 try:
648 import netCDF4 # noqa: F401
649 engine = "netcdf4"
650 if not compress:
651 return engine, None
652 enc = {}
653 for name, da in ds.data_vars.items():
654 if np.issubdtype(da.dtype, np.number):
655 enc[name] = {"zlib": True, "complevel": 4}
656 return engine, enc
657 except Exception:
658 pass
660 # Then h5netcdf
661 try:
662 import h5netcdf # noqa: F401
663 engine = "h5netcdf"
664 if not compress:
665 return engine, None
666 enc = {}
667 for name, da in ds.data_vars.items():
668 if np.issubdtype(da.dtype, np.number):
669 enc[name] = {"compression": "gzip", "compression_opts": 4}
670 return engine, enc
671 except Exception:
672 pass
674 # Finally scipy (no compression supported)
675 engine = "scipy"
676 return engine, None