Coverage for psyop/model.py: 74.65%

355 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-10-29 03:44 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4# Make BLAS single-threaded to avoid oversubscription / macOS crashes 

5import os 

6for _env_var in ( 

7 "MKL_NUM_THREADS", 

8 "OPENBLAS_NUM_THREADS", 

9 "OMP_NUM_THREADS", 

10 "VECLIB_MAXIMUM_THREADS", 

11 "NUMEXPR_NUM_THREADS", 

12): 

13 os.environ.setdefault(_env_var, "1") 

14 

15from pathlib import Path 

16import base64 

17import pickle 

18import json 

19 

20import numpy as np 

21import pandas as pd 

22import pymc as pm 

23import xarray as xr 

24from rich.console import Console 

25from rich.table import Table 

26 

27from .util import get_rng, df_to_table 

28 

29def _safe_vec(ds: xr.Dataset, name: str, nF: int) -> np.ndarray: 

30 if name not in ds: 

31 return np.full(nF, np.nan) 

32 arr = np.asarray(ds[name].values) 

33 if arr.size == nF: 

34 return arr 

35 out = np.full(nF, np.nan) 

36 out[:min(nF, arr.size)] = arr.ravel()[:min(nF, arr.size)] 

37 return out 

38 

39def _safe_scalar(ds: xr.Dataset, name: str) -> float: 

40 if name in ds: 

41 try: return float(np.asarray(ds[name].values).item()) 

42 except Exception: return float(np.nan) 

43 if name in ds.attrs: 

44 try: return float(ds.attrs[name]) 

45 except Exception: return float(np.nan) 

46 return float(np.nan) 

47 

48def _diagnostic_feature_dataframe(ds: xr.Dataset, top_k: int = 20) -> pd.DataFrame: 

49 """Top-K features by |corr_loss_success| without any repeated globals.""" 

50 f = [str(x) for x in ds["feature"].values] 

51 nF = len(f) 

52 

53 is_oh = ds["feature_is_onehot_member"].values.astype(int) if "feature_is_onehot_member" in ds else np.zeros(nF, int) 

54 base = ds["feature_onehot_base"].values if "feature_onehot_base" in ds else np.array([""]*nF, object) 

55 

56 df = pd.DataFrame({ 

57 "feature": f, 

58 "type": np.where(is_oh == 1, "categorical(one-hot)", "numeric"), 

59 "onehot_base": np.where(is_oh == 1, base, ""), 

60 "n_unique_raw": _safe_vec(ds, "n_unique_raw", nF), 

61 "raw_min": _safe_vec(ds, "raw_min", nF), 

62 "raw_max": _safe_vec(ds, "raw_max", nF), 

63 "Xn_span": _safe_vec(ds, "Xn_span", nF), 

64 "ell_s": _safe_vec(ds, "map_success_ell", nF), 

65 "ell_l": _safe_vec(ds, "map_loss_ell", nF), 

66 "ell/span_s": _safe_vec(ds, "ell_over_span_success", nF), 

67 "ell/span_l": _safe_vec(ds, "ell_over_span_loss", nF), 

68 "corr_success": _safe_vec(ds, "corr_success", nF), 

69 "corr_loss_success": _safe_vec(ds, "corr_loss_success", nF), 

70 }) 

71 df["|corr_loss_success|"] = np.abs(df["corr_loss_success"].astype(float)) 

72 return df.sort_values("|corr_loss_success|", ascending=False, kind="mergesort").head(top_k).reset_index(drop=True) 

73 

74def _diagnostic_global_dataframe(ds: xr.Dataset) -> pd.DataFrame: 

75 """Single-row DataFrame of global/model-level diagnostics.""" 

76 # Scalars from attrs 

77 target = ds.attrs.get("target", "") 

78 direction = ds.attrs.get("direction", "") 

79 n_rows = int(ds.attrs.get("n_rows", np.nan)) 

80 n_success_rows = int(ds.attrs.get("n_success_rows", np.nan)) if "n_success_rows" in ds.attrs else int(np.sum(np.asarray(ds["success_mask"].values)) if "success_mask" in ds else np.nan) 

81 success_rate = float(ds.attrs.get("success_rate", np.nan)) 

82 rng_bitgen = ds.attrs.get("rng_bitgen", "") 

83 numpy_version = ds.attrs.get("numpy_version", "") 

84 pymc_version = ds.attrs.get("pymc_version", "") 

85 

86 # Scalars from data_vars (with attr fallback) 

87 conditional_loss_mean = _safe_scalar(ds, "conditional_loss_mean") 

88 

89 # GP MAP scalars 

90 map_success_eta = _safe_scalar(ds, "map_success_eta") 

91 map_success_sigma = _safe_scalar(ds, "map_success_sigma") 

92 map_success_beta0 = _safe_scalar(ds, "map_success_beta0") 

93 

94 map_loss_eta = _safe_scalar(ds, "map_loss_eta") 

95 map_loss_sigma = _safe_scalar(ds, "map_loss_sigma") 

96 map_loss_mean_c = _safe_scalar(ds, "map_loss_mean_const") 

97 

98 # Data dispersion on successes 

99 y_ls_std = float(np.nanstd(ds["y_loss_success"].values)) if "y_loss_success" in ds else np.nan 

100 

101 # Handy amplitude/noise ratio for loss head 

102 eta_l_over_sigma_l = (map_loss_eta / map_loss_sigma) if (np.isfinite(map_loss_eta) and np.isfinite(map_loss_sigma) and map_loss_sigma != 0) else np.nan 

103 

104 rows = [{ 

105 "target": target, 

106 "direction": direction, 

107 "n_rows": n_rows, 

108 "n_success_rows": n_success_rows, 

109 "success_rate": success_rate, 

110 "conditional_loss_mean": conditional_loss_mean, 

111 "map_success_eta": map_success_eta, 

112 "map_success_sigma": map_success_sigma, 

113 "map_success_beta0": map_success_beta0, 

114 "map_loss_eta": map_loss_eta, 

115 "map_loss_sigma": map_loss_sigma, 

116 "map_loss_mean_const": map_loss_mean_c, 

117 "y_loss_success_std": y_ls_std, 

118 "eta_l/sigma_l": eta_l_over_sigma_l, 

119 "rng_bitgen": rng_bitgen, 

120 "numpy_version": numpy_version, 

121 "pymc_version": pymc_version, 

122 }] 

123 return pd.DataFrame(rows) 

124 

125 

126 

127def _print_diagnostics_table(ds: xr.Dataset, top_k: int = 20) -> None: 

128 """3-sigfig Rich table with key diagnostics.""" 

129 feat_df = _diagnostic_feature_dataframe(ds, top_k=20) 

130 global_df = _diagnostic_global_dataframe(ds) 

131 

132 console = Console() 

133 

134 console.print("\n[bold]Model diagnostics (top by |corr_loss_success|):[/]") 

135 console.print(df_to_table(feat_df, transpose=False, show_index=False)) # regular table 

136 

137 console.print("\n[bold]Model globals:[/]") 

138 # transpose for key/value look, with magenta header column 

139 console.print(df_to_table(global_df, transpose=True)) 

140 

141 

142def build_model( 

143 input: pd.DataFrame|Path|str, 

144 target: str, 

145 output: Path | str | None = None, 

146 exclude: list[str] | str | None = None, 

147 direction: str = "auto", 

148 seed: int | np.random.Generator | None = 42, 

149 compress: bool = True, 

150 prior_model: Path | str | xr.Dataset | None = None, 

151) -> xr.Dataset: 

152 """ 

153 Fit two-head GP (success prob + conditional loss) and save a single NetCDF artifact. 

154 Also stores rich diagnostics to help debug flat PD curves and model wiring. 

155 """ 

156 # ---------- Load ---------- 

157 if isinstance(input, pd.DataFrame): 

158 df = input 

159 else: 

160 input_path = Path(input) 

161 if not input_path.exists(): 

162 raise FileNotFoundError(f"Input CSV not found: {input_path.resolve()}") 

163 df = pd.read_csv(input_path) 

164 

165 if target not in df.columns: 

166 raise ValueError(f"Target column '{target}' not found in CSV.") 

167 

168 # Keep a raw copy for artifact (strings as pandas 'string' so NetCDF can handle them) 

169 df_raw = df.copy() 

170 for c in df_raw.columns: 

171 if df_raw[c].dtype == object: 

172 df_raw[c] = df_raw[c].astype("string") 

173 

174 # ---------- Success inference ---------- 

175 success = (~df[target].isna()).to_numpy().astype(int) 

176 has_success = bool(np.any(success == 1)) 

177 if not has_success: 

178 raise RuntimeError("No successful rows detected (cannot fit conditional-loss GP).") 

179 

180 # ---------- Feature selection ---------- 

181 reserved_internals = {"__success__", "__fail__", "__status__"} 

182 exclude = [exclude] if isinstance(exclude, str) else (exclude or []) 

183 excluded = set(exclude) | {target} | reserved_internals 

184 

185 numeric_cols = [ 

186 c for c in df.columns 

187 if c not in excluded and pd.api.types.is_numeric_dtype(df[c]) 

188 ] 

189 cat_cols = [ 

190 c for c in df.columns 

191 if c not in excluded 

192 and (pd.api.types.is_string_dtype(df[c]) or pd.api.types.is_categorical_dtype(df[c])) 

193 ] 

194 

195 feature_names: list[str] = [] 

196 transforms: list[str] = [] 

197 X_raw_cols: list[np.ndarray] = [] 

198 

199 # Diagnostics to fill 

200 onehot_base_per_feature: list[str] = [] # "" for numeric / non-onehot 

201 is_onehot_member: list[int] = [] 

202 categorical_groups: dict[str, dict] = {} # base -> {"labels":[...], "members":[...]} 

203 

204 # numeric features 

205 for name in numeric_cols: 

206 col = df[name].to_numpy(dtype=float) 

207 tr = _choose_transform(name, col) 

208 transforms.append(tr) 

209 feature_names.append(name) 

210 X_raw_cols.append(_apply_transform(tr, col)) 

211 onehot_base_per_feature.append("") 

212 is_onehot_member.append(0) 

213 

214 # categoricals → one-hot 

215 for base in cat_cols: 

216 s_cat = pd.Categorical(df[base].astype("string").fillna("<NA>")) 

217 H = pd.get_dummies(s_cat, prefix=base, prefix_sep="=", dtype=float) # e.g., language=Linear A 

218 members = [] 

219 labels = [] 

220 for new_col in H.columns: 

221 feature_names.append(new_col) 

222 transforms.append("identity") 

223 X_raw_cols.append(H[new_col].to_numpy(dtype=float)) 

224 onehot_base_per_feature.append(base) 

225 is_onehot_member.append(1) 

226 members.append(new_col) 

227 # label is the part after "base=" 

228 labels.append(str(new_col.split("=", 1)[1]) if "=" in new_col else str(new_col)) 

229 categorical_groups[base] = {"labels": labels, "members": members} 

230 

231 X_raw = np.column_stack(X_raw_cols).astype(float) 

232 n, p = X_raw.shape 

233 

234 prior_start_success: dict[str, np.ndarray | float] | None = None 

235 prior_start_loss: dict[str, np.ndarray | float] | None = None 

236 if prior_model is not None: 

237 if isinstance(prior_model, xr.Dataset): 

238 prior_ds = prior_model 

239 else: 

240 prior_path = Path(prior_model) 

241 if not prior_path.exists(): 

242 raise FileNotFoundError(f"Prior model artifact not found: {prior_path.resolve()}") 

243 prior_ds = xr.load_dataset(prior_path) 

244 try: 

245 prior_feature_names = [str(v) for v in prior_ds["feature"].values.tolist()] 

246 except Exception: 

247 prior_ds = None 

248 else: 

249 if prior_feature_names != feature_names: 

250 prior_ds = None 

251 if prior_ds is not None: 

252 try: 

253 ell_s_prev = np.asarray(prior_ds["map_success_ell"].values, dtype=float) 

254 eta_s_prev = float(np.asarray(prior_ds["map_success_eta"].values)) 

255 sigma_s_prev = float(np.asarray(prior_ds["map_success_sigma"].values)) 

256 beta0_s_prev = float(np.asarray(prior_ds["map_success_beta0"].values)) 

257 ell_l_prev = np.asarray(prior_ds["map_loss_ell"].values, dtype=float) 

258 eta_l_prev = float(np.asarray(prior_ds["map_loss_eta"].values)) 

259 sigma_l_prev = float(np.asarray(prior_ds["map_loss_sigma"].values)) 

260 mean_c_prev = float(np.asarray(prior_ds["map_loss_mean_const"].values)) 

261 if ell_s_prev.shape == (p,) and ell_l_prev.shape == (p,): 

262 prior_start_success = { 

263 "ell": ell_s_prev, 

264 "eta": eta_s_prev, 

265 "sigma": sigma_s_prev, 

266 "beta0": beta0_s_prev, 

267 } 

268 prior_start_loss = { 

269 "ell": ell_l_prev, 

270 "eta": eta_l_prev, 

271 "sigma": sigma_l_prev, 

272 "mean_const": mean_c_prev, 

273 } 

274 except Exception: 

275 prior_start_success = None 

276 prior_start_loss = None 

277 

278 # ---------- Standardize ---------- 

279 X_mean = X_raw.mean(axis=0) 

280 X_std = X_raw.std(axis=0) 

281 X_std = np.where(X_std == 0.0, 1.0, X_std) # keep inert dims harmless 

282 Xn = (X_raw - X_mean) / X_std 

283 

284 # Targets 

285 y_success = success.astype(float) 

286 ok_mask = (success == 1) 

287 y_loss_success = df.loc[ok_mask, target].to_numpy(dtype=float) 

288 

289 conditional_loss_mean = float(np.nanmean(y_loss_success)) if len(y_loss_success) else 0.0 

290 y_loss_centered = y_loss_success - conditional_loss_mean 

291 Xn_success_only = Xn[ok_mask, :] 

292 

293 rng = get_rng(seed) 

294 state_bytes = pickle.dumps(rng.bit_generator.state) 

295 rng_state_b64 = base64.b64encode(state_bytes).decode("ascii") 

296 

297 base_success_rate = float(y_success.mean()) 

298 

299 # ---------- Fit Head A: success ---------- 

300 with pm.Model() as model_s: 

301 ell_s = pm.HalfNormal("ell", sigma=2.0, shape=p) 

302 eta_s = pm.HalfNormal("eta", sigma=2.0) 

303 sigma_s = pm.HalfNormal("sigma", sigma=0.3) 

304 beta0_s = pm.Normal("beta0", mu=base_success_rate, sigma=0.15) 

305 

306 K_s = eta_s**2 * pm.gp.cov.Matern52(input_dim=p, ls=ell_s) 

307 m_s = pm.gp.mean.Constant(beta0_s) 

308 gp_s = pm.gp.Marginal(mean_func=m_s, cov_func=K_s) 

309 

310 _ = gp_s.marginal_likelihood("y_obs_s", X=Xn, y=y_success, sigma=sigma_s) 

311 map_s = pm.find_MAP(start=prior_start_success) if prior_start_success is not None else pm.find_MAP() 

312 

313 with model_s: 

314 mu_s, var_s = gp_s.predict(Xn, point=map_s, diag=True, pred_noise=True) 

315 mu_s = np.clip(mu_s, 0.0, 1.0) 

316 

317 # ---------- Fit Head B: conditional loss (success-only) ---------- 

318 if Xn_success_only.shape[0] == 0: 

319 raise RuntimeError("No successful rows to fit the conditional-loss GP.") 

320 

321 with pm.Model() as model_l: 

322 ell_l = pm.TruncatedNormal("ell", mu=1.0, sigma=0.5, lower=0.1, shape=p) 

323 eta_l = pm.HalfNormal("eta", sigma=1.0) 

324 sigma_l = pm.HalfNormal("sigma", sigma=1.0) 

325 mean_c = pm.Normal("mean_const", mu=0.0, sigma=10.0) 

326 

327 K_l = eta_l**2 * pm.gp.cov.Matern52(input_dim=p, ls=ell_l) 

328 m_l = pm.gp.mean.Constant(mean_c) 

329 gp_l = pm.gp.Marginal(mean_func=m_l, cov_func=K_l) 

330 

331 _ = gp_l.marginal_likelihood("y_obs", X=Xn_success_only, y=y_loss_centered, sigma=sigma_l) 

332 map_l = pm.find_MAP(start=prior_start_loss) if prior_start_loss is not None else pm.find_MAP() 

333 

334 with model_l: 

335 mu_l_c, var_l = gp_l.predict(Xn_success_only, point=map_l, diag=True, pred_noise=True) 

336 mu_l = mu_l_c + conditional_loss_mean 

337 sd_l = np.sqrt(var_l) 

338 

339 # ---------- Diagnostics (per-feature) ---------- 

340 # raw stats in ORIGINAL units (before any transform) 

341 raw_stats = { 

342 "raw_min": [], "raw_max": [], "raw_mean": [], "raw_std": [], "n_unique_raw": [], 

343 } 

344 for k, name in enumerate(feature_names): 

345 # Try to recover a raw column if present; otherwise invert transform on X_raw[:,k] 

346 if name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[name]): 

347 raw_col = df_raw[name].to_numpy(dtype=float) 

348 else: 

349 # Inverse transform of "internal" values only if it's a simple one: 

350 internal = X_raw[:, k] 

351 tr = transforms[k] 

352 if tr == "log10": 

353 raw_col = np.power(10.0, internal) 

354 else: 

355 raw_col = internal 

356 x = np.asarray(raw_col, dtype=float) 

357 x_finite = x[np.isfinite(x)] 

358 if x_finite.size == 0: 

359 x_finite = np.array([np.nan]) 

360 raw_stats["raw_min"].append(float(np.nanmin(x_finite))) 

361 raw_stats["raw_max"].append(float(np.nanmax(x_finite))) 

362 raw_stats["raw_mean"].append(float(np.nanmean(x_finite))) 

363 raw_stats["raw_std"].append(float(np.nanstd(x_finite))) 

364 raw_stats["n_unique_raw"].append(int(np.unique(np.round(x_finite, 12)).size)) 

365 

366 # internal (transformed) stats PRIOR to standardization 

367 internal_min = np.nanmin(X_raw, axis=0) 

368 internal_max = np.nanmax(X_raw, axis=0) 

369 internal_mean = np.nanmean(X_raw, axis=0) 

370 internal_std = np.nanstd(X_raw, axis=0) 

371 

372 # standardized 1–99% span (what PD uses) 

373 Xn_p01 = np.percentile(Xn, 1, axis=0) 

374 Xn_p99 = np.percentile(Xn, 99, axis=0) 

375 Xn_span = Xn_p99 - Xn_p01 

376 Xn_span = np.where(np.isfinite(Xn_span), Xn_span, np.nan) 

377 

378 # lengthscale-to-span ratios (big → likely flat) 

379 ell_s_arr = _np1d(map_s["ell"], p) 

380 ell_l_arr = _np1d(map_l["ell"], p) 

381 with np.errstate(divide="ignore", invalid="ignore"): 

382 ell_over_span_success = np.where(Xn_span > 0, ell_s_arr / Xn_span, np.nan) 

383 ell_over_span_loss = np.where(Xn_span > 0, ell_l_arr / Xn_span, np.nan) 

384 

385 # simple correlations (training) 

386 def _safe_corr(a, b): 

387 a = np.asarray(a, float); b = np.asarray(b, float) 

388 m = np.isfinite(a) & np.isfinite(b) 

389 if m.sum() < 3: 

390 return np.nan 

391 va = np.var(a[m]); vb = np.var(b[m]) 

392 if va == 0 or vb == 0: 

393 return np.nan 

394 return float(np.corrcoef(a[m], b[m])[0,1]) 

395 

396 corr_success = np.array([_safe_corr(X_raw[:, j], y_success) for j in range(p)], dtype=float) 

397 corr_loss_success = np.array( 

398 [_safe_corr(X_raw[ok_mask, j], y_loss_success) if ok_mask.any() else np.nan for j in range(p)], 

399 dtype=float 

400 ) 

401 

402 # ---------- Build xarray Dataset ---------- 

403 feature_coord = xr.DataArray(np.array(feature_names, dtype=object), dims=("feature",), name="feature") 

404 row_coord = xr.DataArray(np.arange(n, dtype=np.int64), dims=("row",), name="row") 

405 row_ok_coord = xr.DataArray(np.where(ok_mask)[0].astype(np.int64), dims=("row_success",), name="row_success") 

406 

407 # Raw columns (strings and numerics) — from df_raw 

408 raw_vars = {} 

409 for col in df_raw.columns: 

410 s = df_raw[col] 

411 if ( 

412 pd.api.types.is_integer_dtype(s) 

413 or pd.api.types.is_float_dtype(s) 

414 or pd.api.types.is_bool_dtype(s) 

415 ): 

416 raw_vars[col] = (("row",), s.to_numpy()) 

417 else: 

418 s_str = s.astype("string").fillna("<NA>") 

419 vals = s_str.to_numpy(dtype="U") 

420 raw_vars[col] = (("row",), vals) 

421 

422 ds = xr.Dataset( 

423 data_vars={ 

424 # Design matrices (standardized) 

425 "Xn_train": (("row", "feature"), Xn), 

426 "Xn_success_only": (("row_success", "feature"), Xn_success_only), 

427 

428 # Targets 

429 "y_success": (("row",), y_success), 

430 "y_loss_success": (("row_success",), y_loss_success), 

431 "y_loss_centered": (("row_success",), y_loss_centered), 

432 

433 # Standardization + transforms 

434 "feature_mean": (("feature",), X_mean), 

435 "feature_std": (("feature",), X_std), 

436 "feature_transform": (("feature",), np.array(transforms, dtype=object)), 

437 

438 # Masks / indexing 

439 "success_mask": (("row",), ok_mask.astype(np.int8)), 

440 

441 # Head A (success) MAP params 

442 "map_success_ell": (("feature",), ell_s_arr), 

443 "map_success_eta": ((), float(np.asarray(map_s["eta"]))), 

444 "map_success_sigma": ((), float(np.asarray(map_s["sigma"]))), 

445 "map_success_beta0": ((), float(np.asarray(map_s["beta0"]))), 

446 

447 # Head B (loss|success) MAP params 

448 "map_loss_ell": (("feature",), ell_l_arr), 

449 "map_loss_eta": ((), float(np.asarray(map_l["eta"]))), 

450 "map_loss_sigma": ((), float(np.asarray(map_l["sigma"]))), 

451 "map_loss_mean_const": ((), float(np.asarray(map_l["mean_const"]))), 

452 

453 # Convenience predictions on training data 

454 "pred_success_mu_train": (("row",), mu_s), 

455 "pred_success_var_train": (("row",), var_s), 

456 "pred_loss_mu_success_train": (("row_success",), mu_l), 

457 "pred_loss_sd_success_train": (("row_success",), sd_l), 

458 

459 # Useful scalars for predictors 

460 "conditional_loss_mean": ((), float(conditional_loss_mean)), 

461 

462 # ------- Diagnostics (per-feature) ------- 

463 "raw_min": (("feature",), np.array(raw_stats["raw_min"], dtype=float)), 

464 "raw_max": (("feature",), np.array(raw_stats["raw_max"], dtype=float)), 

465 "raw_mean": (("feature",), np.array(raw_stats["raw_mean"], dtype=float)), 

466 "raw_std": (("feature",), np.array(raw_stats["raw_std"], dtype=float)), 

467 "n_unique_raw": (("feature",), np.array(raw_stats["n_unique_raw"], dtype=np.int32)), 

468 

469 "internal_min": (("feature",), internal_min.astype(float)), 

470 "internal_max": (("feature",), internal_max.astype(float)), 

471 "internal_mean": (("feature",), internal_mean.astype(float)), 

472 "internal_std": (("feature",), internal_std.astype(float)), 

473 

474 "Xn_p01": (("feature",), Xn_p01.astype(float)), 

475 "Xn_p99": (("feature",), Xn_p99.astype(float)), 

476 "Xn_span": (("feature",), Xn_span.astype(float)), 

477 

478 "ell_over_span_success": (("feature",), ell_over_span_success.astype(float)), 

479 "ell_over_span_loss": (("feature",), ell_over_span_loss.astype(float)), 

480 

481 "corr_success": (("feature",), corr_success), 

482 "corr_loss_success": (("feature",), corr_loss_success), 

483 

484 "feature_is_onehot_member": (("feature",), np.array(is_onehot_member, dtype=np.int8)), 

485 "feature_onehot_base": (("feature",), np.array(onehot_base_per_feature, dtype=object)), 

486 }, 

487 coords={ 

488 "feature": feature_coord, 

489 "row": row_coord, 

490 "row_success": row_ok_coord, 

491 }, 

492 attrs={ 

493 "artifact_version": "0.1.2", # bumped for diagnostics 

494 "target": target, 

495 "direction": direction, 

496 "rng_state": rng_state_b64, 

497 "rng_bitgen": rng.bit_generator.__class__.__name__, 

498 "numpy_version": np.__version__, 

499 "pymc_version": pm.__version__, 

500 "n_rows": int(n), 

501 "n_success_rows": int(int(ok_mask.sum())), 

502 "success_rate": float(base_success_rate), 

503 "categorical_groups_json": json.dumps(categorical_groups), # base -> {labels, members} 

504 }, 

505 ) 

506 

507 # Attach raw columns 

508 for k, v in raw_vars.items(): 

509 ds[k] = v 

510 

511 # ---------- Save ---------- 

512 encoding = None 

513 if compress: 

514 encoding = {} 

515 for name, da in ds.data_vars.items(): 

516 if np.issubdtype(da.dtype, np.number): 

517 encoding[name] = {"zlib": True, "complevel": 4} 

518 

519 if output: 

520 output = Path(output) 

521 output.parent.mkdir(parents=True, exist_ok=True) 

522 engine, encoding = _select_netcdf_engine_and_encoding(ds, compress=compress) 

523 ds.to_netcdf(output, engine=engine, encoding=encoding) 

524 

525 _print_diagnostics_table(ds, top_k=20) 

526 

527 return ds 

528 

529 

530 

531def kernel_diag_m52(XA: np.ndarray, ls: np.ndarray, eta: float) -> np.ndarray: 

532 return np.full(XA.shape[0], eta ** 2, dtype=float) 

533 

534 

535def kernel_m52_ard(XA: np.ndarray, XB: np.ndarray, ls: np.ndarray, eta: float) -> np.ndarray: 

536 XA = np.asarray(XA, float) 

537 XB = np.asarray(XB, float) 

538 ls = np.asarray(ls, float).reshape(1, 1, -1) 

539 diff = (XA[:, None, :] - XB[None, :, :]) / ls 

540 r2 = np.sum(diff * diff, axis=2) 

541 r = np.sqrt(np.maximum(r2, 0.0)) 

542 sqrt5_r = np.sqrt(5.0) * r 

543 k = (eta ** 2) * (1.0 + sqrt5_r + (5.0 / 3.0) * r2) * np.exp(-sqrt5_r) 

544 return k 

545 

546 

547def add_jitter(K: np.ndarray, eps: float = 1e-8) -> np.ndarray: 

548 jitter = eps * float(np.mean(np.diag(K)) + 1.0) 

549 return K + jitter * np.eye(K.shape[0], dtype=K.dtype) 

550 

551 

552def solve_chol(L: np.ndarray, b: np.ndarray) -> np.ndarray: 

553 y = np.linalg.solve(L, b) 

554 return np.linalg.solve(L.T, y) 

555 

556 

557def solve_lower(L: np.ndarray, B: np.ndarray) -> np.ndarray: 

558 return np.linalg.solve(L, B) 

559 

560 

561def feature_raw_from_artifact_or_reconstruct( 

562 ds: xr.Dataset, 

563 j: int, 

564 name: str, 

565 transform: str, 

566) -> np.ndarray: 

567 """ 

568 Return the feature values in ORIGINAL units for each training row. 

569 Prefer a stored raw column (ds[name]) if present; otherwise reconstruct 

570 from Xn_train using feature_mean/std and the recorded transform. 

571 """ 

572 # 1) Use stored raw per-row column if present 

573 if name in ds.data_vars: 

574 da = ds[name] 

575 if "row" in da.dims and da.sizes.get("row", None) == ds.sizes.get("row", None): 

576 vals = np.asarray(da.values, dtype=float) 

577 return vals 

578 

579 # 2) Reconstruct from standardized training matrix 

580 Xn = ds["Xn_train"].values.astype(float) # (N, p) 

581 mu = ds["feature_mean"].values.astype(float)[j] 

582 sd = ds["feature_std"].values.astype(float)[j] 

583 x_internal = Xn[:, j] * sd + mu # internal model space 

584 if transform == "log10": 

585 raw = 10.0 ** x_internal 

586 else: 

587 raw = x_internal 

588 return raw 

589 

590 

591# --------------------------------------------------------------------- 

592# Helpers 

593# --------------------------------------------------------------------- 

594def _to_bool01(arr: np.ndarray) -> np.ndarray: 

595 """Map arbitrary truthy/falsy values to {0,1} int array.""" 

596 if arr.dtype == bool: 

597 return arr.astype(np.int32) 

598 if np.issubdtype(arr.dtype, np.number): 

599 return (arr != 0).astype(np.int32) 

600 truthy = {"1", "true", "yes", "y", "ok", "success"} 

601 return np.array([1 if (str(x).strip().lower() in truthy) else 0 for x in arr], dtype=np.int32) 

602 

603 

604def _choose_transform(name: str, col: np.ndarray) -> str: 

605 """ 

606 Choose a simple per-feature transform: 'identity' or 'log10'. 

607 Heuristics: 

608 - if column name looks like a learning rate (lr/learning_rate) AND >0 => log10 

609 - else if strictly positive and p99/p1 >= 1e3 => log10 

610 - else identity 

611 """ 

612 name_l = name.lower() 

613 strictly_pos = np.all(np.isfinite(col)) and np.nanmin(col) > 0.0 

614 looks_lr = ("learning_rate" in name_l) or (name_l == "lr") 

615 if strictly_pos and (looks_lr or _large_dynamic_range(col)): 

616 return "log10" 

617 return "identity" 

618 

619 

620def _large_dynamic_range(col: np.ndarray) -> bool: 

621 x = col[np.isfinite(col)] 

622 if x.size == 0: 

623 return False 

624 p1, p99 = np.percentile(x, [1, 99]) 

625 p1 = max(p1, 1e-12) 

626 return (p99 / p1) >= 1e3 

627 

628 

629def _apply_transform(tr: str, col: np.ndarray) -> np.ndarray: 

630 if tr == "log10": 

631 x = np.asarray(col, dtype=float) 

632 x = np.where(x <= 0, np.nan, x) 

633 return np.log10(x) 

634 return np.asarray(col, dtype=float) 

635 

636 

637def _np1d(x: np.ndarray, p: int) -> np.ndarray: 

638 a = np.asarray(x, dtype=float).ravel() 

639 if a.size != p: 

640 a = np.full((p,), float(a.item()) if a.size == 1 else np.nan, dtype=float) 

641 return a 

642 

643 

644# ---- choose engine + encoding safely across backends ---- 

645def _select_netcdf_engine_and_encoding(ds: xr.Dataset, compress: bool): 

646 # Prefer netcdf4 

647 try: 

648 import netCDF4 # noqa: F401 

649 engine = "netcdf4" 

650 if not compress: 

651 return engine, None 

652 enc = {} 

653 for name, da in ds.data_vars.items(): 

654 if np.issubdtype(da.dtype, np.number): 

655 enc[name] = {"zlib": True, "complevel": 4} 

656 return engine, enc 

657 except Exception: 

658 pass 

659 

660 # Then h5netcdf 

661 try: 

662 import h5netcdf # noqa: F401 

663 engine = "h5netcdf" 

664 if not compress: 

665 return engine, None 

666 enc = {} 

667 for name, da in ds.data_vars.items(): 

668 if np.issubdtype(da.dtype, np.number): 

669 enc[name] = {"compression": "gzip", "compression_opts": 4} 

670 return engine, enc 

671 except Exception: 

672 pass 

673 

674 # Finally scipy (no compression supported) 

675 engine = "scipy" 

676 return engine, None