Coverage for psyop/model.py: 78.09%

324 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-10 06:02 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4# Make BLAS single-threaded to avoid oversubscription / macOS crashes 

5import os 

6for _env_var in ( 

7 "MKL_NUM_THREADS", 

8 "OPENBLAS_NUM_THREADS", 

9 "OMP_NUM_THREADS", 

10 "VECLIB_MAXIMUM_THREADS", 

11 "NUMEXPR_NUM_THREADS", 

12): 

13 os.environ.setdefault(_env_var, "1") 

14 

15from pathlib import Path 

16import base64 

17import pickle 

18import json 

19 

20import numpy as np 

21import pandas as pd 

22import pymc as pm 

23import xarray as xr 

24from rich.console import Console 

25from rich.table import Table 

26 

27from .util import get_rng, df_to_table 

28 

29def _safe_vec(ds: xr.Dataset, name: str, nF: int) -> np.ndarray: 

30 if name not in ds: 

31 return np.full(nF, np.nan) 

32 arr = np.asarray(ds[name].values) 

33 if arr.size == nF: 

34 return arr 

35 out = np.full(nF, np.nan) 

36 out[:min(nF, arr.size)] = arr.ravel()[:min(nF, arr.size)] 

37 return out 

38 

39def _safe_scalar(ds: xr.Dataset, name: str) -> float: 

40 if name in ds: 

41 try: return float(np.asarray(ds[name].values).item()) 

42 except Exception: return float(np.nan) 

43 if name in ds.attrs: 

44 try: return float(ds.attrs[name]) 

45 except Exception: return float(np.nan) 

46 return float(np.nan) 

47 

48def _diagnostic_feature_dataframe(ds: xr.Dataset, top_k: int = 20) -> pd.DataFrame: 

49 """Top-K features by |corr_loss_success| without any repeated globals.""" 

50 f = [str(x) for x in ds["feature"].values] 

51 nF = len(f) 

52 

53 is_oh = ds["feature_is_onehot_member"].values.astype(int) if "feature_is_onehot_member" in ds else np.zeros(nF, int) 

54 base = ds["feature_onehot_base"].values if "feature_onehot_base" in ds else np.array([""]*nF, object) 

55 

56 df = pd.DataFrame({ 

57 "feature": f, 

58 "type": np.where(is_oh == 1, "categorical(one-hot)", "numeric"), 

59 "onehot_base": np.where(is_oh == 1, base, ""), 

60 "n_unique_raw": _safe_vec(ds, "n_unique_raw", nF), 

61 "raw_min": _safe_vec(ds, "raw_min", nF), 

62 "raw_max": _safe_vec(ds, "raw_max", nF), 

63 "Xn_span": _safe_vec(ds, "Xn_span", nF), 

64 "ell_s": _safe_vec(ds, "map_success_ell", nF), 

65 "ell_l": _safe_vec(ds, "map_loss_ell", nF), 

66 "ell/span_s": _safe_vec(ds, "ell_over_span_success", nF), 

67 "ell/span_l": _safe_vec(ds, "ell_over_span_loss", nF), 

68 "corr_success": _safe_vec(ds, "corr_success", nF), 

69 "corr_loss_success": _safe_vec(ds, "corr_loss_success", nF), 

70 }) 

71 df["|corr_loss_success|"] = np.abs(df["corr_loss_success"].astype(float)) 

72 return df.sort_values("|corr_loss_success|", ascending=False, kind="mergesort").head(top_k).reset_index(drop=True) 

73 

74def _diagnostic_global_dataframe(ds: xr.Dataset) -> pd.DataFrame: 

75 """Single-row DataFrame of global/model-level diagnostics.""" 

76 # Scalars from attrs 

77 target = ds.attrs.get("target", "") 

78 direction = ds.attrs.get("direction", "") 

79 n_rows = int(ds.attrs.get("n_rows", np.nan)) 

80 n_success_rows = int(ds.attrs.get("n_success_rows", np.nan)) if "n_success_rows" in ds.attrs else int(np.sum(np.asarray(ds["success_mask"].values)) if "success_mask" in ds else np.nan) 

81 success_rate = float(ds.attrs.get("success_rate", np.nan)) 

82 rng_bitgen = ds.attrs.get("rng_bitgen", "") 

83 numpy_version = ds.attrs.get("numpy_version", "") 

84 pymc_version = ds.attrs.get("pymc_version", "") 

85 

86 # Scalars from data_vars (with attr fallback) 

87 conditional_loss_mean = _safe_scalar(ds, "conditional_loss_mean") 

88 

89 # GP MAP scalars 

90 map_success_eta = _safe_scalar(ds, "map_success_eta") 

91 map_success_sigma = _safe_scalar(ds, "map_success_sigma") 

92 map_success_beta0 = _safe_scalar(ds, "map_success_beta0") 

93 

94 map_loss_eta = _safe_scalar(ds, "map_loss_eta") 

95 map_loss_sigma = _safe_scalar(ds, "map_loss_sigma") 

96 map_loss_mean_c = _safe_scalar(ds, "map_loss_mean_const") 

97 

98 # Data dispersion on successes 

99 y_ls_std = float(np.nanstd(ds["y_loss_success"].values)) if "y_loss_success" in ds else np.nan 

100 

101 # Handy amplitude/noise ratio for loss head 

102 eta_l_over_sigma_l = (map_loss_eta / map_loss_sigma) if (np.isfinite(map_loss_eta) and np.isfinite(map_loss_sigma) and map_loss_sigma != 0) else np.nan 

103 

104 rows = [{ 

105 "target": target, 

106 "direction": direction, 

107 "n_rows": n_rows, 

108 "n_success_rows": n_success_rows, 

109 "success_rate": success_rate, 

110 "conditional_loss_mean": conditional_loss_mean, 

111 "map_success_eta": map_success_eta, 

112 "map_success_sigma": map_success_sigma, 

113 "map_success_beta0": map_success_beta0, 

114 "map_loss_eta": map_loss_eta, 

115 "map_loss_sigma": map_loss_sigma, 

116 "map_loss_mean_const": map_loss_mean_c, 

117 "y_loss_success_std": y_ls_std, 

118 "eta_l/sigma_l": eta_l_over_sigma_l, 

119 "rng_bitgen": rng_bitgen, 

120 "numpy_version": numpy_version, 

121 "pymc_version": pymc_version, 

122 }] 

123 return pd.DataFrame(rows) 

124 

125 

126 

127def _print_diagnostics_table(ds: xr.Dataset, top_k: int = 20) -> None: 

128 """3-sigfig Rich table with key diagnostics.""" 

129 feat_df = _diagnostic_feature_dataframe(ds, top_k=20) 

130 global_df = _diagnostic_global_dataframe(ds) 

131 

132 console = Console() 

133 

134 console.print("\n[bold]Model diagnostics (top by |corr_loss_success|):[/]") 

135 console.print(df_to_table(feat_df, transpose=False, show_index=False)) # regular table 

136 

137 console.print("\n[bold]Model globals:[/]") 

138 # transpose for key/value look, with magenta header column 

139 console.print(df_to_table(global_df, transpose=True)) 

140 

141 

142def build_model( 

143 input: pd.DataFrame|Path|str, 

144 target: str, 

145 output: Path | str | None = None, 

146 exclude: list[str] | str | None = None, 

147 direction: str = "auto", 

148 seed: int | np.random.Generator | None = 42, 

149 compress: bool = True, 

150) -> xr.Dataset: 

151 """ 

152 Fit two-head GP (success prob + conditional loss) and save a single NetCDF artifact. 

153 Also stores rich diagnostics to help debug flat PD curves and model wiring. 

154 """ 

155 # ---------- Load ---------- 

156 if isinstance(input, pd.DataFrame): 

157 df = input 

158 else: 

159 input_path = Path(input) 

160 if not input_path.exists(): 

161 raise FileNotFoundError(f"Input CSV not found: {input_path.resolve()}") 

162 df = pd.read_csv(input_path) 

163 

164 if target not in df.columns: 

165 raise ValueError(f"Target column '{target}' not found in CSV.") 

166 

167 # Keep a raw copy for artifact (strings as pandas 'string' so NetCDF can handle them) 

168 df_raw = df.copy() 

169 for c in df_raw.columns: 

170 if df_raw[c].dtype == object: 

171 df_raw[c] = df_raw[c].astype("string") 

172 

173 # ---------- Success inference ---------- 

174 success = (~df[target].isna()).to_numpy().astype(int) 

175 has_success = bool(np.any(success == 1)) 

176 if not has_success: 

177 raise RuntimeError("No successful rows detected (cannot fit conditional-loss GP).") 

178 

179 # ---------- Feature selection ---------- 

180 reserved_internals = {"__success__", "__fail__", "__status__"} 

181 exclude = [exclude] if isinstance(exclude, str) else (exclude or []) 

182 excluded = set(exclude) | {target} | reserved_internals 

183 

184 numeric_cols = [ 

185 c for c in df.columns 

186 if c not in excluded and pd.api.types.is_numeric_dtype(df[c]) 

187 ] 

188 cat_cols = [ 

189 c for c in df.columns 

190 if c not in excluded 

191 and (pd.api.types.is_string_dtype(df[c]) or pd.api.types.is_categorical_dtype(df[c])) 

192 ] 

193 

194 feature_names: list[str] = [] 

195 transforms: list[str] = [] 

196 X_raw_cols: list[np.ndarray] = [] 

197 

198 # Diagnostics to fill 

199 onehot_base_per_feature: list[str] = [] # "" for numeric / non-onehot 

200 is_onehot_member: list[int] = [] 

201 categorical_groups: dict[str, dict] = {} # base -> {"labels":[...], "members":[...]} 

202 

203 # numeric features 

204 for name in numeric_cols: 

205 col = df[name].to_numpy(dtype=float) 

206 tr = _choose_transform(name, col) 

207 transforms.append(tr) 

208 feature_names.append(name) 

209 X_raw_cols.append(_apply_transform(tr, col)) 

210 onehot_base_per_feature.append("") 

211 is_onehot_member.append(0) 

212 

213 # categoricals → one-hot 

214 for base in cat_cols: 

215 s_cat = pd.Categorical(df[base].astype("string").fillna("<NA>")) 

216 H = pd.get_dummies(s_cat, prefix=base, prefix_sep="=", dtype=float) # e.g., language=Linear A 

217 members = [] 

218 labels = [] 

219 for new_col in H.columns: 

220 feature_names.append(new_col) 

221 transforms.append("identity") 

222 X_raw_cols.append(H[new_col].to_numpy(dtype=float)) 

223 onehot_base_per_feature.append(base) 

224 is_onehot_member.append(1) 

225 members.append(new_col) 

226 # label is the part after "base=" 

227 labels.append(str(new_col.split("=", 1)[1]) if "=" in new_col else str(new_col)) 

228 categorical_groups[base] = {"labels": labels, "members": members} 

229 

230 X_raw = np.column_stack(X_raw_cols).astype(float) 

231 n, p = X_raw.shape 

232 

233 # ---------- Standardize ---------- 

234 X_mean = X_raw.mean(axis=0) 

235 X_std = X_raw.std(axis=0) 

236 X_std = np.where(X_std == 0.0, 1.0, X_std) # keep inert dims harmless 

237 Xn = (X_raw - X_mean) / X_std 

238 

239 # Targets 

240 y_success = success.astype(float) 

241 ok_mask = (success == 1) 

242 y_loss_success = df.loc[ok_mask, target].to_numpy(dtype=float) 

243 

244 conditional_loss_mean = float(np.nanmean(y_loss_success)) if len(y_loss_success) else 0.0 

245 y_loss_centered = y_loss_success - conditional_loss_mean 

246 Xn_success_only = Xn[ok_mask, :] 

247 

248 rng = get_rng(seed) 

249 state_bytes = pickle.dumps(rng.bit_generator.state) 

250 rng_state_b64 = base64.b64encode(state_bytes).decode("ascii") 

251 

252 base_success_rate = float(y_success.mean()) 

253 

254 # ---------- Fit Head A: success ---------- 

255 with pm.Model() as model_s: 

256 ell_s = pm.HalfNormal("ell", sigma=2.0, shape=p) 

257 eta_s = pm.HalfNormal("eta", sigma=2.0) 

258 sigma_s = pm.HalfNormal("sigma", sigma=0.3) 

259 beta0_s = pm.Normal("beta0", mu=base_success_rate, sigma=0.15) 

260 

261 K_s = eta_s**2 * pm.gp.cov.Matern52(input_dim=p, ls=ell_s) 

262 m_s = pm.gp.mean.Constant(beta0_s) 

263 gp_s = pm.gp.Marginal(mean_func=m_s, cov_func=K_s) 

264 

265 _ = gp_s.marginal_likelihood("y_obs_s", X=Xn, y=y_success, sigma=sigma_s) 

266 map_s = pm.find_MAP() 

267 

268 with model_s: 

269 mu_s, var_s = gp_s.predict(Xn, point=map_s, diag=True, pred_noise=True) 

270 mu_s = np.clip(mu_s, 0.0, 1.0) 

271 

272 # ---------- Fit Head B: conditional loss (success-only) ---------- 

273 if Xn_success_only.shape[0] == 0: 

274 raise RuntimeError("No successful rows to fit the conditional-loss GP.") 

275 

276 with pm.Model() as model_l: 

277 ell_l = pm.TruncatedNormal("ell", mu=1.0, sigma=0.5, lower=0.1, shape=p) 

278 eta_l = pm.HalfNormal("eta", sigma=1.0) 

279 sigma_l = pm.HalfNormal("sigma", sigma=1.0) 

280 mean_c = pm.Normal("mean_const", mu=0.0, sigma=10.0) 

281 

282 K_l = eta_l**2 * pm.gp.cov.Matern52(input_dim=p, ls=ell_l) 

283 m_l = pm.gp.mean.Constant(mean_c) 

284 gp_l = pm.gp.Marginal(mean_func=m_l, cov_func=K_l) 

285 

286 _ = gp_l.marginal_likelihood("y_obs", X=Xn_success_only, y=y_loss_centered, sigma=sigma_l) 

287 map_l = pm.find_MAP() 

288 

289 with model_l: 

290 mu_l_c, var_l = gp_l.predict(Xn_success_only, point=map_l, diag=True, pred_noise=True) 

291 mu_l = mu_l_c + conditional_loss_mean 

292 sd_l = np.sqrt(var_l) 

293 

294 # ---------- Diagnostics (per-feature) ---------- 

295 # raw stats in ORIGINAL units (before any transform) 

296 raw_stats = { 

297 "raw_min": [], "raw_max": [], "raw_mean": [], "raw_std": [], "n_unique_raw": [], 

298 } 

299 for k, name in enumerate(feature_names): 

300 # Try to recover a raw column if present; otherwise invert transform on X_raw[:,k] 

301 if name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[name]): 

302 raw_col = df_raw[name].to_numpy(dtype=float) 

303 else: 

304 # Inverse transform of "internal" values only if it's a simple one: 

305 internal = X_raw[:, k] 

306 tr = transforms[k] 

307 if tr == "log10": 

308 raw_col = np.power(10.0, internal) 

309 else: 

310 raw_col = internal 

311 x = np.asarray(raw_col, dtype=float) 

312 x_finite = x[np.isfinite(x)] 

313 if x_finite.size == 0: 

314 x_finite = np.array([np.nan]) 

315 raw_stats["raw_min"].append(float(np.nanmin(x_finite))) 

316 raw_stats["raw_max"].append(float(np.nanmax(x_finite))) 

317 raw_stats["raw_mean"].append(float(np.nanmean(x_finite))) 

318 raw_stats["raw_std"].append(float(np.nanstd(x_finite))) 

319 raw_stats["n_unique_raw"].append(int(np.unique(np.round(x_finite, 12)).size)) 

320 

321 # internal (transformed) stats PRIOR to standardization 

322 internal_min = np.nanmin(X_raw, axis=0) 

323 internal_max = np.nanmax(X_raw, axis=0) 

324 internal_mean = np.nanmean(X_raw, axis=0) 

325 internal_std = np.nanstd(X_raw, axis=0) 

326 

327 # standardized 1–99% span (what PD uses) 

328 Xn_p01 = np.percentile(Xn, 1, axis=0) 

329 Xn_p99 = np.percentile(Xn, 99, axis=0) 

330 Xn_span = Xn_p99 - Xn_p01 

331 Xn_span = np.where(np.isfinite(Xn_span), Xn_span, np.nan) 

332 

333 # lengthscale-to-span ratios (big → likely flat) 

334 ell_s_arr = _np1d(map_s["ell"], p) 

335 ell_l_arr = _np1d(map_l["ell"], p) 

336 with np.errstate(divide="ignore", invalid="ignore"): 

337 ell_over_span_success = np.where(Xn_span > 0, ell_s_arr / Xn_span, np.nan) 

338 ell_over_span_loss = np.where(Xn_span > 0, ell_l_arr / Xn_span, np.nan) 

339 

340 # simple correlations (training) 

341 def _safe_corr(a, b): 

342 a = np.asarray(a, float); b = np.asarray(b, float) 

343 m = np.isfinite(a) & np.isfinite(b) 

344 if m.sum() < 3: 

345 return np.nan 

346 va = np.var(a[m]); vb = np.var(b[m]) 

347 if va == 0 or vb == 0: 

348 return np.nan 

349 return float(np.corrcoef(a[m], b[m])[0,1]) 

350 

351 corr_success = np.array([_safe_corr(X_raw[:, j], y_success) for j in range(p)], dtype=float) 

352 corr_loss_success = np.array( 

353 [_safe_corr(X_raw[ok_mask, j], y_loss_success) if ok_mask.any() else np.nan for j in range(p)], 

354 dtype=float 

355 ) 

356 

357 # ---------- Build xarray Dataset ---------- 

358 feature_coord = xr.DataArray(np.array(feature_names, dtype=object), dims=("feature",), name="feature") 

359 row_coord = xr.DataArray(np.arange(n, dtype=np.int64), dims=("row",), name="row") 

360 row_ok_coord = xr.DataArray(np.where(ok_mask)[0].astype(np.int64), dims=("row_success",), name="row_success") 

361 

362 # Raw columns (strings and numerics) — from df_raw 

363 raw_vars = {} 

364 for col in df_raw.columns: 

365 s = df_raw[col] 

366 if ( 

367 pd.api.types.is_integer_dtype(s) 

368 or pd.api.types.is_float_dtype(s) 

369 or pd.api.types.is_bool_dtype(s) 

370 ): 

371 raw_vars[col] = (("row",), s.to_numpy()) 

372 else: 

373 s_str = s.astype("string").fillna("<NA>") 

374 vals = s_str.to_numpy(dtype="U") 

375 raw_vars[col] = (("row",), vals) 

376 

377 ds = xr.Dataset( 

378 data_vars={ 

379 # Design matrices (standardized) 

380 "Xn_train": (("row", "feature"), Xn), 

381 "Xn_success_only": (("row_success", "feature"), Xn_success_only), 

382 

383 # Targets 

384 "y_success": (("row",), y_success), 

385 "y_loss_success": (("row_success",), y_loss_success), 

386 "y_loss_centered": (("row_success",), y_loss_centered), 

387 

388 # Standardization + transforms 

389 "feature_mean": (("feature",), X_mean), 

390 "feature_std": (("feature",), X_std), 

391 "feature_transform": (("feature",), np.array(transforms, dtype=object)), 

392 

393 # Masks / indexing 

394 "success_mask": (("row",), ok_mask.astype(np.int8)), 

395 

396 # Head A (success) MAP params 

397 "map_success_ell": (("feature",), ell_s_arr), 

398 "map_success_eta": ((), float(np.asarray(map_s["eta"]))), 

399 "map_success_sigma": ((), float(np.asarray(map_s["sigma"]))), 

400 "map_success_beta0": ((), float(np.asarray(map_s["beta0"]))), 

401 

402 # Head B (loss|success) MAP params 

403 "map_loss_ell": (("feature",), ell_l_arr), 

404 "map_loss_eta": ((), float(np.asarray(map_l["eta"]))), 

405 "map_loss_sigma": ((), float(np.asarray(map_l["sigma"]))), 

406 "map_loss_mean_const": ((), float(np.asarray(map_l["mean_const"]))), 

407 

408 # Convenience predictions on training data 

409 "pred_success_mu_train": (("row",), mu_s), 

410 "pred_success_var_train": (("row",), var_s), 

411 "pred_loss_mu_success_train": (("row_success",), mu_l), 

412 "pred_loss_sd_success_train": (("row_success",), sd_l), 

413 

414 # Useful scalars for predictors 

415 "conditional_loss_mean": ((), float(conditional_loss_mean)), 

416 

417 # ------- Diagnostics (per-feature) ------- 

418 "raw_min": (("feature",), np.array(raw_stats["raw_min"], dtype=float)), 

419 "raw_max": (("feature",), np.array(raw_stats["raw_max"], dtype=float)), 

420 "raw_mean": (("feature",), np.array(raw_stats["raw_mean"], dtype=float)), 

421 "raw_std": (("feature",), np.array(raw_stats["raw_std"], dtype=float)), 

422 "n_unique_raw": (("feature",), np.array(raw_stats["n_unique_raw"], dtype=np.int32)), 

423 

424 "internal_min": (("feature",), internal_min.astype(float)), 

425 "internal_max": (("feature",), internal_max.astype(float)), 

426 "internal_mean": (("feature",), internal_mean.astype(float)), 

427 "internal_std": (("feature",), internal_std.astype(float)), 

428 

429 "Xn_p01": (("feature",), Xn_p01.astype(float)), 

430 "Xn_p99": (("feature",), Xn_p99.astype(float)), 

431 "Xn_span": (("feature",), Xn_span.astype(float)), 

432 

433 "ell_over_span_success": (("feature",), ell_over_span_success.astype(float)), 

434 "ell_over_span_loss": (("feature",), ell_over_span_loss.astype(float)), 

435 

436 "corr_success": (("feature",), corr_success), 

437 "corr_loss_success": (("feature",), corr_loss_success), 

438 

439 "feature_is_onehot_member": (("feature",), np.array(is_onehot_member, dtype=np.int8)), 

440 "feature_onehot_base": (("feature",), np.array(onehot_base_per_feature, dtype=object)), 

441 }, 

442 coords={ 

443 "feature": feature_coord, 

444 "row": row_coord, 

445 "row_success": row_ok_coord, 

446 }, 

447 attrs={ 

448 "artifact_version": "0.1.2", # bumped for diagnostics 

449 "target": target, 

450 "direction": direction, 

451 "rng_state": rng_state_b64, 

452 "rng_bitgen": rng.bit_generator.__class__.__name__, 

453 "numpy_version": np.__version__, 

454 "pymc_version": pm.__version__, 

455 "n_rows": int(n), 

456 "n_success_rows": int(int(ok_mask.sum())), 

457 "success_rate": float(base_success_rate), 

458 "categorical_groups_json": json.dumps(categorical_groups), # base -> {labels, members} 

459 }, 

460 ) 

461 

462 # Attach raw columns 

463 for k, v in raw_vars.items(): 

464 ds[k] = v 

465 

466 # ---------- Save ---------- 

467 encoding = None 

468 if compress: 

469 encoding = {} 

470 for name, da in ds.data_vars.items(): 

471 if np.issubdtype(da.dtype, np.number): 

472 encoding[name] = {"zlib": True, "complevel": 4} 

473 

474 if output: 

475 output = Path(output) 

476 output.parent.mkdir(parents=True, exist_ok=True) 

477 engine, encoding = _select_netcdf_engine_and_encoding(ds, compress=compress) 

478 ds.to_netcdf(output, engine=engine, encoding=encoding) 

479 

480 _print_diagnostics_table(ds, top_k=20) 

481 

482 return ds 

483 

484 

485 

486def kernel_diag_m52(XA: np.ndarray, ls: np.ndarray, eta: float) -> np.ndarray: 

487 return np.full(XA.shape[0], eta ** 2, dtype=float) 

488 

489 

490def kernel_m52_ard(XA: np.ndarray, XB: np.ndarray, ls: np.ndarray, eta: float) -> np.ndarray: 

491 XA = np.asarray(XA, float) 

492 XB = np.asarray(XB, float) 

493 ls = np.asarray(ls, float).reshape(1, 1, -1) 

494 diff = (XA[:, None, :] - XB[None, :, :]) / ls 

495 r2 = np.sum(diff * diff, axis=2) 

496 r = np.sqrt(np.maximum(r2, 0.0)) 

497 sqrt5_r = np.sqrt(5.0) * r 

498 k = (eta ** 2) * (1.0 + sqrt5_r + (5.0 / 3.0) * r2) * np.exp(-sqrt5_r) 

499 return k 

500 

501 

502def add_jitter(K: np.ndarray, eps: float = 1e-8) -> np.ndarray: 

503 jitter = eps * float(np.mean(np.diag(K)) + 1.0) 

504 return K + jitter * np.eye(K.shape[0], dtype=K.dtype) 

505 

506 

507def solve_chol(L: np.ndarray, b: np.ndarray) -> np.ndarray: 

508 y = np.linalg.solve(L, b) 

509 return np.linalg.solve(L.T, y) 

510 

511 

512def solve_lower(L: np.ndarray, B: np.ndarray) -> np.ndarray: 

513 return np.linalg.solve(L, B) 

514 

515 

516def feature_raw_from_artifact_or_reconstruct( 

517 ds: xr.Dataset, 

518 j: int, 

519 name: str, 

520 transform: str, 

521) -> np.ndarray: 

522 """ 

523 Return the feature values in ORIGINAL units for each training row. 

524 Prefer a stored raw column (ds[name]) if present; otherwise reconstruct 

525 from Xn_train using feature_mean/std and the recorded transform. 

526 """ 

527 # 1) Use stored raw per-row column if present 

528 if name in ds.data_vars: 

529 da = ds[name] 

530 if "row" in da.dims and da.sizes.get("row", None) == ds.sizes.get("row", None): 

531 vals = np.asarray(da.values, dtype=float) 

532 return vals 

533 

534 # 2) Reconstruct from standardized training matrix 

535 Xn = ds["Xn_train"].values.astype(float) # (N, p) 

536 mu = ds["feature_mean"].values.astype(float)[j] 

537 sd = ds["feature_std"].values.astype(float)[j] 

538 x_internal = Xn[:, j] * sd + mu # internal model space 

539 if transform == "log10": 

540 raw = 10.0 ** x_internal 

541 else: 

542 raw = x_internal 

543 return raw 

544 

545 

546# --------------------------------------------------------------------- 

547# Helpers 

548# --------------------------------------------------------------------- 

549def _to_bool01(arr: np.ndarray) -> np.ndarray: 

550 """Map arbitrary truthy/falsy values to {0,1} int array.""" 

551 if arr.dtype == bool: 

552 return arr.astype(np.int32) 

553 if np.issubdtype(arr.dtype, np.number): 

554 return (arr != 0).astype(np.int32) 

555 truthy = {"1", "true", "yes", "y", "ok", "success"} 

556 return np.array([1 if (str(x).strip().lower() in truthy) else 0 for x in arr], dtype=np.int32) 

557 

558 

559def _choose_transform(name: str, col: np.ndarray) -> str: 

560 """ 

561 Choose a simple per-feature transform: 'identity' or 'log10'. 

562 Heuristics: 

563 - if column name looks like a learning rate (lr/learning_rate) AND >0 => log10 

564 - else if strictly positive and p99/p1 >= 1e3 => log10 

565 - else identity 

566 """ 

567 name_l = name.lower() 

568 strictly_pos = np.all(np.isfinite(col)) and np.nanmin(col) > 0.0 

569 looks_lr = ("learning_rate" in name_l) or (name_l == "lr") 

570 if strictly_pos and (looks_lr or _large_dynamic_range(col)): 

571 return "log10" 

572 return "identity" 

573 

574 

575def _large_dynamic_range(col: np.ndarray) -> bool: 

576 x = col[np.isfinite(col)] 

577 if x.size == 0: 

578 return False 

579 p1, p99 = np.percentile(x, [1, 99]) 

580 p1 = max(p1, 1e-12) 

581 return (p99 / p1) >= 1e3 

582 

583 

584def _apply_transform(tr: str, col: np.ndarray) -> np.ndarray: 

585 if tr == "log10": 

586 x = np.asarray(col, dtype=float) 

587 x = np.where(x <= 0, np.nan, x) 

588 return np.log10(x) 

589 return np.asarray(col, dtype=float) 

590 

591 

592def _np1d(x: np.ndarray, p: int) -> np.ndarray: 

593 a = np.asarray(x, dtype=float).ravel() 

594 if a.size != p: 

595 a = np.full((p,), float(a.item()) if a.size == 1 else np.nan, dtype=float) 

596 return a 

597 

598 

599# ---- choose engine + encoding safely across backends ---- 

600def _select_netcdf_engine_and_encoding(ds: xr.Dataset, compress: bool): 

601 # Prefer netcdf4 

602 try: 

603 import netCDF4 # noqa: F401 

604 engine = "netcdf4" 

605 if not compress: 

606 return engine, None 

607 enc = {} 

608 for name, da in ds.data_vars.items(): 

609 if np.issubdtype(da.dtype, np.number): 

610 enc[name] = {"zlib": True, "complevel": 4} 

611 return engine, enc 

612 except Exception: 

613 pass 

614 

615 # Then h5netcdf 

616 try: 

617 import h5netcdf # noqa: F401 

618 engine = "h5netcdf" 

619 if not compress: 

620 return engine, None 

621 enc = {} 

622 for name, da in ds.data_vars.items(): 

623 if np.issubdtype(da.dtype, np.number): 

624 enc[name] = {"compression": "gzip", "compression_opts": 4} 

625 return engine, enc 

626 except Exception: 

627 pass 

628 

629 # Finally scipy (no compression supported) 

630 engine = "scipy" 

631 return engine, None