Coverage for psyop/main.py: 67.65%
238 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-10 06:02 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-10 06:02 +0000
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4# Make BLAS single-threaded to avoid oversubscription / macOS crashes
5import os
6for _env_var in (
7 "MKL_NUM_THREADS",
8 "OPENBLAS_NUM_THREADS",
9 "OMP_NUM_THREADS",
10 "VECLIB_MAXIMUM_THREADS",
11 "NUMEXPR_NUM_THREADS",
12):
13 os.environ.setdefault(_env_var, "1")
15import re
16from enum import Enum
17from pathlib import Path
18from typing import Optional
19import xarray as xr
20import typer
21from rich.console import Console
23from .model import build_model
24from . import viz, opt
26__version__ = "0.1.0"
28console = Console()
29app = typer.Typer(no_args_is_help=True, add_completion=False, rich_markup_mode="rich")
32class Direction(str, Enum):
33 MINIMIZE = "min"
34 MAXIMIZE = "max"
35 AUTO = "auto"
38def _strip_quotes(s: str) -> str:
39 s = s.strip()
40 if (s.startswith("'") and s.endswith("'")) or (s.startswith('"') and s.endswith('"')):
41 return s[1:-1]
42 return s
44_num_re = re.compile(
45 r"""
46 ^\s*
47 ([+-]? # sign
48 (?:
49 (?:\d+(?:\.\d*)?|\.\d+) # 123, 123., .123, 123.456
50 (?:[eE][+-]?\d+)? # optional exponent
51 )
52 )
53 \s*$
54 """,
55 re.VERBOSE,
56)
58def _is_intlike_str(s: str) -> bool:
59 try:
60 f = float(s)
61 return float(int(round(f))) == f
62 except Exception:
63 return False
65def _to_number(s: str):
66 """Return int if int-like, else float, else raise."""
67 if not _num_re.match(s):
68 raise ValueError
69 v = float(s)
70 return int(round(v)) if _is_intlike_str(s) else v
72def _parse_list_like(s: str) -> list:
73 # Accept comma-separated, optionally wrapped in [] or ().
74 s = s.strip()
75 if s.startswith(("(", "[")) and s.endswith((")", "]")):
76 s = s[1:-1]
77 parts = [p.strip() for p in s.split(",") if p.strip() != ""]
78 out = []
79 for p in parts:
80 p = _strip_quotes(p)
81 try:
82 out.append(_to_number(p))
83 except Exception:
84 out.append(p) # leave as raw string if not numeric
85 return out
87def _parse_range_call(s: str) -> tuple:
88 """
89 Parse 'range(a,b[,step])' inclusive on the upper bound for ints.
90 Returns a tuple of ints.
91 """
92 m = re.fullmatch(r"\s*range\s*\(\s*([^)]*)\s*\)\s*", s)
93 if not m:
94 raise ValueError
95 args = [t.strip() for t in m.group(1).split(",") if t.strip() != ""]
96 if len(args) not in (2, 3):
97 raise ValueError
98 a = int(_to_number(args[0]))
99 b = int(_to_number(args[1]))
100 step = int(_to_number(args[2])) if len(args) == 3 else 1
101 if step == 0:
102 raise ValueError
103 lo, hi = (a, b) if a <= b else (b, a)
104 # inclusive upper bound
105 seq = list(range(lo, hi + 1, abs(step)))
106 return tuple(seq)
109def _parse_colon_or_dots(s: str) -> slice:
110 """
111 Parse 'a:b', 'a..b', or 'a:b:step' → slice(start, stop, step).
112 - Works for int or float endpoints (we sort so start <= stop).
113 - Step is optional; if present it can be int or float.
114 - Any token with ':' (or '..') yields a slice (no tuples).
115 """
116 s_norm = re.sub(r"\.\.+", ":", s.strip())
117 parts = [p.strip() for p in s_norm.split(":")]
118 if len(parts) not in (2, 3):
119 raise ValueError(f"Not a range: {s!r}")
121 a_str, b_str = parts[0], parts[1]
122 a_num = _to_number(a_str)
123 b_num = _to_number(b_str)
124 if not isinstance(a_num, (int, float)) or not isinstance(b_num, (int, float)):
125 raise ValueError(f"Non-numeric range endpoints: {s!r}")
127 start, stop = (a_num, b_num)
128 if start > stop:
129 start, stop = stop, start # normalize to ascending
131 step = None
132 if len(parts) == 3 and parts[2] != "":
133 step_val = _to_number(parts[2])
134 if not isinstance(step_val, (int, float)) or step_val == 0:
135 raise ValueError(f"Invalid step in range: {s!r}")
136 step = step_val
138 return slice(start, stop, step)
140def _parse_constraint_value(text: str):
141 """
142 Convert a CLI string into one of:
143 - number (int/float) -> fixed
144 - slice(lo, hi) -> float range (inclusive ends)
145 - list/tuple -> choices (finite set)
146 - 'range(a,b[,s])' -> tuple of ints
147 """
148 raw = _strip_quotes(str(text))
150 # 1) range(...)
151 try:
152 return _parse_range_call(raw)
153 except Exception:
154 pass
156 # 2) bracketed/parenthesized list or plain comma-separated list
157 if (raw.startswith(("[", "(")) and raw.endswith(("]", ")"))) or ("," in raw and " " not in raw[:2]):
158 items = _parse_list_like(raw)
159 # coerce homogeneous int-like to tuple[int], else list[float]
160 if all(isinstance(v, int) for v in items):
161 return tuple(items)
162 return items
164 # 3) colon / dot ranges (a:b[:step] or a..b)
165 if ":" in raw or ".." in raw:
166 try:
167 return _parse_colon_or_dots(raw)
168 except Exception:
169 pass
171 # 4) scalar number
172 try:
173 return _to_number(raw)
174 except Exception:
175 return raw # fallback to string (rare; typically ignored downstream)
178def _parse_unknown_cli_kv_text(args: list[str]) -> dict[str, str]:
179 """
180 Extract unknown --key value pairs as raw strings (no coercion here).
181 Supports: --k=v and --k v. Repeated keys -> last wins.
182 """
183 out: dict[str, str] = {}
184 it = iter(args)
185 for tok in it:
186 if not tok.startswith("--"):
187 continue
188 key = tok[2:]
189 if "=" in key:
190 k, v = key.split("=", 1)
191 else:
192 k = key
193 try:
194 nxt = next(it)
195 except StopIteration:
196 nxt = "true"
197 if nxt.startswith("--"):
198 # treat as flag without value; put it back by ignoring and storing "true"
199 v = "true"
200 else:
201 v = nxt
202 out[k.strip().replace("-", "_")] = v
203 return out
205def _norm(s: str) -> str:
206 return re.sub(r"[^a-z0-9]+", "", s.lower())
209def _canonicalize_feature_keys(
210 model: xr.Dataset | Path, raw_map: dict[str, object]
211) -> tuple[dict[str, object], dict[str, str]]:
212 """
213 Map user keys (any style) to either:
214 - dataset *feature* names (numeric, including one-hot *member* names), OR
215 - *categorical base* names (e.g., 'language') detected from one-hot blocks.
217 Returns (mapped, alias). Unmatched keys are dropped with a warning.
219 Notes:
220 - Exact matches win.
221 - Then normalized matches for full feature names.
222 - Then normalized matches for categorical bases (so '--language "Linear A"' is preserved as 'language': 'Linear A').
223 """
224 ds = model if isinstance(model, xr.Dataset) else xr.load_dataset(model)
225 features = [str(x) for x in ds["feature"].values.tolist()]
227 # Indexes for feature names and categorical bases
228 feature_norm_index = {_norm(f): f for f in features}
229 feature_set = set(features)
230 bases, base_norm_index = _categorical_bases_from_features(features)
232 mapped: dict[str, object] = {}
233 alias: dict[str, str] = {}
235 for k, v in (raw_map or {}).items():
236 # 1) Exact feature match
237 if k in feature_set:
238 mapped[k] = v
239 alias[k] = k
240 continue
242 nk = _norm(k)
244 # 2) Normalized feature match (full feature/member name)
245 if nk in feature_norm_index:
246 canonical = feature_norm_index[nk]
247 mapped[canonical] = v
248 alias[k] = canonical
249 continue
251 # 3) Categorical base match (exact or normalized)
252 if (k in bases) or (nk in base_norm_index):
253 base = k if k in bases else base_norm_index[nk]
254 mapped[base] = v
255 alias[k] = base
256 continue
258 raise typer.BadParameter(f"Unknown feature key: {k}")
260 return mapped, alias
264def parse_constraints_from_ctx(ctx: typer.Context, model: xr.Dataset | Path) -> dict[str, object]:
265 """
266 End-to-end: ctx.args → {key: constraint_object}.
268 Values can be:
269 - number (int/float) -> fixed
270 - slice(lo, hi) -> float range (inclusive ends)
271 - list/tuple -> finite choices
272 - tuple from range(...) (int choices)
273 - string -> categorical label (e.g., --language "Linear A")
274 """
275 raw_kv = _parse_unknown_cli_kv_text(ctx.args)
276 parsed: dict[str, object] = {k: _parse_constraint_value(v) for k, v in raw_kv.items()}
278 # Canonicalize keys to either feature names OR categorical bases
279 constraints, _ = _canonicalize_feature_keys(model, parsed)
281 # Normalize: convert range objects to tuples of ints (choices)
282 for k, v in list(constraints.items()):
283 if isinstance(v, range):
284 constraints[k] = tuple(v)
286 # Pretty print constraints
287 if constraints:
288 def _fmt_value(val: object) -> str:
289 if isinstance(val, slice):
290 # show start:stop (ignore step in preview)
291 lo = getattr(val, "start", None)
292 hi = getattr(val, "stop", None)
293 return f"[{lo},{hi}]"
294 if isinstance(val, (list, tuple, range)):
295 return f"{tuple(val)}"
296 if isinstance(val, str):
297 # quote strings if they have spaces or special chars
298 if re.search(r'\s|[,=:\.]', val):
299 return f'"{val}"'
300 return val
301 return str(val)
303 pretty = ", ".join(f"{k}={_fmt_value(constraints[k])}" for k in constraints)
304 console.print(f"[cyan]Constraints:[/] {pretty}")
306 return constraints
309@app.callback()
310def main(
311 version: bool = typer.Option(False, "--version", "-v", help="Show version and exit.", is_eager=True),
312):
313 if version:
314 console.print(f"[bold]psyop[/] {__version__}")
315 raise typer.Exit()
318@app.command(help="Fit the model on a CSV and save a single model artifact.")
319def model(
320 input: Path = typer.Argument(..., help="Input CSV file."),
321 output: Path = typer.Argument(..., help="Path to save model artifact (.nc)."),
322 target: str = typer.Option("loss", "--target", "-t", help="Target column name."),
323 exclude: list[str] = typer.Option([], "--exclude", help="Feature columns to exclude."),
324 direction: Direction = typer.Option(
325 Direction.AUTO, "--direction", "-d",
326 help="Optimization direction for the target."
327 ),
328 seed: int = typer.Option(0, "--seed", help="Random seed for fitting/sampling."),
329 compress: bool = typer.Option(True, help="Apply compression inside the artifact."),
330):
331 if not input.exists():
332 raise typer.BadParameter(f"Input CSV not found: {input.resolve()}")
333 if input.suffix.lower() != ".csv":
334 console.print(":warning: [yellow]Input does not end with .csv[/]")
336 build_model(
337 input=input,
338 target=target,
339 output=output,
340 exclude=exclude,
341 direction=direction.value,
342 seed=seed,
343 compress=compress,
344 )
345 console.print(f"[green]Wrote model artifact →[/] {output}")
348@app.command(
349 help="Suggest BO candidates (constrained EI + exploration).",
350 context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
351)
352def suggest(
353 ctx: typer.Context,
354 model: Path = typer.Argument(..., help="Path to the model artifact (.nc)."),
355 output: Optional[Path] = typer.Option(None, "--output", "-o", help="Where to save candidates CSV (defaults relative to model)."),
356 count: int = typer.Option(1, "--count", "-k", help="Number of candidates to propose."),
357 success_threshold: float = typer.Option(0.8, help="Feasibility threshold for constrained EI."),
358 explore: float = typer.Option(0.34, help="Fraction of suggestions reserved for exploration."),
359 # candidates: int = typer.Option(5000, help="Random candidate pool size to score."),
360 seed: int = typer.Option(0, "--seed", help="Random seed for proposals."),
361):
362 if not model.exists():
363 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}")
365 model = xr.load_dataset(model)
366 constraints = parse_constraints_from_ctx(ctx, model)
368 opt.suggest(
369 model=model,
370 output=output,
371 count=count,
372 success_threshold=success_threshold,
373 explore=explore,
374 # candidates=candidates,
375 seed=seed,
376 **constraints,
377 )
378 if output:
379 console.print(f"[green]Wrote proposals →[/] {output}")
382@app.command(
383 help="Rank points by probability of being the best feasible minimum.",
384 context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
385)
386def optimal(
387 ctx: typer.Context,
388 model: Path = typer.Argument(..., help="Path to the model artifact (.nc)."),
389 output: Path|None = typer.Option(None, help="Where to save top candidates CSV (defaults relative to model)."),
390 count: int = typer.Option(10, "--count", "-k", help="How many top rows to keep."),
391 draws: int = typer.Option(2000, "--draws", help="Monte Carlo draws."),
392 min_success_probability: float = typer.Option(0.0, "--min-p-success", help="Hard feasibility cutoff (0 disables)."),
393 seed: int = typer.Option(0, "--seed", help="Random seed for MC."),
394):
395 if not model.exists():
396 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}")
398 model = xr.load_dataset(model)
399 constraints = parse_constraints_from_ctx(ctx, model)
401 opt.optimal(
402 model=model,
403 output=output,
404 # count=count,
405 # n_draws=draws,
406 # min_success_probability=min_success_probability,
407 seed=seed,
408 **constraints,
409 )
410 if output:
411 console.print(f"[green]Wrote top probable minima →[/] {output}")
414@app.command(
415 help="Create a 2D Partial Dependence of Expected Target (Pairwise Features).",
416 context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
417)
418def plot2d(
419 ctx: typer.Context,
420 model: Path = typer.Argument(..., help="Path to the model artifact (.nc)."),
421 output: Path|None = typer.Option(None, help="Output HTML (defaults relative to model)."),
422 grid_size: int = typer.Option(70, help="Grid size per axis for 2D panels."),
423 use_log_scale_for_target: bool = typer.Option(False, help="Log10 colors for target."),
424 log_shift_epsilon: float = typer.Option(1e-9, help="Epsilon shift for log colors."),
425 colorscale: str = typer.Option("RdBu", help="Colorscale name."),
426 show: bool|None = typer.Option(None, help="Open the figure in a browser."),
427 n_contours: int = typer.Option(12, help="Number of contour levels."),
428 optimal: bool = typer.Option(True, help="Include optimal points."),
429 suggest: int = typer.Option(0, help="Number of suggested points."),
430 width: int|None = typer.Option(None, help="Width of each panel in pixels (default auto)."),
431 height: int|None = typer.Option(None, help="Height of each panel in pixels (default auto)."),
432 seed: int = typer.Option(42, help="Random seed for suggested points."),
433):
434 if not model.exists():
435 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}")
437 model = xr.load_dataset(model)
438 constraints = parse_constraints_from_ctx(ctx, model)
440 show = show if show is not None else output is None # default to True if no output file
442 viz.plot2d(
443 model=model,
444 output=output,
445 grid_size=grid_size,
446 use_log_scale_for_target=use_log_scale_for_target,
447 log_shift_epsilon=log_shift_epsilon,
448 colorscale=colorscale,
449 show=show,
450 n_contours=n_contours,
451 optimal=optimal,
452 suggest=suggest,
453 width=width,
454 height=height,
455 seed=seed,
456 **constraints,
457 )
458 if output:
459 console.print(f"[green]Wrote pairplot →[/] {output}")
462@app.command(
463 help="Create 1D Partial Dependence panels.",
464 context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
465)
466def plot1d(
467 ctx: typer.Context,
468 model: Path = typer.Argument(..., help="Path to the model artifact (.nc)."),
469 output: Path|None = typer.Option(None, help="Output HTML (defaults relative to model)."),
470 csv_out: Path|None = typer.Option(None, help="Optional CSV export of tidy PD data."),
471 grid_size: int = typer.Option(300, help="Points along 1D sweep."),
472 line_color: str = typer.Option("rgb(31,119,180)", help="Line/band color (consistent across variables)."),
473 band_alpha: float = typer.Option(0.25, help="Fill alpha for ±2σ."),
474 figure_height_per_row_px: int = typer.Option(320, help="Pixels per PD row."),
475 show: bool|None = typer.Option(None, help="Open the figure in a browser."),
476 use_log_scale_for_target_y: bool = typer.Option(True, "--log-y/--no-log-y", help="Log scale for target (Y)."),
477 log_y_epsilon: float = typer.Option(1e-9, "--log-y-eps", help="Clamp for log-Y."),
478 optimal: bool = typer.Option(True, help="Include optimal points."),
479 suggest: int = typer.Option(0, help="Number of suggested points."),
480 width: int|None = typer.Option(None, help="Width of each panel in pixels (default auto)."),
481 height: int|None = typer.Option(None, help="Height of each panel in pixels (default auto)."),
482 seed: int = typer.Option(42, help="Random seed for suggested points."),
483):
484 if not model.exists():
485 raise typer.BadParameter(f"Model artifact not found: {model.resolve()}")
487 show = show if show is not None else output is None # default to True if no output file
489 model = xr.load_dataset(model)
490 constraints = parse_constraints_from_ctx(ctx, model)
492 viz.plot1d(
493 model=model,
494 output=output,
495 csv_out=csv_out,
496 grid_size=grid_size,
497 line_color=line_color,
498 band_alpha=band_alpha,
499 figure_height_per_row_px=figure_height_per_row_px,
500 show=show,
501 use_log_scale_for_target_y=use_log_scale_for_target_y,
502 log_y_epsilon=log_y_epsilon,
503 optimal=optimal,
504 suggest=suggest,
505 width=width,
506 height=height,
507 seed=seed,
508 **constraints,
509 )
510 if output:
511 console.print(f"[green]Wrote PD HTML →[/] {output}")
512 if csv_out:
513 console.print(f"[green]Wrote PD CSV →[/] {csv_out}")
516def _categorical_bases_from_features(features: list[str]) -> tuple[set[str], dict[str, str]]:
517 """
518 Given model feature names (which may include one-hot members like 'language=Linear A'),
519 return:
520 - bases: a set of base names, e.g. {'language'}
521 - base_norm_index: mapping from normalized base name -> canonical base string
522 """
523 bases: set[str] = set()
524 for f in features:
525 if "=" in f:
526 base = f.split("=", 1)[0].strip()
527 if base:
528 bases.add(base)
529 # normalized index for lookup
530 base_norm_index = {_norm(b): b for b in bases}
531 return bases, base_norm_index
534if __name__ == "__main__":
535 app()