Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from pathlib import Path 

2import wandb 

3from rich.console import Console 

4from rich.pretty import pprint 

5import math 

6from ..util import call_func 

7 

8console = Console() 

9 

10 

11def get_parameter_config(param) -> dict: 

12 if param.tune_choices: 

13 return dict( 

14 distribution="categorical", 

15 values=param.tune_choices, 

16 ) 

17 

18 if param.annotation in [int, float]: 

19 assert param.tune_min is not None 

20 assert param.tune_max is not None 

21 

22 distribution = "log_uniform_values" if param.tune_log else "uniform" 

23 if param.annotation == int: 

24 distribution = f"q_{distribution}" 

25 

26 return dict( 

27 distribution=distribution, 

28 min=param.tune_min, 

29 max=param.tune_max, 

30 ) 

31 

32 raise NotImplementedError 

33 

34 

35def get_sweep_config( 

36 app, 

37 name: str = None, 

38 method: str = "", # Should be enum 

39 min_iter: int = None, 

40 **kwargs, 

41): 

42 parameters_config = dict() 

43 tuning_params = app.tuning_params() 

44 

45 for key, value in tuning_params.items(): 

46 if key not in kwargs or kwargs[key] is None: 

47 parameters_config[key] = get_parameter_config(value) 

48 

49 method = method or "bayes" 

50 if method not in ["grid", "random", "bayes"]: 

51 raise NotImplementedError(f"Cannot interpret sampling method '{method}' using wandb.") 

52 

53 sweep_config = { 

54 "name": name, 

55 "method": method, 

56 "parameters": parameters_config, 

57 } 

58 if app.monitor(): 

59 sweep_config["metric"] = dict(name=app.monitor(), goal=app.goal()) 

60 

61 if min_iter: 

62 sweep_config["early_terminate"] = dict(type="hyperband", min_iter=min_iter) 

63 console.print("Configuration for hyper-parameter tuning:", style="bold red") 

64 pprint(sweep_config) 

65 return sweep_config 

66 

67 

68def wandb_tune( 

69 app, 

70 sweep_id: str = None, 

71 name: str = None, 

72 method: str = "random", # Should be enum 

73 runs: int = 1, 

74 min_iter: int = None, 

75 **kwargs, 

76) -> str: 

77 """ 

78 Performs hyperparameter tuning using 'weights and biases' sweeps. 

79 

80 Args: 

81 sweep_id(str, optional): The sweep ID, only necessary if sweep has already been generated for the project, defaults to None. 

82 name(str, optional): The name of the sweep run. This defaults to the project name with the suffix '-tuning' if left as None. 

83 method(str, optional): The hyperparameter sweep method, can be 'random' for random, 

84 'grid' for grid search, 

85 and 'bayes' for bayes optimisation. 

86 Defaults to 'random'. 

87 runs(int, optional): The number of runs. Defaults to 1. 

88 min_iter(int, optional): The minimum number of iterations if using early termination. If left empty, then early termination is not used. 

89 **kwargs: 

90 

91 Returns: 

92 str: The sweep id. This can be used by other runs in the same sweep. 

93 """ 

94 # Create a sweep id if it hasn't been given as an argument 

95 if not sweep_id: 

96 sweep_config = get_sweep_config( 

97 app=app, 

98 name=name, 

99 method=method, 

100 min_iter=min_iter, 

101 **kwargs, 

102 ) 

103 sweep_id = wandb.sweep(sweep_config, project=name) 

104 console.print(f"The wandb sweep id is: {sweep_id}", style="bold red") 

105 

106 def agent_train(): 

107 with wandb.init() as run: 

108 run_kwargs = dict(kwargs) 

109 run_kwargs.update(wandb.config) 

110 if "output_dir" in run_kwargs: 

111 run_kwargs["output_dir"] = Path(run_kwargs["output_dir"]) / run.name 

112 

113 console.print("Training with parameters:", style="bold red") 

114 pprint(run_kwargs) 

115 

116 return call_func(app.train, **run_kwargs) 

117 

118 wandb.agent(sweep_id, function=agent_train, count=runs, project=name) 

119 

120 return sweep_id