Coverage for psyop/model.py: 78.09%
324 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-10 06:02 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-10 06:02 +0000
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4# Make BLAS single-threaded to avoid oversubscription / macOS crashes
5import os
6for _env_var in (
7 "MKL_NUM_THREADS",
8 "OPENBLAS_NUM_THREADS",
9 "OMP_NUM_THREADS",
10 "VECLIB_MAXIMUM_THREADS",
11 "NUMEXPR_NUM_THREADS",
12):
13 os.environ.setdefault(_env_var, "1")
15from pathlib import Path
16import base64
17import pickle
18import json
20import numpy as np
21import pandas as pd
22import pymc as pm
23import xarray as xr
24from rich.console import Console
25from rich.table import Table
27from .util import get_rng, df_to_table
29def _safe_vec(ds: xr.Dataset, name: str, nF: int) -> np.ndarray:
30 if name not in ds:
31 return np.full(nF, np.nan)
32 arr = np.asarray(ds[name].values)
33 if arr.size == nF:
34 return arr
35 out = np.full(nF, np.nan)
36 out[:min(nF, arr.size)] = arr.ravel()[:min(nF, arr.size)]
37 return out
39def _safe_scalar(ds: xr.Dataset, name: str) -> float:
40 if name in ds:
41 try: return float(np.asarray(ds[name].values).item())
42 except Exception: return float(np.nan)
43 if name in ds.attrs:
44 try: return float(ds.attrs[name])
45 except Exception: return float(np.nan)
46 return float(np.nan)
48def _diagnostic_feature_dataframe(ds: xr.Dataset, top_k: int = 20) -> pd.DataFrame:
49 """Top-K features by |corr_loss_success| without any repeated globals."""
50 f = [str(x) for x in ds["feature"].values]
51 nF = len(f)
53 is_oh = ds["feature_is_onehot_member"].values.astype(int) if "feature_is_onehot_member" in ds else np.zeros(nF, int)
54 base = ds["feature_onehot_base"].values if "feature_onehot_base" in ds else np.array([""]*nF, object)
56 df = pd.DataFrame({
57 "feature": f,
58 "type": np.where(is_oh == 1, "categorical(one-hot)", "numeric"),
59 "onehot_base": np.where(is_oh == 1, base, ""),
60 "n_unique_raw": _safe_vec(ds, "n_unique_raw", nF),
61 "raw_min": _safe_vec(ds, "raw_min", nF),
62 "raw_max": _safe_vec(ds, "raw_max", nF),
63 "Xn_span": _safe_vec(ds, "Xn_span", nF),
64 "ell_s": _safe_vec(ds, "map_success_ell", nF),
65 "ell_l": _safe_vec(ds, "map_loss_ell", nF),
66 "ell/span_s": _safe_vec(ds, "ell_over_span_success", nF),
67 "ell/span_l": _safe_vec(ds, "ell_over_span_loss", nF),
68 "corr_success": _safe_vec(ds, "corr_success", nF),
69 "corr_loss_success": _safe_vec(ds, "corr_loss_success", nF),
70 })
71 df["|corr_loss_success|"] = np.abs(df["corr_loss_success"].astype(float))
72 return df.sort_values("|corr_loss_success|", ascending=False, kind="mergesort").head(top_k).reset_index(drop=True)
74def _diagnostic_global_dataframe(ds: xr.Dataset) -> pd.DataFrame:
75 """Single-row DataFrame of global/model-level diagnostics."""
76 # Scalars from attrs
77 target = ds.attrs.get("target", "")
78 direction = ds.attrs.get("direction", "")
79 n_rows = int(ds.attrs.get("n_rows", np.nan))
80 n_success_rows = int(ds.attrs.get("n_success_rows", np.nan)) if "n_success_rows" in ds.attrs else int(np.sum(np.asarray(ds["success_mask"].values)) if "success_mask" in ds else np.nan)
81 success_rate = float(ds.attrs.get("success_rate", np.nan))
82 rng_bitgen = ds.attrs.get("rng_bitgen", "")
83 numpy_version = ds.attrs.get("numpy_version", "")
84 pymc_version = ds.attrs.get("pymc_version", "")
86 # Scalars from data_vars (with attr fallback)
87 conditional_loss_mean = _safe_scalar(ds, "conditional_loss_mean")
89 # GP MAP scalars
90 map_success_eta = _safe_scalar(ds, "map_success_eta")
91 map_success_sigma = _safe_scalar(ds, "map_success_sigma")
92 map_success_beta0 = _safe_scalar(ds, "map_success_beta0")
94 map_loss_eta = _safe_scalar(ds, "map_loss_eta")
95 map_loss_sigma = _safe_scalar(ds, "map_loss_sigma")
96 map_loss_mean_c = _safe_scalar(ds, "map_loss_mean_const")
98 # Data dispersion on successes
99 y_ls_std = float(np.nanstd(ds["y_loss_success"].values)) if "y_loss_success" in ds else np.nan
101 # Handy amplitude/noise ratio for loss head
102 eta_l_over_sigma_l = (map_loss_eta / map_loss_sigma) if (np.isfinite(map_loss_eta) and np.isfinite(map_loss_sigma) and map_loss_sigma != 0) else np.nan
104 rows = [{
105 "target": target,
106 "direction": direction,
107 "n_rows": n_rows,
108 "n_success_rows": n_success_rows,
109 "success_rate": success_rate,
110 "conditional_loss_mean": conditional_loss_mean,
111 "map_success_eta": map_success_eta,
112 "map_success_sigma": map_success_sigma,
113 "map_success_beta0": map_success_beta0,
114 "map_loss_eta": map_loss_eta,
115 "map_loss_sigma": map_loss_sigma,
116 "map_loss_mean_const": map_loss_mean_c,
117 "y_loss_success_std": y_ls_std,
118 "eta_l/sigma_l": eta_l_over_sigma_l,
119 "rng_bitgen": rng_bitgen,
120 "numpy_version": numpy_version,
121 "pymc_version": pymc_version,
122 }]
123 return pd.DataFrame(rows)
127def _print_diagnostics_table(ds: xr.Dataset, top_k: int = 20) -> None:
128 """3-sigfig Rich table with key diagnostics."""
129 feat_df = _diagnostic_feature_dataframe(ds, top_k=20)
130 global_df = _diagnostic_global_dataframe(ds)
132 console = Console()
134 console.print("\n[bold]Model diagnostics (top by |corr_loss_success|):[/]")
135 console.print(df_to_table(feat_df, transpose=False, show_index=False)) # regular table
137 console.print("\n[bold]Model globals:[/]")
138 # transpose for key/value look, with magenta header column
139 console.print(df_to_table(global_df, transpose=True))
142def build_model(
143 input: pd.DataFrame|Path|str,
144 target: str,
145 output: Path | str | None = None,
146 exclude: list[str] | str | None = None,
147 direction: str = "auto",
148 seed: int | np.random.Generator | None = 42,
149 compress: bool = True,
150) -> xr.Dataset:
151 """
152 Fit two-head GP (success prob + conditional loss) and save a single NetCDF artifact.
153 Also stores rich diagnostics to help debug flat PD curves and model wiring.
154 """
155 # ---------- Load ----------
156 if isinstance(input, pd.DataFrame):
157 df = input
158 else:
159 input_path = Path(input)
160 if not input_path.exists():
161 raise FileNotFoundError(f"Input CSV not found: {input_path.resolve()}")
162 df = pd.read_csv(input_path)
164 if target not in df.columns:
165 raise ValueError(f"Target column '{target}' not found in CSV.")
167 # Keep a raw copy for artifact (strings as pandas 'string' so NetCDF can handle them)
168 df_raw = df.copy()
169 for c in df_raw.columns:
170 if df_raw[c].dtype == object:
171 df_raw[c] = df_raw[c].astype("string")
173 # ---------- Success inference ----------
174 success = (~df[target].isna()).to_numpy().astype(int)
175 has_success = bool(np.any(success == 1))
176 if not has_success:
177 raise RuntimeError("No successful rows detected (cannot fit conditional-loss GP).")
179 # ---------- Feature selection ----------
180 reserved_internals = {"__success__", "__fail__", "__status__"}
181 exclude = [exclude] if isinstance(exclude, str) else (exclude or [])
182 excluded = set(exclude) | {target} | reserved_internals
184 numeric_cols = [
185 c for c in df.columns
186 if c not in excluded and pd.api.types.is_numeric_dtype(df[c])
187 ]
188 cat_cols = [
189 c for c in df.columns
190 if c not in excluded
191 and (pd.api.types.is_string_dtype(df[c]) or pd.api.types.is_categorical_dtype(df[c]))
192 ]
194 feature_names: list[str] = []
195 transforms: list[str] = []
196 X_raw_cols: list[np.ndarray] = []
198 # Diagnostics to fill
199 onehot_base_per_feature: list[str] = [] # "" for numeric / non-onehot
200 is_onehot_member: list[int] = []
201 categorical_groups: dict[str, dict] = {} # base -> {"labels":[...], "members":[...]}
203 # numeric features
204 for name in numeric_cols:
205 col = df[name].to_numpy(dtype=float)
206 tr = _choose_transform(name, col)
207 transforms.append(tr)
208 feature_names.append(name)
209 X_raw_cols.append(_apply_transform(tr, col))
210 onehot_base_per_feature.append("")
211 is_onehot_member.append(0)
213 # categoricals → one-hot
214 for base in cat_cols:
215 s_cat = pd.Categorical(df[base].astype("string").fillna("<NA>"))
216 H = pd.get_dummies(s_cat, prefix=base, prefix_sep="=", dtype=float) # e.g., language=Linear A
217 members = []
218 labels = []
219 for new_col in H.columns:
220 feature_names.append(new_col)
221 transforms.append("identity")
222 X_raw_cols.append(H[new_col].to_numpy(dtype=float))
223 onehot_base_per_feature.append(base)
224 is_onehot_member.append(1)
225 members.append(new_col)
226 # label is the part after "base="
227 labels.append(str(new_col.split("=", 1)[1]) if "=" in new_col else str(new_col))
228 categorical_groups[base] = {"labels": labels, "members": members}
230 X_raw = np.column_stack(X_raw_cols).astype(float)
231 n, p = X_raw.shape
233 # ---------- Standardize ----------
234 X_mean = X_raw.mean(axis=0)
235 X_std = X_raw.std(axis=0)
236 X_std = np.where(X_std == 0.0, 1.0, X_std) # keep inert dims harmless
237 Xn = (X_raw - X_mean) / X_std
239 # Targets
240 y_success = success.astype(float)
241 ok_mask = (success == 1)
242 y_loss_success = df.loc[ok_mask, target].to_numpy(dtype=float)
244 conditional_loss_mean = float(np.nanmean(y_loss_success)) if len(y_loss_success) else 0.0
245 y_loss_centered = y_loss_success - conditional_loss_mean
246 Xn_success_only = Xn[ok_mask, :]
248 rng = get_rng(seed)
249 state_bytes = pickle.dumps(rng.bit_generator.state)
250 rng_state_b64 = base64.b64encode(state_bytes).decode("ascii")
252 base_success_rate = float(y_success.mean())
254 # ---------- Fit Head A: success ----------
255 with pm.Model() as model_s:
256 ell_s = pm.HalfNormal("ell", sigma=2.0, shape=p)
257 eta_s = pm.HalfNormal("eta", sigma=2.0)
258 sigma_s = pm.HalfNormal("sigma", sigma=0.3)
259 beta0_s = pm.Normal("beta0", mu=base_success_rate, sigma=0.15)
261 K_s = eta_s**2 * pm.gp.cov.Matern52(input_dim=p, ls=ell_s)
262 m_s = pm.gp.mean.Constant(beta0_s)
263 gp_s = pm.gp.Marginal(mean_func=m_s, cov_func=K_s)
265 _ = gp_s.marginal_likelihood("y_obs_s", X=Xn, y=y_success, sigma=sigma_s)
266 map_s = pm.find_MAP()
268 with model_s:
269 mu_s, var_s = gp_s.predict(Xn, point=map_s, diag=True, pred_noise=True)
270 mu_s = np.clip(mu_s, 0.0, 1.0)
272 # ---------- Fit Head B: conditional loss (success-only) ----------
273 if Xn_success_only.shape[0] == 0:
274 raise RuntimeError("No successful rows to fit the conditional-loss GP.")
276 with pm.Model() as model_l:
277 ell_l = pm.TruncatedNormal("ell", mu=1.0, sigma=0.5, lower=0.1, shape=p)
278 eta_l = pm.HalfNormal("eta", sigma=1.0)
279 sigma_l = pm.HalfNormal("sigma", sigma=1.0)
280 mean_c = pm.Normal("mean_const", mu=0.0, sigma=10.0)
282 K_l = eta_l**2 * pm.gp.cov.Matern52(input_dim=p, ls=ell_l)
283 m_l = pm.gp.mean.Constant(mean_c)
284 gp_l = pm.gp.Marginal(mean_func=m_l, cov_func=K_l)
286 _ = gp_l.marginal_likelihood("y_obs", X=Xn_success_only, y=y_loss_centered, sigma=sigma_l)
287 map_l = pm.find_MAP()
289 with model_l:
290 mu_l_c, var_l = gp_l.predict(Xn_success_only, point=map_l, diag=True, pred_noise=True)
291 mu_l = mu_l_c + conditional_loss_mean
292 sd_l = np.sqrt(var_l)
294 # ---------- Diagnostics (per-feature) ----------
295 # raw stats in ORIGINAL units (before any transform)
296 raw_stats = {
297 "raw_min": [], "raw_max": [], "raw_mean": [], "raw_std": [], "n_unique_raw": [],
298 }
299 for k, name in enumerate(feature_names):
300 # Try to recover a raw column if present; otherwise invert transform on X_raw[:,k]
301 if name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[name]):
302 raw_col = df_raw[name].to_numpy(dtype=float)
303 else:
304 # Inverse transform of "internal" values only if it's a simple one:
305 internal = X_raw[:, k]
306 tr = transforms[k]
307 if tr == "log10":
308 raw_col = np.power(10.0, internal)
309 else:
310 raw_col = internal
311 x = np.asarray(raw_col, dtype=float)
312 x_finite = x[np.isfinite(x)]
313 if x_finite.size == 0:
314 x_finite = np.array([np.nan])
315 raw_stats["raw_min"].append(float(np.nanmin(x_finite)))
316 raw_stats["raw_max"].append(float(np.nanmax(x_finite)))
317 raw_stats["raw_mean"].append(float(np.nanmean(x_finite)))
318 raw_stats["raw_std"].append(float(np.nanstd(x_finite)))
319 raw_stats["n_unique_raw"].append(int(np.unique(np.round(x_finite, 12)).size))
321 # internal (transformed) stats PRIOR to standardization
322 internal_min = np.nanmin(X_raw, axis=0)
323 internal_max = np.nanmax(X_raw, axis=0)
324 internal_mean = np.nanmean(X_raw, axis=0)
325 internal_std = np.nanstd(X_raw, axis=0)
327 # standardized 1–99% span (what PD uses)
328 Xn_p01 = np.percentile(Xn, 1, axis=0)
329 Xn_p99 = np.percentile(Xn, 99, axis=0)
330 Xn_span = Xn_p99 - Xn_p01
331 Xn_span = np.where(np.isfinite(Xn_span), Xn_span, np.nan)
333 # lengthscale-to-span ratios (big → likely flat)
334 ell_s_arr = _np1d(map_s["ell"], p)
335 ell_l_arr = _np1d(map_l["ell"], p)
336 with np.errstate(divide="ignore", invalid="ignore"):
337 ell_over_span_success = np.where(Xn_span > 0, ell_s_arr / Xn_span, np.nan)
338 ell_over_span_loss = np.where(Xn_span > 0, ell_l_arr / Xn_span, np.nan)
340 # simple correlations (training)
341 def _safe_corr(a, b):
342 a = np.asarray(a, float); b = np.asarray(b, float)
343 m = np.isfinite(a) & np.isfinite(b)
344 if m.sum() < 3:
345 return np.nan
346 va = np.var(a[m]); vb = np.var(b[m])
347 if va == 0 or vb == 0:
348 return np.nan
349 return float(np.corrcoef(a[m], b[m])[0,1])
351 corr_success = np.array([_safe_corr(X_raw[:, j], y_success) for j in range(p)], dtype=float)
352 corr_loss_success = np.array(
353 [_safe_corr(X_raw[ok_mask, j], y_loss_success) if ok_mask.any() else np.nan for j in range(p)],
354 dtype=float
355 )
357 # ---------- Build xarray Dataset ----------
358 feature_coord = xr.DataArray(np.array(feature_names, dtype=object), dims=("feature",), name="feature")
359 row_coord = xr.DataArray(np.arange(n, dtype=np.int64), dims=("row",), name="row")
360 row_ok_coord = xr.DataArray(np.where(ok_mask)[0].astype(np.int64), dims=("row_success",), name="row_success")
362 # Raw columns (strings and numerics) — from df_raw
363 raw_vars = {}
364 for col in df_raw.columns:
365 s = df_raw[col]
366 if (
367 pd.api.types.is_integer_dtype(s)
368 or pd.api.types.is_float_dtype(s)
369 or pd.api.types.is_bool_dtype(s)
370 ):
371 raw_vars[col] = (("row",), s.to_numpy())
372 else:
373 s_str = s.astype("string").fillna("<NA>")
374 vals = s_str.to_numpy(dtype="U")
375 raw_vars[col] = (("row",), vals)
377 ds = xr.Dataset(
378 data_vars={
379 # Design matrices (standardized)
380 "Xn_train": (("row", "feature"), Xn),
381 "Xn_success_only": (("row_success", "feature"), Xn_success_only),
383 # Targets
384 "y_success": (("row",), y_success),
385 "y_loss_success": (("row_success",), y_loss_success),
386 "y_loss_centered": (("row_success",), y_loss_centered),
388 # Standardization + transforms
389 "feature_mean": (("feature",), X_mean),
390 "feature_std": (("feature",), X_std),
391 "feature_transform": (("feature",), np.array(transforms, dtype=object)),
393 # Masks / indexing
394 "success_mask": (("row",), ok_mask.astype(np.int8)),
396 # Head A (success) MAP params
397 "map_success_ell": (("feature",), ell_s_arr),
398 "map_success_eta": ((), float(np.asarray(map_s["eta"]))),
399 "map_success_sigma": ((), float(np.asarray(map_s["sigma"]))),
400 "map_success_beta0": ((), float(np.asarray(map_s["beta0"]))),
402 # Head B (loss|success) MAP params
403 "map_loss_ell": (("feature",), ell_l_arr),
404 "map_loss_eta": ((), float(np.asarray(map_l["eta"]))),
405 "map_loss_sigma": ((), float(np.asarray(map_l["sigma"]))),
406 "map_loss_mean_const": ((), float(np.asarray(map_l["mean_const"]))),
408 # Convenience predictions on training data
409 "pred_success_mu_train": (("row",), mu_s),
410 "pred_success_var_train": (("row",), var_s),
411 "pred_loss_mu_success_train": (("row_success",), mu_l),
412 "pred_loss_sd_success_train": (("row_success",), sd_l),
414 # Useful scalars for predictors
415 "conditional_loss_mean": ((), float(conditional_loss_mean)),
417 # ------- Diagnostics (per-feature) -------
418 "raw_min": (("feature",), np.array(raw_stats["raw_min"], dtype=float)),
419 "raw_max": (("feature",), np.array(raw_stats["raw_max"], dtype=float)),
420 "raw_mean": (("feature",), np.array(raw_stats["raw_mean"], dtype=float)),
421 "raw_std": (("feature",), np.array(raw_stats["raw_std"], dtype=float)),
422 "n_unique_raw": (("feature",), np.array(raw_stats["n_unique_raw"], dtype=np.int32)),
424 "internal_min": (("feature",), internal_min.astype(float)),
425 "internal_max": (("feature",), internal_max.astype(float)),
426 "internal_mean": (("feature",), internal_mean.astype(float)),
427 "internal_std": (("feature",), internal_std.astype(float)),
429 "Xn_p01": (("feature",), Xn_p01.astype(float)),
430 "Xn_p99": (("feature",), Xn_p99.astype(float)),
431 "Xn_span": (("feature",), Xn_span.astype(float)),
433 "ell_over_span_success": (("feature",), ell_over_span_success.astype(float)),
434 "ell_over_span_loss": (("feature",), ell_over_span_loss.astype(float)),
436 "corr_success": (("feature",), corr_success),
437 "corr_loss_success": (("feature",), corr_loss_success),
439 "feature_is_onehot_member": (("feature",), np.array(is_onehot_member, dtype=np.int8)),
440 "feature_onehot_base": (("feature",), np.array(onehot_base_per_feature, dtype=object)),
441 },
442 coords={
443 "feature": feature_coord,
444 "row": row_coord,
445 "row_success": row_ok_coord,
446 },
447 attrs={
448 "artifact_version": "0.1.2", # bumped for diagnostics
449 "target": target,
450 "direction": direction,
451 "rng_state": rng_state_b64,
452 "rng_bitgen": rng.bit_generator.__class__.__name__,
453 "numpy_version": np.__version__,
454 "pymc_version": pm.__version__,
455 "n_rows": int(n),
456 "n_success_rows": int(int(ok_mask.sum())),
457 "success_rate": float(base_success_rate),
458 "categorical_groups_json": json.dumps(categorical_groups), # base -> {labels, members}
459 },
460 )
462 # Attach raw columns
463 for k, v in raw_vars.items():
464 ds[k] = v
466 # ---------- Save ----------
467 encoding = None
468 if compress:
469 encoding = {}
470 for name, da in ds.data_vars.items():
471 if np.issubdtype(da.dtype, np.number):
472 encoding[name] = {"zlib": True, "complevel": 4}
474 if output:
475 output = Path(output)
476 output.parent.mkdir(parents=True, exist_ok=True)
477 engine, encoding = _select_netcdf_engine_and_encoding(ds, compress=compress)
478 ds.to_netcdf(output, engine=engine, encoding=encoding)
480 _print_diagnostics_table(ds, top_k=20)
482 return ds
486def kernel_diag_m52(XA: np.ndarray, ls: np.ndarray, eta: float) -> np.ndarray:
487 return np.full(XA.shape[0], eta ** 2, dtype=float)
490def kernel_m52_ard(XA: np.ndarray, XB: np.ndarray, ls: np.ndarray, eta: float) -> np.ndarray:
491 XA = np.asarray(XA, float)
492 XB = np.asarray(XB, float)
493 ls = np.asarray(ls, float).reshape(1, 1, -1)
494 diff = (XA[:, None, :] - XB[None, :, :]) / ls
495 r2 = np.sum(diff * diff, axis=2)
496 r = np.sqrt(np.maximum(r2, 0.0))
497 sqrt5_r = np.sqrt(5.0) * r
498 k = (eta ** 2) * (1.0 + sqrt5_r + (5.0 / 3.0) * r2) * np.exp(-sqrt5_r)
499 return k
502def add_jitter(K: np.ndarray, eps: float = 1e-8) -> np.ndarray:
503 jitter = eps * float(np.mean(np.diag(K)) + 1.0)
504 return K + jitter * np.eye(K.shape[0], dtype=K.dtype)
507def solve_chol(L: np.ndarray, b: np.ndarray) -> np.ndarray:
508 y = np.linalg.solve(L, b)
509 return np.linalg.solve(L.T, y)
512def solve_lower(L: np.ndarray, B: np.ndarray) -> np.ndarray:
513 return np.linalg.solve(L, B)
516def feature_raw_from_artifact_or_reconstruct(
517 ds: xr.Dataset,
518 j: int,
519 name: str,
520 transform: str,
521) -> np.ndarray:
522 """
523 Return the feature values in ORIGINAL units for each training row.
524 Prefer a stored raw column (ds[name]) if present; otherwise reconstruct
525 from Xn_train using feature_mean/std and the recorded transform.
526 """
527 # 1) Use stored raw per-row column if present
528 if name in ds.data_vars:
529 da = ds[name]
530 if "row" in da.dims and da.sizes.get("row", None) == ds.sizes.get("row", None):
531 vals = np.asarray(da.values, dtype=float)
532 return vals
534 # 2) Reconstruct from standardized training matrix
535 Xn = ds["Xn_train"].values.astype(float) # (N, p)
536 mu = ds["feature_mean"].values.astype(float)[j]
537 sd = ds["feature_std"].values.astype(float)[j]
538 x_internal = Xn[:, j] * sd + mu # internal model space
539 if transform == "log10":
540 raw = 10.0 ** x_internal
541 else:
542 raw = x_internal
543 return raw
546# ---------------------------------------------------------------------
547# Helpers
548# ---------------------------------------------------------------------
549def _to_bool01(arr: np.ndarray) -> np.ndarray:
550 """Map arbitrary truthy/falsy values to {0,1} int array."""
551 if arr.dtype == bool:
552 return arr.astype(np.int32)
553 if np.issubdtype(arr.dtype, np.number):
554 return (arr != 0).astype(np.int32)
555 truthy = {"1", "true", "yes", "y", "ok", "success"}
556 return np.array([1 if (str(x).strip().lower() in truthy) else 0 for x in arr], dtype=np.int32)
559def _choose_transform(name: str, col: np.ndarray) -> str:
560 """
561 Choose a simple per-feature transform: 'identity' or 'log10'.
562 Heuristics:
563 - if column name looks like a learning rate (lr/learning_rate) AND >0 => log10
564 - else if strictly positive and p99/p1 >= 1e3 => log10
565 - else identity
566 """
567 name_l = name.lower()
568 strictly_pos = np.all(np.isfinite(col)) and np.nanmin(col) > 0.0
569 looks_lr = ("learning_rate" in name_l) or (name_l == "lr")
570 if strictly_pos and (looks_lr or _large_dynamic_range(col)):
571 return "log10"
572 return "identity"
575def _large_dynamic_range(col: np.ndarray) -> bool:
576 x = col[np.isfinite(col)]
577 if x.size == 0:
578 return False
579 p1, p99 = np.percentile(x, [1, 99])
580 p1 = max(p1, 1e-12)
581 return (p99 / p1) >= 1e3
584def _apply_transform(tr: str, col: np.ndarray) -> np.ndarray:
585 if tr == "log10":
586 x = np.asarray(col, dtype=float)
587 x = np.where(x <= 0, np.nan, x)
588 return np.log10(x)
589 return np.asarray(col, dtype=float)
592def _np1d(x: np.ndarray, p: int) -> np.ndarray:
593 a = np.asarray(x, dtype=float).ravel()
594 if a.size != p:
595 a = np.full((p,), float(a.item()) if a.size == 1 else np.nan, dtype=float)
596 return a
599# ---- choose engine + encoding safely across backends ----
600def _select_netcdf_engine_and_encoding(ds: xr.Dataset, compress: bool):
601 # Prefer netcdf4
602 try:
603 import netCDF4 # noqa: F401
604 engine = "netcdf4"
605 if not compress:
606 return engine, None
607 enc = {}
608 for name, da in ds.data_vars.items():
609 if np.issubdtype(da.dtype, np.number):
610 enc[name] = {"zlib": True, "complevel": 4}
611 return engine, enc
612 except Exception:
613 pass
615 # Then h5netcdf
616 try:
617 import h5netcdf # noqa: F401
618 engine = "h5netcdf"
619 if not compress:
620 return engine, None
621 enc = {}
622 for name, da in ds.data_vars.items():
623 if np.issubdtype(da.dtype, np.number):
624 enc[name] = {"compression": "gzip", "compression_opts": 4}
625 return engine, enc
626 except Exception:
627 pass
629 # Finally scipy (no compression supported)
630 engine = "scipy"
631 return engine, None