Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from pathlib import Path 

2 

3try: 

4 import skopt 

5 from skopt.space.space import Real, Integer, Categorical 

6 from skopt.callbacks import CheckpointSaver 

7 from skopt.plots import plot_convergence, plot_evaluations, plot_objective 

8except: 

9 raise Exception( 

10 "No module named 'skopt'. Please install this as an extra dependency or choose a different optimization engine." 

11 ) 

12 

13from ..util import call_func 

14 

15 

16def get_optimizer(method): 

17 method = method.lower() 

18 if method.startswith("bayes") or method.startswith("gp") or not method: 

19 return skopt.gp_minimize 

20 elif method.startswith("random"): 

21 return skopt.dummy_minimize 

22 elif method.startswith("forest"): 

23 return skopt.forest_minimize 

24 elif method.startswith("gbrt") or method.startswith("gradientboost"): 

25 return skopt.gbrt_minimize 

26 raise NotImplementedError(f"Cannot interpret sampling method '{method}' using scikit-optimize.") 

27 

28 

29def get_param_search_space(param, name): 

30 if param.tune_choices: 

31 return Categorical(categories=param.tune_choices, name=name) 

32 

33 prior = "uniform" if not param.tune_log else "log-uniform" 

34 if param.annotation == float: 

35 return Real(param.tune_min, param.tune_max, prior=prior, name=name) 

36 

37 if param.annotation == int: 

38 return Integer(param.tune_min, param.tune_max, prior=prior, name=name) 

39 

40 raise NotImplementedError("scikit-optimize tuning engine cannot understand param '{name}': {param}") 

41 

42 

43class SkoptPlot(object): 

44 """ 

45 Save current state after each iteration with :class:`skopt.dump`. 

46 """ 

47 def __init__(self, path:Path, format): 

48 self.path = Path(path) 

49 self.path.mkdir(parents=True, exist_ok=True) 

50 self.format = format 

51 

52 def __call__(self, result): 

53 import matplotlib.pyplot as plt 

54 import matplotlib 

55 matplotlib.use('Agg') 

56 

57 plot_convergence(result) 

58 plt.savefig(str(self.path/f"convergence.{self.format}"), format=self.format) 

59 

60 plot_evaluations(result) 

61 plt.savefig(str(self.path/f"evaluations.{self.format}"), format=self.format) 

62 

63 if result.models: 

64 plot_objective(result) 

65 plt.savefig(str(self.path/f"objective.{self.format}"), format=self.format) 

66 

67 

68class SkoptObjective(): 

69 def __init__(self, app, kwargs, used_tuning_params, name, base_output_dir): 

70 self.app = app 

71 self.kwargs = kwargs 

72 self.used_tuning_params = used_tuning_params 

73 self.name = name 

74 self.base_output_dir = Path(base_output_dir) 

75 

76 def __call__(self, *args): 

77 run_kwargs = dict(self.kwargs) 

78 

79 for key, value in zip(self.used_tuning_params.keys(), *args): 

80 run_kwargs[key] = value 

81 

82 run_number = 0 

83 while True: 

84 trial_name = f"trial-{run_number}" 

85 output_dir = self.base_output_dir / trial_name 

86 if not output_dir.exists(): 

87 break 

88 run_number += 1 

89 

90 run_kwargs["output_dir"] = output_dir 

91 run_kwargs["project_name"] = self.name 

92 run_kwargs["run_name"] = trial_name 

93 

94 # Train 

95 learner = call_func(self.app.train, **run_kwargs) 

96 metric = self.app.get_best_metric(learner) 

97 

98 # make negative if the goal is to maximize this metric 

99 if self.app.goal()[:3] != "min": 

100 metric = -metric 

101 

102 return metric 

103 

104 def __deepcopy__(self, memo): 

105 """ Returns None for deepcopy because this shouldn't be copied into the checkpoint. """ 

106 return None 

107 

108 

109def skopt_tune( 

110 app, 

111 file: str = "", 

112 name: str = None, 

113 method: str = "", # Should be enum 

114 runs: int = 1, 

115 seed: int = None, 

116 **kwargs, 

117): 

118 

119 # Get tuning parameters 

120 tuning_params = app.tuning_params() 

121 used_tuning_params = {} 

122 for key, value in tuning_params.items(): 

123 if key not in kwargs or kwargs[key] is None: 

124 used_tuning_params[key] = value 

125 

126 # Get search space 

127 search_space = [get_param_search_space(param, name=key) for key, param in used_tuning_params.items()] 

128 

129 optimizer = get_optimizer(method) 

130 

131 if not name: 

132 name = f"{app.project_name()}-tuning" 

133 base_output_dir = Path(kwargs.get("output_dir", ".")) / name 

134 

135 optimizer_kwargs = dict(n_calls=runs, random_state=seed, callback=[]) 

136 

137 if False: 

138 optimizer_kwargs["callback"].append(SkoptPlot(base_output_dir, "svg")) 

139 

140 if file: 

141 file = Path(file) 

142 # if a file is given, first try to read from that file the results and then use it as a checkpoint 

143 # https://scikit-optimize.github.io/stable/auto_examples/interruptible-optimization.html 

144 if file.exists(): 

145 try: 

146 checkpoint = skopt.load(file) 

147 x0 = checkpoint.x_iters 

148 y0 = checkpoint.func_vals 

149 optimizer_kwargs['x0'] = x0 

150 optimizer_kwargs['y0'] = y0 

151 except Exception as e: 

152 raise IOError(f"Cannot read scikit-optimize checkpoint file '{file}': {e}") 

153 

154 checkpoint_saver = CheckpointSaver(str(file), compress=9, store_objective=False) 

155 optimizer_kwargs['callback'].append( checkpoint_saver ) 

156 

157 objective = SkoptObjective(app, kwargs, used_tuning_params, name, base_output_dir) 

158 results = optimizer(objective, search_space, **optimizer_kwargs) 

159 

160 return results