Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from pathlib import Path 

2 

3try: 

4 import optuna 

5 from optuna.integration import FastAIV2PruningCallback 

6 from optuna import samplers 

7except: 

8 raise Exception( 

9 "No module named 'optuna'. Please install this as an extra dependency or choose a different optimization engine." 

10 ) 

11 

12from ..util import call_func 

13 

14 

15def get_sampler(method, seed=0): 

16 method = method.lower() 

17 if method.startswith("tpe") or not method: 

18 return samplers.TPESampler(seed=seed) 

19 elif method.startswith("cma"): 

20 return samplers.CmaEsSampler(seed=seed) 

21 # elif method.startswith("grid"): 

22 # return samplers.GridSampler() 

23 elif method.startswith("random"): 

24 return samplers.RandomSampler(seed=seed) 

25 

26 raise NotImplementedError(f"Cannot interpret sampling method '{method}' using Optuna.") 

27 

28 

29def suggest(trial, name, param): 

30 if param.tune_choices: 

31 return trial.suggest_categorical(name, param.tune_choices) 

32 elif param.annotation == float: 

33 return trial.suggest_float(name, param.tune_min, param.tune_max, log=param.tune_log) 

34 elif param.annotation == int: 

35 return trial.suggest_int(name, param.tune_min, param.tune_max, log=param.tune_log) 

36 

37 raise NotImplementedError("Optuna Tuning Engine cannot understand param '{name}': {param}") 

38 

39 

40def optuna_tune( 

41 app, 

42 storage: str = "", 

43 name: str = None, 

44 method: str = "tpe", # Should be enum 

45 runs: int = 1, 

46 seed: int = None, 

47 **kwargs, 

48): 

49 output_dir = Path(kwargs.get("output_dir", ".")) 

50 output_dir.mkdir(parents=True, exist_ok=True) 

51 

52 def objective(trial: optuna.Trial): 

53 run_kwargs = dict(kwargs) 

54 

55 tuning_params = app.tuning_params() 

56 

57 for key, value in tuning_params.items(): 

58 if key not in kwargs or kwargs[key] is None: 

59 run_kwargs[key] = suggest(trial, key, value) 

60 

61 trial_name = f"trial-{trial.number}" 

62 

63 output_dir = Path(run_kwargs.get("output_dir", ".")) 

64 run_kwargs["output_dir"] = output_dir / trial.study.study_name / trial_name 

65 run_kwargs["project_name"] = trial.study.study_name 

66 run_kwargs["run_name"] = trial_name 

67 

68 # Train 

69 learner = call_func(app.train, **run_kwargs) 

70 

71 # Return metric from recorder 

72 return app.get_best_metric(learner) 

73 

74 if not storage: 

75 storage = None 

76 elif "://" not in storage: 

77 storage_path = output_dir/f"{storage}.sqlite3" 

78 storage = f"sqlite:///{storage_path.resolve()}" 

79 

80 study = optuna.create_study( 

81 study_name=name, 

82 storage=storage, 

83 sampler=get_sampler(method, seed=seed), 

84 load_if_exists=True, 

85 direction=app.goal(), 

86 ) 

87 study.optimize(objective, n_trials=runs) 

88 return study