Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from pathlib import Path
3try:
4 import skopt
5 from skopt.space.space import Real, Integer, Categorical
6 from skopt.callbacks import CheckpointSaver
7 from skopt.plots import plot_convergence, plot_evaluations, plot_objective
8except:
9 raise Exception(
10 "No module named 'skopt'. Please install this as an extra dependency or choose a different optimization engine."
11 )
13from ..util import call_func
16def get_optimizer(method):
17 method = method.lower()
18 if method.startswith("bayes") or method.startswith("gp") or not method:
19 return skopt.gp_minimize
20 elif method.startswith("random"):
21 return skopt.dummy_minimize
22 elif method.startswith("forest"):
23 return skopt.forest_minimize
24 elif method.startswith("gbrt") or method.startswith("gradientboost"):
25 return skopt.gbrt_minimize
26 raise NotImplementedError(f"Cannot interpret sampling method '{method}' using scikit-optimize.")
29def get_param_search_space(param, name):
30 if param.tune_choices:
31 return Categorical(categories=param.tune_choices, name=name)
33 prior = "uniform" if not param.tune_log else "log-uniform"
34 if param.annotation == float:
35 return Real(param.tune_min, param.tune_max, prior=prior, name=name)
37 if param.annotation == int:
38 return Integer(param.tune_min, param.tune_max, prior=prior, name=name)
40 raise NotImplementedError("scikit-optimize tuning engine cannot understand param '{name}': {param}")
43class SkoptPlot(object):
44 """
45 Save current state after each iteration with :class:`skopt.dump`.
46 """
47 def __init__(self, path:Path, format):
48 self.path = Path(path)
49 self.path.mkdir(parents=True, exist_ok=True)
50 self.format = format
52 def __call__(self, result):
53 import matplotlib.pyplot as plt
54 import matplotlib
55 matplotlib.use('Agg')
57 plot_convergence(result)
58 plt.savefig(str(self.path/f"convergence.{self.format}"), format=self.format)
60 plot_evaluations(result)
61 plt.savefig(str(self.path/f"evaluations.{self.format}"), format=self.format)
63 if result.models:
64 plot_objective(result)
65 plt.savefig(str(self.path/f"objective.{self.format}"), format=self.format)
68class SkoptObjective():
69 def __init__(self, app, kwargs, used_tuning_params, name, base_output_dir):
70 self.app = app
71 self.kwargs = kwargs
72 self.used_tuning_params = used_tuning_params
73 self.name = name
74 self.base_output_dir = Path(base_output_dir)
76 def __call__(self, *args):
77 run_kwargs = dict(self.kwargs)
79 for key, value in zip(self.used_tuning_params.keys(), *args):
80 run_kwargs[key] = value
82 run_number = 0
83 while True:
84 trial_name = f"trial-{run_number}"
85 output_dir = self.base_output_dir / trial_name
86 if not output_dir.exists():
87 break
88 run_number += 1
90 run_kwargs["output_dir"] = output_dir
91 run_kwargs["project_name"] = self.name
92 run_kwargs["run_name"] = trial_name
94 # Train
95 learner = call_func(self.app.train, **run_kwargs)
96 metric = self.app.get_best_metric(learner)
98 # make negative if the goal is to maximize this metric
99 if self.app.goal()[:3] != "min":
100 metric = -metric
102 return metric
104 def __deepcopy__(self, memo):
105 """ Returns None for deepcopy because this shouldn't be copied into the checkpoint. """
106 return None
109def skopt_tune(
110 app,
111 file: str = "",
112 name: str = None,
113 method: str = "", # Should be enum
114 runs: int = 1,
115 seed: int = None,
116 **kwargs,
117):
119 # Get tuning parameters
120 tuning_params = app.tuning_params()
121 used_tuning_params = {}
122 for key, value in tuning_params.items():
123 if key not in kwargs or kwargs[key] is None:
124 used_tuning_params[key] = value
126 # Get search space
127 search_space = [get_param_search_space(param, name=key) for key, param in used_tuning_params.items()]
129 optimizer = get_optimizer(method)
131 if not name:
132 name = f"{app.project_name()}-tuning"
133 base_output_dir = Path(kwargs.get("output_dir", ".")) / name
135 optimizer_kwargs = dict(n_calls=runs, random_state=seed, callback=[])
137 if False:
138 optimizer_kwargs["callback"].append(SkoptPlot(base_output_dir, "svg"))
140 if file:
141 file = Path(file)
142 # if a file is given, first try to read from that file the results and then use it as a checkpoint
143 # https://scikit-optimize.github.io/stable/auto_examples/interruptible-optimization.html
144 if file.exists():
145 try:
146 checkpoint = skopt.load(file)
147 x0 = checkpoint.x_iters
148 y0 = checkpoint.func_vals
149 optimizer_kwargs['x0'] = x0
150 optimizer_kwargs['y0'] = y0
151 except Exception as e:
152 raise IOError(f"Cannot read scikit-optimize checkpoint file '{file}': {e}")
154 checkpoint_saver = CheckpointSaver(str(file), compress=9, store_objective=False)
155 optimizer_kwargs['callback'].append( checkpoint_saver )
157 objective = SkoptObjective(app, kwargs, used_tuning_params, name, base_output_dir)
158 results = optimizer(objective, search_space, **optimizer_kwargs)
160 return results