Coverage for psyop/main.py: 64.07%

270 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-10-29 03:44 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4# Make BLAS single-threaded to avoid oversubscription / macOS crashes 

5import os 

6for _env_var in ( 

7 "MKL_NUM_THREADS", 

8 "OPENBLAS_NUM_THREADS", 

9 "OMP_NUM_THREADS", 

10 "VECLIB_MAXIMUM_THREADS", 

11 "NUMEXPR_NUM_THREADS", 

12): 

13 os.environ.setdefault(_env_var, "1") 

14 

15import re 

16from enum import Enum 

17from pathlib import Path 

18from typing import Optional 

19import xarray as xr 

20import typer 

21from rich.console import Console 

22 

23from .model import build_model 

24from . import viz, opt 

25 

26__version__ = "0.1.0" 

27 

28console = Console() 

29app = typer.Typer(no_args_is_help=True, add_completion=False, rich_markup_mode="rich") 

30 

31 

32class Direction(str, Enum): 

33 MINIMIZE = "min" 

34 MAXIMIZE = "max" 

35 

36 

37def _strip_quotes(s: str) -> str: 

38 s = s.strip() 

39 if (s.startswith("'") and s.endswith("'")) or (s.startswith('"') and s.endswith('"')): 

40 return s[1:-1] 

41 return s 

42 

43_num_re = re.compile( 

44 r""" 

45 ^\s* 

46 ([+-]? # sign 

47 (?: 

48 (?:\d+(?:\.\d*)?|\.\d+) # 123, 123., .123, 123.456 

49 (?:[eE][+-]?\d+)? # optional exponent 

50 ) 

51 ) 

52 \s*$ 

53 """, 

54 re.VERBOSE, 

55) 

56 

57def _is_intlike_str(s: str) -> bool: 

58 try: 

59 f = float(s) 

60 return float(int(round(f))) == f 

61 except Exception: 

62 return False 

63 

64def _to_number(s: str): 

65 """Return int if int-like, else float, else raise.""" 

66 if not _num_re.match(s): 

67 raise ValueError 

68 v = float(s) 

69 return int(round(v)) if _is_intlike_str(s) else v 

70 

71def _parse_list_like(s: str) -> list: 

72 # Accept comma-separated, optionally wrapped in [] or (). 

73 s = s.strip() 

74 if s.startswith(("(", "[")) and s.endswith((")", "]")): 

75 s = s[1:-1] 

76 parts = [p.strip() for p in s.split(",") if p.strip() != ""] 

77 out = [] 

78 for p in parts: 

79 p = _strip_quotes(p) 

80 try: 

81 out.append(_to_number(p)) 

82 except Exception: 

83 out.append(p) # leave as raw string if not numeric 

84 return out 

85 

86def _parse_range_call(s: str) -> tuple: 

87 """ 

88 Parse 'range(a,b[,step])' inclusive on the upper bound for ints. 

89 Returns a tuple of ints. 

90 """ 

91 m = re.fullmatch(r"\s*range\s*\(\s*([^)]*)\s*\)\s*", s) 

92 if not m: 

93 raise ValueError 

94 args = [t.strip() for t in m.group(1).split(",") if t.strip() != ""] 

95 if len(args) not in (2, 3): 

96 raise ValueError 

97 a = int(_to_number(args[0])) 

98 b = int(_to_number(args[1])) 

99 step = int(_to_number(args[2])) if len(args) == 3 else 1 

100 if step == 0: 

101 raise ValueError 

102 lo, hi = (a, b) if a <= b else (b, a) 

103 # inclusive upper bound 

104 seq = list(range(lo, hi + 1, abs(step))) 

105 return tuple(seq) 

106 

107 

108def _parse_colon_or_dots(s: str) -> slice: 

109 """ 

110 Parse 'a:b', 'a..b', or 'a:b:step' → slice(start, stop, step). 

111 - Works for int or float endpoints (we sort so start <= stop). 

112 - Step is optional; if present it can be int or float. 

113 - Any token with ':' (or '..') yields a slice (no tuples). 

114 """ 

115 s_norm = re.sub(r"\.\.+", ":", s.strip()) 

116 parts = [p.strip() for p in s_norm.split(":")] 

117 if len(parts) not in (2, 3): 

118 raise ValueError(f"Not a range: {s!r}") 

119 

120 a_str, b_str = parts[0], parts[1] 

121 a_num = _to_number(a_str) 

122 b_num = _to_number(b_str) 

123 if not isinstance(a_num, (int, float)) or not isinstance(b_num, (int, float)): 

124 raise ValueError(f"Non-numeric range endpoints: {s!r}") 

125 

126 start, stop = (a_num, b_num) 

127 if start > stop: 

128 start, stop = stop, start # normalize to ascending 

129 

130 step = None 

131 if len(parts) == 3 and parts[2] != "": 

132 step_val = _to_number(parts[2]) 

133 if not isinstance(step_val, (int, float)) or step_val == 0: 

134 raise ValueError(f"Invalid step in range: {s!r}") 

135 step = step_val 

136 

137 return slice(start, stop, step) 

138 

139def _parse_constraint_value(text: str): 

140 """ 

141 Convert a CLI string into one of: 

142 - number (int/float) -> fixed 

143 - slice(lo, hi) -> float range (inclusive ends) 

144 - list/tuple -> choices (finite set) 

145 - 'range(a,b[,s])' -> tuple of ints 

146 """ 

147 raw = _strip_quotes(str(text)) 

148 

149 # 1) range(...) 

150 try: 

151 return _parse_range_call(raw) 

152 except Exception: 

153 pass 

154 

155 # 2) bracketed/parenthesized list or plain comma-separated list 

156 if (raw.startswith(("[", "(")) and raw.endswith(("]", ")"))) or ("," in raw and " " not in raw[:2]): 

157 items = _parse_list_like(raw) 

158 # coerce homogeneous int-like to tuple[int], else list[float] 

159 if all(isinstance(v, int) for v in items): 

160 return tuple(items) 

161 return items 

162 

163 # 3) colon / dot ranges (a:b[:step] or a..b) 

164 if ":" in raw or ".." in raw: 

165 try: 

166 return _parse_colon_or_dots(raw) 

167 except Exception: 

168 pass 

169 

170 # 4) scalar number 

171 try: 

172 return _to_number(raw) 

173 except Exception: 

174 return raw # fallback to string (rare; typically ignored downstream) 

175 

176 

177def _parse_unknown_cli_kv_text(args: list[str]) -> dict[str, str]: 

178 """ 

179 Extract unknown --key value pairs as raw strings (no coercion here). 

180 Supports: --k=v and --k v. Repeated keys -> last wins. 

181 """ 

182 out: dict[str, str] = {} 

183 it = iter(args) 

184 for tok in it: 

185 if not tok.startswith("--"): 

186 continue 

187 key = tok[2:] 

188 if "=" in key: 

189 k, v = key.split("=", 1) 

190 else: 

191 k = key 

192 try: 

193 nxt = next(it) 

194 except StopIteration: 

195 nxt = "true" 

196 if nxt.startswith("--"): 

197 # treat as flag without value; put it back by ignoring and storing "true" 

198 v = "true" 

199 else: 

200 v = nxt 

201 out[k.strip().replace("-", "_")] = v 

202 return out 

203 

204def _norm(s: str) -> str: 

205 return re.sub(r"[^a-z0-9]+", "", s.lower()) 

206 

207 

208def _canonicalize_feature_keys( 

209 model: xr.Dataset | Path, raw_map: dict[str, object] 

210) -> tuple[dict[str, object], dict[str, str]]: 

211 """ 

212 Map user keys (any style) to either: 

213 - dataset *feature* names (numeric, including one-hot *member* names), OR 

214 - *categorical base* names (e.g., 'language') detected from one-hot blocks. 

215 

216 Returns (mapped, alias). Unmatched keys are dropped with a warning. 

217 

218 Notes: 

219 - Exact matches win. 

220 - Then normalized matches for full feature names. 

221 - Then normalized matches for categorical bases (so '--language "Linear A"' is preserved as 'language': 'Linear A'). 

222 """ 

223 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model) 

224 features = [str(x) for x in ds["feature"].values.tolist()] 

225 

226 # Indexes for feature names and categorical bases 

227 feature_norm_index = {_norm(f): f for f in features} 

228 feature_set = set(features) 

229 bases, base_norm_index = _categorical_bases_from_features(features) 

230 

231 mapped: dict[str, object] = {} 

232 alias: dict[str, str] = {} 

233 

234 for k, v in (raw_map or {}).items(): 

235 # 1) Exact feature match 

236 if k in feature_set: 

237 mapped[k] = v 

238 alias[k] = k 

239 continue 

240 

241 nk = _norm(k) 

242 

243 # 2) Normalized feature match (full feature/member name) 

244 if nk in feature_norm_index: 

245 canonical = feature_norm_index[nk] 

246 mapped[canonical] = v 

247 alias[k] = canonical 

248 continue 

249 

250 # 3) Categorical base match (exact or normalized) 

251 if (k in bases) or (nk in base_norm_index): 

252 base = k if k in bases else base_norm_index[nk] 

253 mapped[base] = v 

254 alias[k] = base 

255 continue 

256 

257 raise typer.BadParameter(f"Unknown feature key: {k}") 

258 

259 return mapped, alias 

260 

261 

262 

263def parse_constraints_from_ctx(ctx: typer.Context, model: xr.Dataset | Path) -> dict[str, object]: 

264 """ 

265 End-to-end: ctx.args → {key: constraint_object}. 

266 

267 Values can be: 

268 - number (int/float) -> fixed 

269 - slice(lo, hi) -> float range (inclusive ends) 

270 - list/tuple -> finite choices 

271 - tuple from range(...) (int choices) 

272 - string -> categorical label (e.g., --language "Linear A") 

273 """ 

274 raw_kv = _parse_unknown_cli_kv_text(ctx.args) 

275 parsed: dict[str, object] = {k: _parse_constraint_value(v) for k, v in raw_kv.items()} 

276 

277 # Canonicalize keys to either feature names OR categorical bases 

278 constraints, _ = _canonicalize_feature_keys(model, parsed) 

279 

280 # Normalize: convert range objects to tuples of ints (choices) 

281 for k, v in list(constraints.items()): 

282 if isinstance(v, range): 

283 constraints[k] = tuple(v) 

284 

285 # Pretty print constraints 

286 if constraints: 

287 def _fmt_value(val: object) -> str: 

288 if isinstance(val, slice): 

289 # show start:stop (ignore step in preview) 

290 lo = getattr(val, "start", None) 

291 hi = getattr(val, "stop", None) 

292 return f"[{lo},{hi}]" 

293 if isinstance(val, (list, tuple, range)): 

294 return f"{tuple(val)}" 

295 if isinstance(val, str): 

296 # quote strings if they have spaces or special chars 

297 if re.search(r'\s|[,=:\.]', val): 

298 return f'"{val}"' 

299 return val 

300 return str(val) 

301 

302 pretty = ", ".join(f"{k}={_fmt_value(constraints[k])}" for k in constraints) 

303 console.print(f"[cyan]Constraints:[/] {pretty}") 

304 

305 return constraints 

306 

307 

308@app.callback() 

309def main( 

310 version: bool = typer.Option(False, "--version", "-v", help="Show version and exit.", is_eager=True), 

311): 

312 if version: 

313 console.print(f"[bold]psyop[/] {__version__}") 

314 raise typer.Exit() 

315 

316 

317@app.command(help="Fit the model on a CSV and save a single model artifact.") 

318def model( 

319 input: Path = typer.Argument(..., help="Input CSV file."), 

320 output: Path = typer.Argument(..., help="Path to save model artifact [.psyop]."), 

321 target: str = typer.Option("loss", "--target", "-t", help="Target column name."), 

322 exclude: list[str] = typer.Option([], help="Feature columns to exclude."), 

323 direction: Direction = typer.Option( 

324 Direction.MINIMIZE, "--direction", "-d", 

325 help="Optimization direction for the target." 

326 ), 

327 seed: int = typer.Option(0, "--seed", help="Random seed for fitting/sampling."), 

328 compress: bool = typer.Option(True, help="Apply compression inside the artifact."), 

329 prior_model: Path | None = typer.Option(None, help="Existing model artifact used to warm-start parameter optimization."), 

330): 

331 if not input.exists(): 

332 raise typer.BadParameter(f"Input CSV not found: {input.resolve()}") 

333 if input.suffix.lower() != ".csv": 

334 console.print(":warning: [yellow]Input does not end with .csv[/]") 

335 if prior_model is not None and not prior_model.exists(): 

336 raise typer.BadParameter(f"Prior model artifact not found: {prior_model.resolve()}") 

337 

338 build_model( 

339 input=input, 

340 target=target, 

341 output=output, 

342 exclude=exclude, 

343 direction=direction.value, 

344 seed=seed, 

345 compress=compress, 

346 prior_model=prior_model, 

347 ) 

348 console.print(f"[green]Wrote model artifact →[/] {output}") 

349 

350 

351@app.command( 

352 help="Suggest BO candidates (constrained EI + exploration).", 

353 context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, 

354) 

355def suggest( 

356 ctx: typer.Context, 

357 model: Path = typer.Argument(..., help="Path to the model artifact [.psyop]."), 

358 output: Optional[Path] = typer.Option(None, "--output", "-o", help="Where to save candidates CSV (defaults relative to model)."), 

359 count: int = typer.Option(1, "--count", "-k", help="Number of candidates to propose."), 

360 success_threshold: float = typer.Option(0.8, help="Feasibility threshold for constrained EI."), 

361 explore: float = typer.Option(0.34, help="Fraction of suggestions reserved for exploration."), 

362 seed: int = typer.Option(0, help="Random seed for proposals."), 

363): 

364 if not model.exists(): 

365 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}") 

366 

367 model = xr.load_dataset(model) 

368 constraints = parse_constraints_from_ctx(ctx, model) 

369 

370 opt.suggest( 

371 model=model, 

372 output=output, 

373 count=count, 

374 success_threshold=success_threshold, 

375 explore=explore, 

376 seed=seed, 

377 **constraints, 

378 ) 

379 if output: 

380 console.print(f"[green]Wrote proposals →[/] {output}") 

381 

382 

383@app.command( 

384 help="Rank points by probability of being the best feasible minimum.", 

385 context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, 

386) 

387def optimal( 

388 ctx: typer.Context, 

389 model: Path = typer.Argument(..., help="Path to the model artifact [.psyop]."), 

390 output: Path|None = typer.Option(None, help="Where to save top candidates CSV (defaults relative to model)."), 

391 seed: int = typer.Option(0, help="Random seed for MC."), 

392): 

393 if not model.exists(): 

394 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}") 

395 

396 model = xr.load_dataset(model) 

397 constraints = parse_constraints_from_ctx(ctx, model) 

398 

399 opt.optimal( 

400 model=model, 

401 output=output, 

402 seed=seed, 

403 **constraints, 

404 ) 

405 if output: 

406 console.print(f"[green]Wrote top probable minima →[/] {output}") 

407 

408 

409@app.command( 

410 help="Create a 2D Partial Dependence of Expected Target (Pairwise Features).", 

411 context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, 

412) 

413def plot2d( 

414 ctx: typer.Context, 

415 model: Path = typer.Argument(..., help="Path to the model artifact [.psyop]."), 

416 output: Path|None = typer.Option(None, help="Output HTML (defaults relative to model)."), 

417 grid_size: int = typer.Option(70, help="Grid size per axis for 2D panels."), 

418 use_log_scale_for_target: bool = typer.Option(False, help="Log10 colors for target."), 

419 log_shift_epsilon: float = typer.Option(1e-9, help="Epsilon shift for log colors."), 

420 colorscale: str = typer.Option("RdBu", help="Colorscale name."), 

421 show: bool|None = typer.Option(None, help="Open the figure in a browser."), 

422 n_contours: int = typer.Option(12, help="Number of contour levels."), 

423 optimal: bool = typer.Option(True, help="Include optimal points."), 

424 suggest: int = typer.Option(0, help="Number of suggested points."), 

425 width: int = typer.Option(1000, help="Width of each panel in pixels."), 

426 height: int = typer.Option(1000, help="Height of each panel in pixels."), 

427 seed: int = typer.Option(42, help="Random seed for suggested points."), 

428): 

429 if not model.exists(): 

430 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}") 

431 

432 model = xr.load_dataset(model) 

433 constraints = parse_constraints_from_ctx(ctx, model) 

434 

435 show = show if show is not None else output is None # default to True if no output file 

436 

437 viz.plot2d( 

438 model=model, 

439 output=output, 

440 grid_size=grid_size, 

441 use_log_scale_for_target=use_log_scale_for_target, 

442 log_shift_epsilon=log_shift_epsilon, 

443 colorscale=colorscale, 

444 show=show, 

445 n_contours=n_contours, 

446 optimal=optimal, 

447 suggest=suggest, 

448 width=width, 

449 height=height, 

450 seed=seed, 

451 **constraints, 

452 ) 

453 if output: 

454 console.print(f"[green]Wrote pairplot →[/] {output}") 

455 

456 

457@app.command( 

458 help="Create 1D Partial Dependence panels.", 

459 context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, 

460) 

461def plot1d( 

462 ctx: typer.Context, 

463 model: Path = typer.Argument(..., help="Path to the model artifact [.psyop]."), 

464 output: Path|None = typer.Option(None, help="Output HTML (defaults relative to model)."), 

465 csv_out: Path|None = typer.Option(None, help="Optional CSV export of tidy PD data."), 

466 grid_size: int = typer.Option(300, help="Points along 1D sweep."), 

467 line_color: str = typer.Option("blue", help="Line/band color (consistent across variables)."), 

468 band_alpha: float = typer.Option(0.25, help="Fill alpha for ±2σ."), 

469 show: bool|None = typer.Option(None, help="Open the figure in a browser."), 

470 use_log_scale_for_target_y: bool = typer.Option(True, "--log-y/--no-log-y", help="Log scale for target (Y)."), 

471 log_y_epsilon: float = typer.Option(1e-9, "--log-y-eps", help="Clamp for log-Y."), 

472 optimal: bool = typer.Option(True, help="Include optimal points."), 

473 suggest: int = typer.Option(0, help="Number of suggested points."), 

474 width: int = typer.Option(1000, help="Width of each panel in pixels."), 

475 height: int = typer.Option(1000, help="Height of each panel in pixels."), 

476 seed: int = typer.Option(42, help="Random seed for suggested points."), 

477): 

478 if not model.exists(): 

479 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}") 

480 

481 show = show if show is not None else output is None # default to True if no output file 

482 

483 model = xr.load_dataset(model) 

484 constraints = parse_constraints_from_ctx(ctx, model) 

485 

486 viz.plot1d( 

487 model=model, 

488 output=output, 

489 csv_out=csv_out, 

490 grid_size=grid_size, 

491 line_color=line_color, 

492 band_alpha=band_alpha, 

493 show=show, 

494 use_log_scale_for_target_y=use_log_scale_for_target_y, 

495 log_y_epsilon=log_y_epsilon, 

496 optimal=optimal, 

497 suggest=suggest, 

498 width=width, 

499 height=height, 

500 seed=seed, 

501 **constraints, 

502 ) 

503 if output: 

504 console.print(f"[green]Wrote PD HTML →[/] {output}") 

505 if csv_out: 

506 console.print(f"[green]Wrote PD CSV →[/] {csv_out}") 

507 

508 

509@app.command( 

510 help="Create 1D Partial Dependence panels.", 

511 context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, 

512) 

513def optimum_plot1d( 

514 ctx: typer.Context, 

515 model: Path = typer.Argument(..., help="Path to the model artifact [.psyop]."), 

516 output: Path|None = typer.Option(None, help="Output HTML (defaults relative to model)."), 

517 csv_out: Path|None = typer.Option(None, help="Optional CSV export of tidy PD data."), 

518 grid_size: int = typer.Option(300, help="Points along 1D sweep."), 

519 line_color: str = typer.Option("blue", help="Line/band color (consistent across variables)."), 

520 band_alpha: float = typer.Option(0.25, help="Fill alpha for ±2σ."), 

521 show: bool|None = typer.Option(None, help="Open the figure in a browser."), 

522 use_log_scale_for_target_y: bool = typer.Option(True, "--log-y/--no-log-y", help="Log scale for target (Y)."), 

523 log_y_epsilon: float = typer.Option(1e-9, "--log-y-eps", help="Clamp for log-Y."), 

524 optimal: bool = typer.Option(True, help="Include optimal points."), 

525 suggest: int = typer.Option(0, help="Number of suggested points."), 

526 width: int = typer.Option(1000, help="Width of each panel in pixels."), 

527 height: int = typer.Option(1000, help="Height of each panel in pixels."), 

528 seed: int = typer.Option(42, help="Random seed for suggested points."), 

529): 

530 if not model.exists(): 

531 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}") 

532 

533 show = show if show is not None else output is None # default to True if no output file 

534 

535 model = xr.load_dataset(model) 

536 constraints = parse_constraints_from_ctx(ctx, model) 

537 

538 viz.optimum_plot1d( 

539 model=model, 

540 output=output, 

541 csv_out=csv_out, 

542 grid_size=grid_size, 

543 line_color=line_color, 

544 band_alpha=band_alpha, 

545 show=show, 

546 use_log_scale_for_target_y=use_log_scale_for_target_y, 

547 log_y_epsilon=log_y_epsilon, 

548 optimal=optimal, 

549 suggest=suggest, 

550 width=width, 

551 height=height, 

552 seed=seed, 

553 **constraints, 

554 ) 

555 if output: 

556 console.print(f"[green]Wrote PD HTML →[/] {output}") 

557 if csv_out: 

558 console.print(f"[green]Wrote PD CSV →[/] {csv_out}") 

559 

560 

561@app.command( 

562 help="Create 2D Partial Dependence panels anchored at the optimum.", 

563 context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, 

564) 

565def optimum_plot2d( 

566 ctx: typer.Context, 

567 model: Path = typer.Argument(..., help="Path to the model artifact [.psyop]."), 

568 output: Path | None = typer.Option(None, help="Output HTML (defaults relative to model)."), 

569 grid_size: int = typer.Option(70, help="Grid size per axis for 2D panels."), 

570 use_log_scale_for_target: bool = typer.Option(False, help="Log10 colors for target."), 

571 log_shift_epsilon: float = typer.Option(1e-9, help="Epsilon shift for log colors."), 

572 colorscale: str = typer.Option("RdBu", help="Colorscale name."), 

573 show: bool | None = typer.Option(None, help="Open the figure in a browser."), 

574 n_contours: int = typer.Option(12, help="Number of contour levels."), 

575 optimal: bool = typer.Option(True, help="Show optimal point overlay."), 

576 suggest: int = typer.Option(0, help="Number of suggested points."), 

577 width: int = typer.Option(1000, help="Width of each panel in pixels."), 

578 height: int = typer.Option(1000, help="Height of each panel in pixels."), 

579 seed: int = typer.Option(42, help="Random seed for suggested points."), 

580): 

581 if not model.exists(): 

582 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}") 

583 

584 show = show if show is not None else output is None 

585 

586 model_ds = xr.load_dataset(model) 

587 constraints = parse_constraints_from_ctx(ctx, model_ds) 

588 

589 viz.optimum_plot2d( 

590 model=model_ds, 

591 output=output, 

592 grid_size=grid_size, 

593 use_log_scale_for_target=use_log_scale_for_target, 

594 log_shift_epsilon=log_shift_epsilon, 

595 colorscale=colorscale, 

596 show=show, 

597 n_contours=n_contours, 

598 optimal=optimal, 

599 suggest=suggest, 

600 width=width, 

601 height=height, 

602 seed=seed, 

603 **constraints, 

604 ) 

605 if output: 

606 console.print(f"[green]Wrote PD HTML →[/] {output}") 

607 

608 

609@app.command( 

610 help="Export CSV of data used to create the model.", 

611) 

612def export( 

613 model: Path = typer.Argument(..., help="Path to the model artifact [.psyop]."), 

614 output: Path = typer.Argument(..., help="Output CSV of data used to create model."), 

615): 

616 if not model.exists(): 

617 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}") 

618 model = xr.load_dataset(model) 

619 vars_row = [v for v, da in model.data_vars.items() if "row" in da.dims and "feature" not in da.dims] 

620 df = model[vars_row].to_dataframe().reset_index() 

621 df = df.drop( 

622 columns=['row', 'y_success', 'success_mask','pred_success_mu_train', 'pred_success_var_train'], 

623 errors='ignore', 

624 ) 

625 df.to_csv(output, index=False) 

626 

627 

628def _categorical_bases_from_features(features: list[str]) -> tuple[set[str], dict[str, str]]: 

629 """ 

630 Given model feature names (which may include one-hot members like 'language=Linear A'), 

631 return: 

632 - bases: a set of base names, e.g. {'language'} 

633 - base_norm_index: mapping from normalized base name -> canonical base string 

634 """ 

635 bases: set[str] = set() 

636 for f in features: 

637 if "=" in f: 

638 base = f.split("=", 1)[0].strip() 

639 if base: 

640 bases.add(base) 

641 # normalized index for lookup 

642 base_norm_index = {_norm(b): b for b in bases} 

643 return bases, base_norm_index 

644 

645 

646if __name__ == "__main__": 

647 app()