Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from fastai.callback.core import Callback 

2import mlflow 

3import matplotlib 

4import pandas as pd 

5import pickle 

6from pathlib import Path 

7from typing import Union, Optional 

8from rich.console import Console 

9 

10console = Console() 

11 

12 

13def get_or_create_experiment(experiment_name: str) -> mlflow.entities.Experiment: 

14 """ 

15 Returns an existing MLflow experiment if it exists, otherwise it creates a new one. 

16 

17 Args: 

18 experiment_name (str): The name or ID of the experiment. 

19 

20 Returns: 

21 mlflow.experiment: The found or created experiment. 

22 """ 

23 

24 # Look for experiement in existing mlflow experiments 

25 experiments = mlflow.list_experiments() 

26 experiment = None 

27 for experiment in experiments: 

28 if experiment_name in [experiment.name, experiment.id]: 

29 return experiment 

30 

31 # if not found, then create a new one with that name 

32 return mlflow.create_experiment(name=experiment_name) 

33 

34 

35class TorchAppMlflowCallback(Callback): 

36 def __init__( 

37 self, 

38 app, 

39 output_dir: Optional[Path] = None, 

40 experiment_name: Optional[str] = None, 

41 log_models: bool = False, 

42 **kwargs, 

43 ): 

44 super().__init__(**kwargs) 

45 self.app = app 

46 

47 # tracking_uri will be set if an output directory is given 

48 if output_dir: 

49 output_dir = Path(output_dir) 

50 output_dir.mkdir(exist_ok=True, parents=True) 

51 mlflow.set_tracking_uri(f'file://{output_dir.resolve()}') 

52 

53 # if no experiment_name is given, then it should use the app's project name 

54 if experiment_name is None: 

55 experiment_name = app.project_name() 

56 

57 self.experiment = get_or_create_experiment(experiment_name) 

58 

59 mlflow.start_run(experiment_id=self.experiment.id) 

60 

61 mlflow.fastai.autolog(log_models=log_models) 

62 

63 # checking functions 

64 self.run = mlflow.active_run() 

65 console.print(f"Active run_id: {self.run.info.run_id}") 

66 

67 tracking_uri = mlflow.get_tracking_uri() 

68 console.print(f"Current tracking URI: {tracking_uri}") 

69 

70 def after_fit(self): 

71 mlflow.end_run() 

72 

73 ################################################ 

74 # The code below isn't used, should it stay here? 

75 ################################################ 

76 

77 def log(self, param: dict, parameter_metric: bool = False, step: Optional[int] = None): 

78 """ 

79 Log a dictionary of parameters. 

80 

81 If parameter metric = True, log as a set of metric with an optional argument 'step'. 

82 

83 Args: 

84 param (dict): The dictionary of parameters to be logged. 

85 parameter_metric (bool, optional): If True, log as a set of metrics. Defaults to False. 

86 step (Optional[int], optional): an optional argument. Defaults to None. 

87 """ 

88 if parameter_metric == True: 

89 mlflow.log_metrics(param, step=step) 

90 else: 

91 mlflow.log_params(param) 

92 

93 def log_artifact( 

94 self, 

95 artifact, 

96 artifact_path: Union[Path, str], 

97 **kwargs, 

98 ): 

99 """ 

100 Input an artifact (pandas dataframe/matplotlib/plotly.figure/dict/str/path) to a saved file 

101 

102 Args: 

103 artifact (_type_): artifact to be logged 

104 artifact_path (Union[Path, str]): path to file to be uploaded. 

105 """ 

106 

107 if isinstance(artifact, pd.DataFrame): 

108 csv_file = artifact.to_csv(None, sep='\t') 

109 mlflow.log_text(csv_file, artifact_path) 

110 

111 elif isinstance(artifact, dict): 

112 mlflow.log_dict(artifact, artifact_path) 

113 

114 elif isinstance(artifact, str): 

115 mlflow.log_text(artifact, artifact_path) 

116 

117 elif isinstance(artifact, matplotlib.figure.Figure): 

118 mlflow.log_figure(artifact, artifact_path) 

119 

120 # elif isinstance(artifact, plotly.graph_objs._figure.Figure ): 

121 # mlflow.log_figure(artifact, artifact_path) 

122 

123 # elif isinstance(artifact, ): 

124 # mlflow.log_artifact() 

125 

126 else: 

127 pickle.dump(artifact, open(artifact_path, 'wb'))