Coverage for psyop/main.py: 67.65%

238 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-10 06:02 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4# Make BLAS single-threaded to avoid oversubscription / macOS crashes 

5import os 

6for _env_var in ( 

7 "MKL_NUM_THREADS", 

8 "OPENBLAS_NUM_THREADS", 

9 "OMP_NUM_THREADS", 

10 "VECLIB_MAXIMUM_THREADS", 

11 "NUMEXPR_NUM_THREADS", 

12): 

13 os.environ.setdefault(_env_var, "1") 

14 

15import re 

16from enum import Enum 

17from pathlib import Path 

18from typing import Optional 

19import xarray as xr 

20import typer 

21from rich.console import Console 

22 

23from .model import build_model 

24from . import viz, opt 

25 

26__version__ = "0.1.0" 

27 

28console = Console() 

29app = typer.Typer(no_args_is_help=True, add_completion=False, rich_markup_mode="rich") 

30 

31 

32class Direction(str, Enum): 

33 MINIMIZE = "min" 

34 MAXIMIZE = "max" 

35 AUTO = "auto" 

36 

37 

38def _strip_quotes(s: str) -> str: 

39 s = s.strip() 

40 if (s.startswith("'") and s.endswith("'")) or (s.startswith('"') and s.endswith('"')): 

41 return s[1:-1] 

42 return s 

43 

44_num_re = re.compile( 

45 r""" 

46 ^\s* 

47 ([+-]? # sign 

48 (?: 

49 (?:\d+(?:\.\d*)?|\.\d+) # 123, 123., .123, 123.456 

50 (?:[eE][+-]?\d+)? # optional exponent 

51 ) 

52 ) 

53 \s*$ 

54 """, 

55 re.VERBOSE, 

56) 

57 

58def _is_intlike_str(s: str) -> bool: 

59 try: 

60 f = float(s) 

61 return float(int(round(f))) == f 

62 except Exception: 

63 return False 

64 

65def _to_number(s: str): 

66 """Return int if int-like, else float, else raise.""" 

67 if not _num_re.match(s): 

68 raise ValueError 

69 v = float(s) 

70 return int(round(v)) if _is_intlike_str(s) else v 

71 

72def _parse_list_like(s: str) -> list: 

73 # Accept comma-separated, optionally wrapped in [] or (). 

74 s = s.strip() 

75 if s.startswith(("(", "[")) and s.endswith((")", "]")): 

76 s = s[1:-1] 

77 parts = [p.strip() for p in s.split(",") if p.strip() != ""] 

78 out = [] 

79 for p in parts: 

80 p = _strip_quotes(p) 

81 try: 

82 out.append(_to_number(p)) 

83 except Exception: 

84 out.append(p) # leave as raw string if not numeric 

85 return out 

86 

87def _parse_range_call(s: str) -> tuple: 

88 """ 

89 Parse 'range(a,b[,step])' inclusive on the upper bound for ints. 

90 Returns a tuple of ints. 

91 """ 

92 m = re.fullmatch(r"\s*range\s*\(\s*([^)]*)\s*\)\s*", s) 

93 if not m: 

94 raise ValueError 

95 args = [t.strip() for t in m.group(1).split(",") if t.strip() != ""] 

96 if len(args) not in (2, 3): 

97 raise ValueError 

98 a = int(_to_number(args[0])) 

99 b = int(_to_number(args[1])) 

100 step = int(_to_number(args[2])) if len(args) == 3 else 1 

101 if step == 0: 

102 raise ValueError 

103 lo, hi = (a, b) if a <= b else (b, a) 

104 # inclusive upper bound 

105 seq = list(range(lo, hi + 1, abs(step))) 

106 return tuple(seq) 

107 

108 

109def _parse_colon_or_dots(s: str) -> slice: 

110 """ 

111 Parse 'a:b', 'a..b', or 'a:b:step' → slice(start, stop, step). 

112 - Works for int or float endpoints (we sort so start <= stop). 

113 - Step is optional; if present it can be int or float. 

114 - Any token with ':' (or '..') yields a slice (no tuples). 

115 """ 

116 s_norm = re.sub(r"\.\.+", ":", s.strip()) 

117 parts = [p.strip() for p in s_norm.split(":")] 

118 if len(parts) not in (2, 3): 

119 raise ValueError(f"Not a range: {s!r}") 

120 

121 a_str, b_str = parts[0], parts[1] 

122 a_num = _to_number(a_str) 

123 b_num = _to_number(b_str) 

124 if not isinstance(a_num, (int, float)) or not isinstance(b_num, (int, float)): 

125 raise ValueError(f"Non-numeric range endpoints: {s!r}") 

126 

127 start, stop = (a_num, b_num) 

128 if start > stop: 

129 start, stop = stop, start # normalize to ascending 

130 

131 step = None 

132 if len(parts) == 3 and parts[2] != "": 

133 step_val = _to_number(parts[2]) 

134 if not isinstance(step_val, (int, float)) or step_val == 0: 

135 raise ValueError(f"Invalid step in range: {s!r}") 

136 step = step_val 

137 

138 return slice(start, stop, step) 

139 

140def _parse_constraint_value(text: str): 

141 """ 

142 Convert a CLI string into one of: 

143 - number (int/float) -> fixed 

144 - slice(lo, hi) -> float range (inclusive ends) 

145 - list/tuple -> choices (finite set) 

146 - 'range(a,b[,s])' -> tuple of ints 

147 """ 

148 raw = _strip_quotes(str(text)) 

149 

150 # 1) range(...) 

151 try: 

152 return _parse_range_call(raw) 

153 except Exception: 

154 pass 

155 

156 # 2) bracketed/parenthesized list or plain comma-separated list 

157 if (raw.startswith(("[", "(")) and raw.endswith(("]", ")"))) or ("," in raw and " " not in raw[:2]): 

158 items = _parse_list_like(raw) 

159 # coerce homogeneous int-like to tuple[int], else list[float] 

160 if all(isinstance(v, int) for v in items): 

161 return tuple(items) 

162 return items 

163 

164 # 3) colon / dot ranges (a:b[:step] or a..b) 

165 if ":" in raw or ".." in raw: 

166 try: 

167 return _parse_colon_or_dots(raw) 

168 except Exception: 

169 pass 

170 

171 # 4) scalar number 

172 try: 

173 return _to_number(raw) 

174 except Exception: 

175 return raw # fallback to string (rare; typically ignored downstream) 

176 

177 

178def _parse_unknown_cli_kv_text(args: list[str]) -> dict[str, str]: 

179 """ 

180 Extract unknown --key value pairs as raw strings (no coercion here). 

181 Supports: --k=v and --k v. Repeated keys -> last wins. 

182 """ 

183 out: dict[str, str] = {} 

184 it = iter(args) 

185 for tok in it: 

186 if not tok.startswith("--"): 

187 continue 

188 key = tok[2:] 

189 if "=" in key: 

190 k, v = key.split("=", 1) 

191 else: 

192 k = key 

193 try: 

194 nxt = next(it) 

195 except StopIteration: 

196 nxt = "true" 

197 if nxt.startswith("--"): 

198 # treat as flag without value; put it back by ignoring and storing "true" 

199 v = "true" 

200 else: 

201 v = nxt 

202 out[k.strip().replace("-", "_")] = v 

203 return out 

204 

205def _norm(s: str) -> str: 

206 return re.sub(r"[^a-z0-9]+", "", s.lower()) 

207 

208 

209def _canonicalize_feature_keys( 

210 model: xr.Dataset | Path, raw_map: dict[str, object] 

211) -> tuple[dict[str, object], dict[str, str]]: 

212 """ 

213 Map user keys (any style) to either: 

214 - dataset *feature* names (numeric, including one-hot *member* names), OR 

215 - *categorical base* names (e.g., 'language') detected from one-hot blocks. 

216 

217 Returns (mapped, alias). Unmatched keys are dropped with a warning. 

218 

219 Notes: 

220 - Exact matches win. 

221 - Then normalized matches for full feature names. 

222 - Then normalized matches for categorical bases (so '--language "Linear A"' is preserved as 'language': 'Linear A'). 

223 """ 

224 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model) 

225 features = [str(x) for x in ds["feature"].values.tolist()] 

226 

227 # Indexes for feature names and categorical bases 

228 feature_norm_index = {_norm(f): f for f in features} 

229 feature_set = set(features) 

230 bases, base_norm_index = _categorical_bases_from_features(features) 

231 

232 mapped: dict[str, object] = {} 

233 alias: dict[str, str] = {} 

234 

235 for k, v in (raw_map or {}).items(): 

236 # 1) Exact feature match 

237 if k in feature_set: 

238 mapped[k] = v 

239 alias[k] = k 

240 continue 

241 

242 nk = _norm(k) 

243 

244 # 2) Normalized feature match (full feature/member name) 

245 if nk in feature_norm_index: 

246 canonical = feature_norm_index[nk] 

247 mapped[canonical] = v 

248 alias[k] = canonical 

249 continue 

250 

251 # 3) Categorical base match (exact or normalized) 

252 if (k in bases) or (nk in base_norm_index): 

253 base = k if k in bases else base_norm_index[nk] 

254 mapped[base] = v 

255 alias[k] = base 

256 continue 

257 

258 raise typer.BadParameter(f"Unknown feature key: {k}") 

259 

260 return mapped, alias 

261 

262 

263 

264def parse_constraints_from_ctx(ctx: typer.Context, model: xr.Dataset | Path) -> dict[str, object]: 

265 """ 

266 End-to-end: ctx.args → {key: constraint_object}. 

267 

268 Values can be: 

269 - number (int/float) -> fixed 

270 - slice(lo, hi) -> float range (inclusive ends) 

271 - list/tuple -> finite choices 

272 - tuple from range(...) (int choices) 

273 - string -> categorical label (e.g., --language "Linear A") 

274 """ 

275 raw_kv = _parse_unknown_cli_kv_text(ctx.args) 

276 parsed: dict[str, object] = {k: _parse_constraint_value(v) for k, v in raw_kv.items()} 

277 

278 # Canonicalize keys to either feature names OR categorical bases 

279 constraints, _ = _canonicalize_feature_keys(model, parsed) 

280 

281 # Normalize: convert range objects to tuples of ints (choices) 

282 for k, v in list(constraints.items()): 

283 if isinstance(v, range): 

284 constraints[k] = tuple(v) 

285 

286 # Pretty print constraints 

287 if constraints: 

288 def _fmt_value(val: object) -> str: 

289 if isinstance(val, slice): 

290 # show start:stop (ignore step in preview) 

291 lo = getattr(val, "start", None) 

292 hi = getattr(val, "stop", None) 

293 return f"[{lo},{hi}]" 

294 if isinstance(val, (list, tuple, range)): 

295 return f"{tuple(val)}" 

296 if isinstance(val, str): 

297 # quote strings if they have spaces or special chars 

298 if re.search(r'\s|[,=:\.]', val): 

299 return f'"{val}"' 

300 return val 

301 return str(val) 

302 

303 pretty = ", ".join(f"{k}={_fmt_value(constraints[k])}" for k in constraints) 

304 console.print(f"[cyan]Constraints:[/] {pretty}") 

305 

306 return constraints 

307 

308 

309@app.callback() 

310def main( 

311 version: bool = typer.Option(False, "--version", "-v", help="Show version and exit.", is_eager=True), 

312): 

313 if version: 

314 console.print(f"[bold]psyop[/] {__version__}") 

315 raise typer.Exit() 

316 

317 

318@app.command(help="Fit the model on a CSV and save a single model artifact.") 

319def model( 

320 input: Path = typer.Argument(..., help="Input CSV file."), 

321 output: Path = typer.Argument(..., help="Path to save model artifact (.nc)."), 

322 target: str = typer.Option("loss", "--target", "-t", help="Target column name."), 

323 exclude: list[str] = typer.Option([], "--exclude", help="Feature columns to exclude."), 

324 direction: Direction = typer.Option( 

325 Direction.AUTO, "--direction", "-d", 

326 help="Optimization direction for the target." 

327 ), 

328 seed: int = typer.Option(0, "--seed", help="Random seed for fitting/sampling."), 

329 compress: bool = typer.Option(True, help="Apply compression inside the artifact."), 

330): 

331 if not input.exists(): 

332 raise typer.BadParameter(f"Input CSV not found: {input.resolve()}") 

333 if input.suffix.lower() != ".csv": 

334 console.print(":warning: [yellow]Input does not end with .csv[/]") 

335 

336 build_model( 

337 input=input, 

338 target=target, 

339 output=output, 

340 exclude=exclude, 

341 direction=direction.value, 

342 seed=seed, 

343 compress=compress, 

344 ) 

345 console.print(f"[green]Wrote model artifact →[/] {output}") 

346 

347 

348@app.command( 

349 help="Suggest BO candidates (constrained EI + exploration).", 

350 context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, 

351) 

352def suggest( 

353 ctx: typer.Context, 

354 model: Path = typer.Argument(..., help="Path to the model artifact (.nc)."), 

355 output: Optional[Path] = typer.Option(None, "--output", "-o", help="Where to save candidates CSV (defaults relative to model)."), 

356 count: int = typer.Option(1, "--count", "-k", help="Number of candidates to propose."), 

357 success_threshold: float = typer.Option(0.8, help="Feasibility threshold for constrained EI."), 

358 explore: float = typer.Option(0.34, help="Fraction of suggestions reserved for exploration."), 

359 # candidates: int = typer.Option(5000, help="Random candidate pool size to score."), 

360 seed: int = typer.Option(0, "--seed", help="Random seed for proposals."), 

361): 

362 if not model.exists(): 

363 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}") 

364 

365 model = xr.load_dataset(model) 

366 constraints = parse_constraints_from_ctx(ctx, model) 

367 

368 opt.suggest( 

369 model=model, 

370 output=output, 

371 count=count, 

372 success_threshold=success_threshold, 

373 explore=explore, 

374 # candidates=candidates, 

375 seed=seed, 

376 **constraints, 

377 ) 

378 if output: 

379 console.print(f"[green]Wrote proposals →[/] {output}") 

380 

381 

382@app.command( 

383 help="Rank points by probability of being the best feasible minimum.", 

384 context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, 

385) 

386def optimal( 

387 ctx: typer.Context, 

388 model: Path = typer.Argument(..., help="Path to the model artifact (.nc)."), 

389 output: Path|None = typer.Option(None, help="Where to save top candidates CSV (defaults relative to model)."), 

390 count: int = typer.Option(10, "--count", "-k", help="How many top rows to keep."), 

391 draws: int = typer.Option(2000, "--draws", help="Monte Carlo draws."), 

392 min_success_probability: float = typer.Option(0.0, "--min-p-success", help="Hard feasibility cutoff (0 disables)."), 

393 seed: int = typer.Option(0, "--seed", help="Random seed for MC."), 

394): 

395 if not model.exists(): 

396 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}") 

397 

398 model = xr.load_dataset(model) 

399 constraints = parse_constraints_from_ctx(ctx, model) 

400 

401 opt.optimal( 

402 model=model, 

403 output=output, 

404 # count=count, 

405 # n_draws=draws, 

406 # min_success_probability=min_success_probability, 

407 seed=seed, 

408 **constraints, 

409 ) 

410 if output: 

411 console.print(f"[green]Wrote top probable minima →[/] {output}") 

412 

413 

414@app.command( 

415 help="Create a 2D Partial Dependence of Expected Target (Pairwise Features).", 

416 context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, 

417) 

418def plot2d( 

419 ctx: typer.Context, 

420 model: Path = typer.Argument(..., help="Path to the model artifact (.nc)."), 

421 output: Path|None = typer.Option(None, help="Output HTML (defaults relative to model)."), 

422 grid_size: int = typer.Option(70, help="Grid size per axis for 2D panels."), 

423 use_log_scale_for_target: bool = typer.Option(False, help="Log10 colors for target."), 

424 log_shift_epsilon: float = typer.Option(1e-9, help="Epsilon shift for log colors."), 

425 colorscale: str = typer.Option("RdBu", help="Colorscale name."), 

426 show: bool|None = typer.Option(None, help="Open the figure in a browser."), 

427 n_contours: int = typer.Option(12, help="Number of contour levels."), 

428 optimal: bool = typer.Option(True, help="Include optimal points."), 

429 suggest: int = typer.Option(0, help="Number of suggested points."), 

430 width: int|None = typer.Option(None, help="Width of each panel in pixels (default auto)."), 

431 height: int|None = typer.Option(None, help="Height of each panel in pixels (default auto)."), 

432 seed: int = typer.Option(42, help="Random seed for suggested points."), 

433): 

434 if not model.exists(): 

435 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}") 

436 

437 model = xr.load_dataset(model) 

438 constraints = parse_constraints_from_ctx(ctx, model) 

439 

440 show = show if show is not None else output is None # default to True if no output file 

441 

442 viz.plot2d( 

443 model=model, 

444 output=output, 

445 grid_size=grid_size, 

446 use_log_scale_for_target=use_log_scale_for_target, 

447 log_shift_epsilon=log_shift_epsilon, 

448 colorscale=colorscale, 

449 show=show, 

450 n_contours=n_contours, 

451 optimal=optimal, 

452 suggest=suggest, 

453 width=width, 

454 height=height, 

455 seed=seed, 

456 **constraints, 

457 ) 

458 if output: 

459 console.print(f"[green]Wrote pairplot →[/] {output}") 

460 

461 

462@app.command( 

463 help="Create 1D Partial Dependence panels.", 

464 context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, 

465) 

466def plot1d( 

467 ctx: typer.Context, 

468 model: Path = typer.Argument(..., help="Path to the model artifact (.nc)."), 

469 output: Path|None = typer.Option(None, help="Output HTML (defaults relative to model)."), 

470 csv_out: Path|None = typer.Option(None, help="Optional CSV export of tidy PD data."), 

471 grid_size: int = typer.Option(300, help="Points along 1D sweep."), 

472 line_color: str = typer.Option("rgb(31,119,180)", help="Line/band color (consistent across variables)."), 

473 band_alpha: float = typer.Option(0.25, help="Fill alpha for ±2σ."), 

474 figure_height_per_row_px: int = typer.Option(320, help="Pixels per PD row."), 

475 show: bool|None = typer.Option(None, help="Open the figure in a browser."), 

476 use_log_scale_for_target_y: bool = typer.Option(True, "--log-y/--no-log-y", help="Log scale for target (Y)."), 

477 log_y_epsilon: float = typer.Option(1e-9, "--log-y-eps", help="Clamp for log-Y."), 

478 optimal: bool = typer.Option(True, help="Include optimal points."), 

479 suggest: int = typer.Option(0, help="Number of suggested points."), 

480 width: int|None = typer.Option(None, help="Width of each panel in pixels (default auto)."), 

481 height: int|None = typer.Option(None, help="Height of each panel in pixels (default auto)."), 

482 seed: int = typer.Option(42, help="Random seed for suggested points."), 

483): 

484 if not model.exists(): 

485 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}") 

486 

487 show = show if show is not None else output is None # default to True if no output file 

488 

489 model = xr.load_dataset(model) 

490 constraints = parse_constraints_from_ctx(ctx, model) 

491 

492 viz.plot1d( 

493 model=model, 

494 output=output, 

495 csv_out=csv_out, 

496 grid_size=grid_size, 

497 line_color=line_color, 

498 band_alpha=band_alpha, 

499 figure_height_per_row_px=figure_height_per_row_px, 

500 show=show, 

501 use_log_scale_for_target_y=use_log_scale_for_target_y, 

502 log_y_epsilon=log_y_epsilon, 

503 optimal=optimal, 

504 suggest=suggest, 

505 width=width, 

506 height=height, 

507 seed=seed, 

508 **constraints, 

509 ) 

510 if output: 

511 console.print(f"[green]Wrote PD HTML →[/] {output}") 

512 if csv_out: 

513 console.print(f"[green]Wrote PD CSV →[/] {csv_out}") 

514 

515 

516def _categorical_bases_from_features(features: list[str]) -> tuple[set[str], dict[str, str]]: 

517 """ 

518 Given model feature names (which may include one-hot members like 'language=Linear A'), 

519 return: 

520 - bases: a set of base names, e.g. {'language'} 

521 - base_norm_index: mapping from normalized base name -> canonical base string 

522 """ 

523 bases: set[str] = set() 

524 for f in features: 

525 if "=" in f: 

526 base = f.split("=", 1)[0].strip() 

527 if base: 

528 bases.add(base) 

529 # normalized index for lookup 

530 base_norm_index = {_norm(b): b for b in bases} 

531 return bases, base_norm_index 

532 

533 

534if __name__ == "__main__": 

535 app()