Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from fastai.callback.core import Callback
2import mlflow
3import matplotlib
4import pandas as pd
5import pickle
6from pathlib import Path
7from typing import Union, Optional
8from rich.console import Console
10console = Console()
13def get_or_create_experiment(experiment_name: str) -> mlflow.entities.Experiment:
14 """
15 Returns an existing MLflow experiment if it exists, otherwise it creates a new one.
17 Args:
18 experiment_name (str): The name or ID of the experiment.
20 Returns:
21 mlflow.experiment: The found or created experiment.
22 """
24 # Look for experiement in existing mlflow experiments
25 experiments = mlflow.list_experiments()
26 experiment = None
27 for experiment in experiments:
28 if experiment_name in [experiment.name, experiment.id]:
29 return experiment
31 # if not found, then create a new one with that name
32 return mlflow.create_experiment(name=experiment_name)
35class TorchAppMlflowCallback(Callback):
36 def __init__(
37 self,
38 app,
39 output_dir: Optional[Path] = None,
40 experiment_name: Optional[str] = None,
41 log_models: bool = False,
42 **kwargs,
43 ):
44 super().__init__(**kwargs)
45 self.app = app
47 # tracking_uri will be set if an output directory is given
48 if output_dir:
49 output_dir = Path(output_dir)
50 output_dir.mkdir(exist_ok=True, parents=True)
51 mlflow.set_tracking_uri(f'file://{output_dir.resolve()}')
53 # if no experiment_name is given, then it should use the app's project name
54 if experiment_name is None:
55 experiment_name = app.project_name()
57 self.experiment = get_or_create_experiment(experiment_name)
59 mlflow.start_run(experiment_id=self.experiment.id)
61 mlflow.fastai.autolog(log_models=log_models)
63 # checking functions
64 self.run = mlflow.active_run()
65 console.print(f"Active run_id: {self.run.info.run_id}")
67 tracking_uri = mlflow.get_tracking_uri()
68 console.print(f"Current tracking URI: {tracking_uri}")
70 def after_fit(self):
71 mlflow.end_run()
73 ################################################
74 # The code below isn't used, should it stay here?
75 ################################################
77 def log(self, param: dict, parameter_metric: bool = False, step: Optional[int] = None):
78 """
79 Log a dictionary of parameters.
81 If parameter metric = True, log as a set of metric with an optional argument 'step'.
83 Args:
84 param (dict): The dictionary of parameters to be logged.
85 parameter_metric (bool, optional): If True, log as a set of metrics. Defaults to False.
86 step (Optional[int], optional): an optional argument. Defaults to None.
87 """
88 if parameter_metric == True:
89 mlflow.log_metrics(param, step=step)
90 else:
91 mlflow.log_params(param)
93 def log_artifact(
94 self,
95 artifact,
96 artifact_path: Union[Path, str],
97 **kwargs,
98 ):
99 """
100 Input an artifact (pandas dataframe/matplotlib/plotly.figure/dict/str/path) to a saved file
102 Args:
103 artifact (_type_): artifact to be logged
104 artifact_path (Union[Path, str]): path to file to be uploaded.
105 """
107 if isinstance(artifact, pd.DataFrame):
108 csv_file = artifact.to_csv(None, sep='\t')
109 mlflow.log_text(csv_file, artifact_path)
111 elif isinstance(artifact, dict):
112 mlflow.log_dict(artifact, artifact_path)
114 elif isinstance(artifact, str):
115 mlflow.log_text(artifact, artifact_path)
117 elif isinstance(artifact, matplotlib.figure.Figure):
118 mlflow.log_figure(artifact, artifact_path)
120 # elif isinstance(artifact, plotly.graph_objs._figure.Figure ):
121 # mlflow.log_figure(artifact, artifact_path)
123 # elif isinstance(artifact, ):
124 # mlflow.log_artifact()
126 else:
127 pickle.dump(artifact, open(artifact_path, 'wb'))