Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from pathlib import Path
3try:
4 import optuna
5 from optuna.integration import FastAIV2PruningCallback
6 from optuna import samplers
7except:
8 raise Exception(
9 "No module named 'optuna'. Please install this as an extra dependency or choose a different optimization engine."
10 )
12from ..util import call_func
15def get_sampler(method, seed=0):
16 method = method.lower()
17 if method.startswith("tpe") or not method:
18 return samplers.TPESampler(seed=seed)
19 elif method.startswith("cma"):
20 return samplers.CmaEsSampler(seed=seed)
21 # elif method.startswith("grid"):
22 # return samplers.GridSampler()
23 elif method.startswith("random"):
24 return samplers.RandomSampler(seed=seed)
26 raise NotImplementedError(f"Cannot interpret sampling method '{method}' using Optuna.")
29def suggest(trial, name, param):
30 if param.tune_choices:
31 return trial.suggest_categorical(name, param.tune_choices)
32 elif param.annotation == float:
33 return trial.suggest_float(name, param.tune_min, param.tune_max, log=param.tune_log)
34 elif param.annotation == int:
35 return trial.suggest_int(name, param.tune_min, param.tune_max, log=param.tune_log)
37 raise NotImplementedError("Optuna Tuning Engine cannot understand param '{name}': {param}")
40def optuna_tune(
41 app,
42 storage: str = "",
43 name: str = None,
44 method: str = "tpe", # Should be enum
45 runs: int = 1,
46 seed: int = None,
47 **kwargs,
48):
49 output_dir = Path(kwargs.get("output_dir", "."))
50 output_dir.mkdir(parents=True, exist_ok=True)
52 def objective(trial: optuna.Trial):
53 run_kwargs = dict(kwargs)
55 tuning_params = app.tuning_params()
57 for key, value in tuning_params.items():
58 if key not in kwargs or kwargs[key] is None:
59 run_kwargs[key] = suggest(trial, key, value)
61 trial_name = f"trial-{trial.number}"
63 output_dir = Path(run_kwargs.get("output_dir", "."))
64 run_kwargs["output_dir"] = output_dir / trial.study.study_name / trial_name
65 run_kwargs["project_name"] = trial.study.study_name
66 run_kwargs["run_name"] = trial_name
68 # Train
69 learner = call_func(app.train, **run_kwargs)
71 # Return metric from recorder
72 return app.get_best_metric(learner)
74 if not storage:
75 storage = None
76 elif "://" not in storage:
77 storage_path = output_dir/f"{storage}.sqlite3"
78 storage = f"sqlite:///{storage_path.resolve()}"
80 study = optuna.create_study(
81 study_name=name,
82 storage=storage,
83 sampler=get_sampler(method, seed=seed),
84 load_if_exists=True,
85 direction=app.goal(),
86 )
87 study.optimize(objective, n_trials=runs)
88 return study