Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from pathlib import Path
2import wandb
3from rich.console import Console
4from rich.pretty import pprint
5import math
6from ..util import call_func
8console = Console()
11def get_parameter_config(param) -> dict:
12 if param.tune_choices:
13 return dict(
14 distribution="categorical",
15 values=param.tune_choices,
16 )
18 if param.annotation in [int, float]:
19 assert param.tune_min is not None
20 assert param.tune_max is not None
22 distribution = "log_uniform_values" if param.tune_log else "uniform"
23 if param.annotation == int:
24 distribution = f"q_{distribution}"
26 return dict(
27 distribution=distribution,
28 min=param.tune_min,
29 max=param.tune_max,
30 )
32 raise NotImplementedError
35def get_sweep_config(
36 app,
37 name: str = None,
38 method: str = "", # Should be enum
39 min_iter: int = None,
40 **kwargs,
41):
42 parameters_config = dict()
43 tuning_params = app.tuning_params()
45 for key, value in tuning_params.items():
46 if key not in kwargs or kwargs[key] is None:
47 parameters_config[key] = get_parameter_config(value)
49 method = method or "bayes"
50 if method not in ["grid", "random", "bayes"]:
51 raise NotImplementedError(f"Cannot interpret sampling method '{method}' using wandb.")
53 sweep_config = {
54 "name": name,
55 "method": method,
56 "parameters": parameters_config,
57 }
58 if app.monitor():
59 sweep_config["metric"] = dict(name=app.monitor(), goal=app.goal())
61 if min_iter:
62 sweep_config["early_terminate"] = dict(type="hyperband", min_iter=min_iter)
63 console.print("Configuration for hyper-parameter tuning:", style="bold red")
64 pprint(sweep_config)
65 return sweep_config
68def wandb_tune(
69 app,
70 sweep_id: str = None,
71 name: str = None,
72 method: str = "random", # Should be enum
73 runs: int = 1,
74 min_iter: int = None,
75 **kwargs,
76) -> str:
77 """
78 Performs hyperparameter tuning using 'weights and biases' sweeps.
80 Args:
81 sweep_id(str, optional): The sweep ID, only necessary if sweep has already been generated for the project, defaults to None.
82 name(str, optional): The name of the sweep run. This defaults to the project name with the suffix '-tuning' if left as None.
83 method(str, optional): The hyperparameter sweep method, can be 'random' for random,
84 'grid' for grid search,
85 and 'bayes' for bayes optimisation.
86 Defaults to 'random'.
87 runs(int, optional): The number of runs. Defaults to 1.
88 min_iter(int, optional): The minimum number of iterations if using early termination. If left empty, then early termination is not used.
89 **kwargs:
91 Returns:
92 str: The sweep id. This can be used by other runs in the same sweep.
93 """
94 # Create a sweep id if it hasn't been given as an argument
95 if not sweep_id:
96 sweep_config = get_sweep_config(
97 app=app,
98 name=name,
99 method=method,
100 min_iter=min_iter,
101 **kwargs,
102 )
103 sweep_id = wandb.sweep(sweep_config, project=name)
104 console.print(f"The wandb sweep id is: {sweep_id}", style="bold red")
106 def agent_train():
107 with wandb.init() as run:
108 run_kwargs = dict(kwargs)
109 run_kwargs.update(wandb.config)
110 if "output_dir" in run_kwargs:
111 run_kwargs["output_dir"] = Path(run_kwargs["output_dir"]) / run.name
113 console.print("Training with parameters:", style="bold red")
114 pprint(run_kwargs)
116 return call_func(app.train, **run_kwargs)
118 wandb.agent(sweep_id, function=agent_train, count=runs, project=name)
120 return sweep_id