Coverage for psyop/main.py: 64.07%
270 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-10-29 03:44 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-10-29 03:44 +0000
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4# Make BLAS single-threaded to avoid oversubscription / macOS crashes
5import os
6for _env_var in (
7 "MKL_NUM_THREADS",
8 "OPENBLAS_NUM_THREADS",
9 "OMP_NUM_THREADS",
10 "VECLIB_MAXIMUM_THREADS",
11 "NUMEXPR_NUM_THREADS",
12):
13 os.environ.setdefault(_env_var, "1")
15import re
16from enum import Enum
17from pathlib import Path
18from typing import Optional
19import xarray as xr
20import typer
21from rich.console import Console
23from .model import build_model
24from . import viz, opt
26__version__ = "0.1.0"
28console = Console()
29app = typer.Typer(no_args_is_help=True, add_completion=False, rich_markup_mode="rich")
32class Direction(str, Enum):
33 MINIMIZE = "min"
34 MAXIMIZE = "max"
37def _strip_quotes(s: str) -> str:
38 s = s.strip()
39 if (s.startswith("'") and s.endswith("'")) or (s.startswith('"') and s.endswith('"')):
40 return s[1:-1]
41 return s
43_num_re = re.compile(
44 r"""
45 ^\s*
46 ([+-]? # sign
47 (?:
48 (?:\d+(?:\.\d*)?|\.\d+) # 123, 123., .123, 123.456
49 (?:[eE][+-]?\d+)? # optional exponent
50 )
51 )
52 \s*$
53 """,
54 re.VERBOSE,
55)
57def _is_intlike_str(s: str) -> bool:
58 try:
59 f = float(s)
60 return float(int(round(f))) == f
61 except Exception:
62 return False
64def _to_number(s: str):
65 """Return int if int-like, else float, else raise."""
66 if not _num_re.match(s):
67 raise ValueError
68 v = float(s)
69 return int(round(v)) if _is_intlike_str(s) else v
71def _parse_list_like(s: str) -> list:
72 # Accept comma-separated, optionally wrapped in [] or ().
73 s = s.strip()
74 if s.startswith(("(", "[")) and s.endswith((")", "]")):
75 s = s[1:-1]
76 parts = [p.strip() for p in s.split(",") if p.strip() != ""]
77 out = []
78 for p in parts:
79 p = _strip_quotes(p)
80 try:
81 out.append(_to_number(p))
82 except Exception:
83 out.append(p) # leave as raw string if not numeric
84 return out
86def _parse_range_call(s: str) -> tuple:
87 """
88 Parse 'range(a,b[,step])' inclusive on the upper bound for ints.
89 Returns a tuple of ints.
90 """
91 m = re.fullmatch(r"\s*range\s*\(\s*([^)]*)\s*\)\s*", s)
92 if not m:
93 raise ValueError
94 args = [t.strip() for t in m.group(1).split(",") if t.strip() != ""]
95 if len(args) not in (2, 3):
96 raise ValueError
97 a = int(_to_number(args[0]))
98 b = int(_to_number(args[1]))
99 step = int(_to_number(args[2])) if len(args) == 3 else 1
100 if step == 0:
101 raise ValueError
102 lo, hi = (a, b) if a <= b else (b, a)
103 # inclusive upper bound
104 seq = list(range(lo, hi + 1, abs(step)))
105 return tuple(seq)
108def _parse_colon_or_dots(s: str) -> slice:
109 """
110 Parse 'a:b', 'a..b', or 'a:b:step' → slice(start, stop, step).
111 - Works for int or float endpoints (we sort so start <= stop).
112 - Step is optional; if present it can be int or float.
113 - Any token with ':' (or '..') yields a slice (no tuples).
114 """
115 s_norm = re.sub(r"\.\.+", ":", s.strip())
116 parts = [p.strip() for p in s_norm.split(":")]
117 if len(parts) not in (2, 3):
118 raise ValueError(f"Not a range: {s!r}")
120 a_str, b_str = parts[0], parts[1]
121 a_num = _to_number(a_str)
122 b_num = _to_number(b_str)
123 if not isinstance(a_num, (int, float)) or not isinstance(b_num, (int, float)):
124 raise ValueError(f"Non-numeric range endpoints: {s!r}")
126 start, stop = (a_num, b_num)
127 if start > stop:
128 start, stop = stop, start # normalize to ascending
130 step = None
131 if len(parts) == 3 and parts[2] != "":
132 step_val = _to_number(parts[2])
133 if not isinstance(step_val, (int, float)) or step_val == 0:
134 raise ValueError(f"Invalid step in range: {s!r}")
135 step = step_val
137 return slice(start, stop, step)
139def _parse_constraint_value(text: str):
140 """
141 Convert a CLI string into one of:
142 - number (int/float) -> fixed
143 - slice(lo, hi) -> float range (inclusive ends)
144 - list/tuple -> choices (finite set)
145 - 'range(a,b[,s])' -> tuple of ints
146 """
147 raw = _strip_quotes(str(text))
149 # 1) range(...)
150 try:
151 return _parse_range_call(raw)
152 except Exception:
153 pass
155 # 2) bracketed/parenthesized list or plain comma-separated list
156 if (raw.startswith(("[", "(")) and raw.endswith(("]", ")"))) or ("," in raw and " " not in raw[:2]):
157 items = _parse_list_like(raw)
158 # coerce homogeneous int-like to tuple[int], else list[float]
159 if all(isinstance(v, int) for v in items):
160 return tuple(items)
161 return items
163 # 3) colon / dot ranges (a:b[:step] or a..b)
164 if ":" in raw or ".." in raw:
165 try:
166 return _parse_colon_or_dots(raw)
167 except Exception:
168 pass
170 # 4) scalar number
171 try:
172 return _to_number(raw)
173 except Exception:
174 return raw # fallback to string (rare; typically ignored downstream)
177def _parse_unknown_cli_kv_text(args: list[str]) -> dict[str, str]:
178 """
179 Extract unknown --key value pairs as raw strings (no coercion here).
180 Supports: --k=v and --k v. Repeated keys -> last wins.
181 """
182 out: dict[str, str] = {}
183 it = iter(args)
184 for tok in it:
185 if not tok.startswith("--"):
186 continue
187 key = tok[2:]
188 if "=" in key:
189 k, v = key.split("=", 1)
190 else:
191 k = key
192 try:
193 nxt = next(it)
194 except StopIteration:
195 nxt = "true"
196 if nxt.startswith("--"):
197 # treat as flag without value; put it back by ignoring and storing "true"
198 v = "true"
199 else:
200 v = nxt
201 out[k.strip().replace("-", "_")] = v
202 return out
204def _norm(s: str) -> str:
205 return re.sub(r"[^a-z0-9]+", "", s.lower())
208def _canonicalize_feature_keys(
209 model: xr.Dataset | Path, raw_map: dict[str, object]
210) -> tuple[dict[str, object], dict[str, str]]:
211 """
212 Map user keys (any style) to either:
213 - dataset *feature* names (numeric, including one-hot *member* names), OR
214 - *categorical base* names (e.g., 'language') detected from one-hot blocks.
216 Returns (mapped, alias). Unmatched keys are dropped with a warning.
218 Notes:
219 - Exact matches win.
220 - Then normalized matches for full feature names.
221 - Then normalized matches for categorical bases (so '--language "Linear A"' is preserved as 'language': 'Linear A').
222 """
223 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model)
224 features = [str(x) for x in ds["feature"].values.tolist()]
226 # Indexes for feature names and categorical bases
227 feature_norm_index = {_norm(f): f for f in features}
228 feature_set = set(features)
229 bases, base_norm_index = _categorical_bases_from_features(features)
231 mapped: dict[str, object] = {}
232 alias: dict[str, str] = {}
234 for k, v in (raw_map or {}).items():
235 # 1) Exact feature match
236 if k in feature_set:
237 mapped[k] = v
238 alias[k] = k
239 continue
241 nk = _norm(k)
243 # 2) Normalized feature match (full feature/member name)
244 if nk in feature_norm_index:
245 canonical = feature_norm_index[nk]
246 mapped[canonical] = v
247 alias[k] = canonical
248 continue
250 # 3) Categorical base match (exact or normalized)
251 if (k in bases) or (nk in base_norm_index):
252 base = k if k in bases else base_norm_index[nk]
253 mapped[base] = v
254 alias[k] = base
255 continue
257 raise typer.BadParameter(f"Unknown feature key: {k}")
259 return mapped, alias
263def parse_constraints_from_ctx(ctx: typer.Context, model: xr.Dataset | Path) -> dict[str, object]:
264 """
265 End-to-end: ctx.args → {key: constraint_object}.
267 Values can be:
268 - number (int/float) -> fixed
269 - slice(lo, hi) -> float range (inclusive ends)
270 - list/tuple -> finite choices
271 - tuple from range(...) (int choices)
272 - string -> categorical label (e.g., --language "Linear A")
273 """
274 raw_kv = _parse_unknown_cli_kv_text(ctx.args)
275 parsed: dict[str, object] = {k: _parse_constraint_value(v) for k, v in raw_kv.items()}
277 # Canonicalize keys to either feature names OR categorical bases
278 constraints, _ = _canonicalize_feature_keys(model, parsed)
280 # Normalize: convert range objects to tuples of ints (choices)
281 for k, v in list(constraints.items()):
282 if isinstance(v, range):
283 constraints[k] = tuple(v)
285 # Pretty print constraints
286 if constraints:
287 def _fmt_value(val: object) -> str:
288 if isinstance(val, slice):
289 # show start:stop (ignore step in preview)
290 lo = getattr(val, "start", None)
291 hi = getattr(val, "stop", None)
292 return f"[{lo},{hi}]"
293 if isinstance(val, (list, tuple, range)):
294 return f"{tuple(val)}"
295 if isinstance(val, str):
296 # quote strings if they have spaces or special chars
297 if re.search(r'\s|[,=:\.]', val):
298 return f'"{val}"'
299 return val
300 return str(val)
302 pretty = ", ".join(f"{k}={_fmt_value(constraints[k])}" for k in constraints)
303 console.print(f"[cyan]Constraints:[/] {pretty}")
305 return constraints
308@app.callback()
309def main(
310 version: bool = typer.Option(False, "--version", "-v", help="Show version and exit.", is_eager=True),
311):
312 if version:
313 console.print(f"[bold]psyop[/] {__version__}")
314 raise typer.Exit()
317@app.command(help="Fit the model on a CSV and save a single model artifact.")
318def model(
319 input: Path = typer.Argument(..., help="Input CSV file."),
320 output: Path = typer.Argument(..., help="Path to save model artifact [.psyop]."),
321 target: str = typer.Option("loss", "--target", "-t", help="Target column name."),
322 exclude: list[str] = typer.Option([], help="Feature columns to exclude."),
323 direction: Direction = typer.Option(
324 Direction.MINIMIZE, "--direction", "-d",
325 help="Optimization direction for the target."
326 ),
327 seed: int = typer.Option(0, "--seed", help="Random seed for fitting/sampling."),
328 compress: bool = typer.Option(True, help="Apply compression inside the artifact."),
329 prior_model: Path | None = typer.Option(None, help="Existing model artifact used to warm-start parameter optimization."),
330):
331 if not input.exists():
332 raise typer.BadParameter(f"Input CSV not found: {input.resolve()}")
333 if input.suffix.lower() != ".csv":
334 console.print(":warning: [yellow]Input does not end with .csv[/]")
335 if prior_model is not None and not prior_model.exists():
336 raise typer.BadParameter(f"Prior model artifact not found: {prior_model.resolve()}")
338 build_model(
339 input=input,
340 target=target,
341 output=output,
342 exclude=exclude,
343 direction=direction.value,
344 seed=seed,
345 compress=compress,
346 prior_model=prior_model,
347 )
348 console.print(f"[green]Wrote model artifact →[/] {output}")
351@app.command(
352 help="Suggest BO candidates (constrained EI + exploration).",
353 context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
354)
355def suggest(
356 ctx: typer.Context,
357 model: Path = typer.Argument(..., help="Path to the model artifact [.psyop]."),
358 output: Optional[Path] = typer.Option(None, "--output", "-o", help="Where to save candidates CSV (defaults relative to model)."),
359 count: int = typer.Option(1, "--count", "-k", help="Number of candidates to propose."),
360 success_threshold: float = typer.Option(0.8, help="Feasibility threshold for constrained EI."),
361 explore: float = typer.Option(0.34, help="Fraction of suggestions reserved for exploration."),
362 seed: int = typer.Option(0, help="Random seed for proposals."),
363):
364 if not model.exists():
365 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}")
367 model = xr.load_dataset(model)
368 constraints = parse_constraints_from_ctx(ctx, model)
370 opt.suggest(
371 model=model,
372 output=output,
373 count=count,
374 success_threshold=success_threshold,
375 explore=explore,
376 seed=seed,
377 **constraints,
378 )
379 if output:
380 console.print(f"[green]Wrote proposals →[/] {output}")
383@app.command(
384 help="Rank points by probability of being the best feasible minimum.",
385 context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
386)
387def optimal(
388 ctx: typer.Context,
389 model: Path = typer.Argument(..., help="Path to the model artifact [.psyop]."),
390 output: Path|None = typer.Option(None, help="Where to save top candidates CSV (defaults relative to model)."),
391 seed: int = typer.Option(0, help="Random seed for MC."),
392):
393 if not model.exists():
394 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}")
396 model = xr.load_dataset(model)
397 constraints = parse_constraints_from_ctx(ctx, model)
399 opt.optimal(
400 model=model,
401 output=output,
402 seed=seed,
403 **constraints,
404 )
405 if output:
406 console.print(f"[green]Wrote top probable minima →[/] {output}")
409@app.command(
410 help="Create a 2D Partial Dependence of Expected Target (Pairwise Features).",
411 context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
412)
413def plot2d(
414 ctx: typer.Context,
415 model: Path = typer.Argument(..., help="Path to the model artifact [.psyop]."),
416 output: Path|None = typer.Option(None, help="Output HTML (defaults relative to model)."),
417 grid_size: int = typer.Option(70, help="Grid size per axis for 2D panels."),
418 use_log_scale_for_target: bool = typer.Option(False, help="Log10 colors for target."),
419 log_shift_epsilon: float = typer.Option(1e-9, help="Epsilon shift for log colors."),
420 colorscale: str = typer.Option("RdBu", help="Colorscale name."),
421 show: bool|None = typer.Option(None, help="Open the figure in a browser."),
422 n_contours: int = typer.Option(12, help="Number of contour levels."),
423 optimal: bool = typer.Option(True, help="Include optimal points."),
424 suggest: int = typer.Option(0, help="Number of suggested points."),
425 width: int = typer.Option(1000, help="Width of each panel in pixels."),
426 height: int = typer.Option(1000, help="Height of each panel in pixels."),
427 seed: int = typer.Option(42, help="Random seed for suggested points."),
428):
429 if not model.exists():
430 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}")
432 model = xr.load_dataset(model)
433 constraints = parse_constraints_from_ctx(ctx, model)
435 show = show if show is not None else output is None # default to True if no output file
437 viz.plot2d(
438 model=model,
439 output=output,
440 grid_size=grid_size,
441 use_log_scale_for_target=use_log_scale_for_target,
442 log_shift_epsilon=log_shift_epsilon,
443 colorscale=colorscale,
444 show=show,
445 n_contours=n_contours,
446 optimal=optimal,
447 suggest=suggest,
448 width=width,
449 height=height,
450 seed=seed,
451 **constraints,
452 )
453 if output:
454 console.print(f"[green]Wrote pairplot →[/] {output}")
457@app.command(
458 help="Create 1D Partial Dependence panels.",
459 context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
460)
461def plot1d(
462 ctx: typer.Context,
463 model: Path = typer.Argument(..., help="Path to the model artifact [.psyop]."),
464 output: Path|None = typer.Option(None, help="Output HTML (defaults relative to model)."),
465 csv_out: Path|None = typer.Option(None, help="Optional CSV export of tidy PD data."),
466 grid_size: int = typer.Option(300, help="Points along 1D sweep."),
467 line_color: str = typer.Option("blue", help="Line/band color (consistent across variables)."),
468 band_alpha: float = typer.Option(0.25, help="Fill alpha for ±2σ."),
469 show: bool|None = typer.Option(None, help="Open the figure in a browser."),
470 use_log_scale_for_target_y: bool = typer.Option(True, "--log-y/--no-log-y", help="Log scale for target (Y)."),
471 log_y_epsilon: float = typer.Option(1e-9, "--log-y-eps", help="Clamp for log-Y."),
472 optimal: bool = typer.Option(True, help="Include optimal points."),
473 suggest: int = typer.Option(0, help="Number of suggested points."),
474 width: int = typer.Option(1000, help="Width of each panel in pixels."),
475 height: int = typer.Option(1000, help="Height of each panel in pixels."),
476 seed: int = typer.Option(42, help="Random seed for suggested points."),
477):
478 if not model.exists():
479 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}")
481 show = show if show is not None else output is None # default to True if no output file
483 model = xr.load_dataset(model)
484 constraints = parse_constraints_from_ctx(ctx, model)
486 viz.plot1d(
487 model=model,
488 output=output,
489 csv_out=csv_out,
490 grid_size=grid_size,
491 line_color=line_color,
492 band_alpha=band_alpha,
493 show=show,
494 use_log_scale_for_target_y=use_log_scale_for_target_y,
495 log_y_epsilon=log_y_epsilon,
496 optimal=optimal,
497 suggest=suggest,
498 width=width,
499 height=height,
500 seed=seed,
501 **constraints,
502 )
503 if output:
504 console.print(f"[green]Wrote PD HTML →[/] {output}")
505 if csv_out:
506 console.print(f"[green]Wrote PD CSV →[/] {csv_out}")
509@app.command(
510 help="Create 1D Partial Dependence panels.",
511 context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
512)
513def optimum_plot1d(
514 ctx: typer.Context,
515 model: Path = typer.Argument(..., help="Path to the model artifact [.psyop]."),
516 output: Path|None = typer.Option(None, help="Output HTML (defaults relative to model)."),
517 csv_out: Path|None = typer.Option(None, help="Optional CSV export of tidy PD data."),
518 grid_size: int = typer.Option(300, help="Points along 1D sweep."),
519 line_color: str = typer.Option("blue", help="Line/band color (consistent across variables)."),
520 band_alpha: float = typer.Option(0.25, help="Fill alpha for ±2σ."),
521 show: bool|None = typer.Option(None, help="Open the figure in a browser."),
522 use_log_scale_for_target_y: bool = typer.Option(True, "--log-y/--no-log-y", help="Log scale for target (Y)."),
523 log_y_epsilon: float = typer.Option(1e-9, "--log-y-eps", help="Clamp for log-Y."),
524 optimal: bool = typer.Option(True, help="Include optimal points."),
525 suggest: int = typer.Option(0, help="Number of suggested points."),
526 width: int = typer.Option(1000, help="Width of each panel in pixels."),
527 height: int = typer.Option(1000, help="Height of each panel in pixels."),
528 seed: int = typer.Option(42, help="Random seed for suggested points."),
529):
530 if not model.exists():
531 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}")
533 show = show if show is not None else output is None # default to True if no output file
535 model = xr.load_dataset(model)
536 constraints = parse_constraints_from_ctx(ctx, model)
538 viz.optimum_plot1d(
539 model=model,
540 output=output,
541 csv_out=csv_out,
542 grid_size=grid_size,
543 line_color=line_color,
544 band_alpha=band_alpha,
545 show=show,
546 use_log_scale_for_target_y=use_log_scale_for_target_y,
547 log_y_epsilon=log_y_epsilon,
548 optimal=optimal,
549 suggest=suggest,
550 width=width,
551 height=height,
552 seed=seed,
553 **constraints,
554 )
555 if output:
556 console.print(f"[green]Wrote PD HTML →[/] {output}")
557 if csv_out:
558 console.print(f"[green]Wrote PD CSV →[/] {csv_out}")
561@app.command(
562 help="Create 2D Partial Dependence panels anchored at the optimum.",
563 context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
564)
565def optimum_plot2d(
566 ctx: typer.Context,
567 model: Path = typer.Argument(..., help="Path to the model artifact [.psyop]."),
568 output: Path | None = typer.Option(None, help="Output HTML (defaults relative to model)."),
569 grid_size: int = typer.Option(70, help="Grid size per axis for 2D panels."),
570 use_log_scale_for_target: bool = typer.Option(False, help="Log10 colors for target."),
571 log_shift_epsilon: float = typer.Option(1e-9, help="Epsilon shift for log colors."),
572 colorscale: str = typer.Option("RdBu", help="Colorscale name."),
573 show: bool | None = typer.Option(None, help="Open the figure in a browser."),
574 n_contours: int = typer.Option(12, help="Number of contour levels."),
575 optimal: bool = typer.Option(True, help="Show optimal point overlay."),
576 suggest: int = typer.Option(0, help="Number of suggested points."),
577 width: int = typer.Option(1000, help="Width of each panel in pixels."),
578 height: int = typer.Option(1000, help="Height of each panel in pixels."),
579 seed: int = typer.Option(42, help="Random seed for suggested points."),
580):
581 if not model.exists():
582 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}")
584 show = show if show is not None else output is None
586 model_ds = xr.load_dataset(model)
587 constraints = parse_constraints_from_ctx(ctx, model_ds)
589 viz.optimum_plot2d(
590 model=model_ds,
591 output=output,
592 grid_size=grid_size,
593 use_log_scale_for_target=use_log_scale_for_target,
594 log_shift_epsilon=log_shift_epsilon,
595 colorscale=colorscale,
596 show=show,
597 n_contours=n_contours,
598 optimal=optimal,
599 suggest=suggest,
600 width=width,
601 height=height,
602 seed=seed,
603 **constraints,
604 )
605 if output:
606 console.print(f"[green]Wrote PD HTML →[/] {output}")
609@app.command(
610 help="Export CSV of data used to create the model.",
611)
612def export(
613 model: Path = typer.Argument(..., help="Path to the model artifact [.psyop]."),
614 output: Path = typer.Argument(..., help="Output CSV of data used to create model."),
615):
616 if not model.exists():
617 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}")
618 model = xr.load_dataset(model)
619 vars_row = [v for v, da in model.data_vars.items() if "row" in da.dims and "feature" not in da.dims]
620 df = model[vars_row].to_dataframe().reset_index()
621 df = df.drop(
622 columns=['row', 'y_success', 'success_mask','pred_success_mu_train', 'pred_success_var_train'],
623 errors='ignore',
624 )
625 df.to_csv(output, index=False)
628def _categorical_bases_from_features(features: list[str]) -> tuple[set[str], dict[str, str]]:
629 """
630 Given model feature names (which may include one-hot members like 'language=Linear A'),
631 return:
632 - bases: a set of base names, e.g. {'language'}
633 - base_norm_index: mapping from normalized base name -> canonical base string
634 """
635 bases: set[str] = set()
636 for f in features:
637 if "=" in f:
638 base = f.split("=", 1)[0].strip()
639 if base:
640 bases.add(base)
641 # normalized index for lookup
642 base_norm_index = {_norm(b): b for b in bases}
643 return bases, base_norm_index
646if __name__ == "__main__":
647 app()