Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from contextlib import nullcontext
2from pathlib import Path
3from types import MethodType
4from typing import List, Optional, Union, Dict
5import inspect
6import hashlib
7from appdirs import user_cache_dir
8import torch
9from torch import nn
10from fastai.learner import Learner, load_learner, load_model
11from fastai.data.core import DataLoaders
12from fastai.callback.schedule import fit_one_cycle
14# from fastai.distributed import distrib_ctx
15from fastai.callback.tracker import SaveModelCallback
16from fastai.callback.progress import CSVLogger
17import click
18import typer
19from typer.main import get_params_convertors_ctx_param_name_from_function
20from rich.console import Console
21from rich.traceback import install
22from rich.table import Table
23from rich.box import SIMPLE
27install()
28console = Console()
30from .citations import Citable
31from .util import copy_func, call_func, change_typer_to_defaults, add_kwargs
32from .params import Param
33from .callbacks import TorchAppWandbCallback, TorchAppMlflowCallback
34from .download import cached_download
36bibtex_dir = Path(__file__).parent / "bibtex"
39class TorchAppInitializationError(Exception):
40 pass
43class TorchApp(Citable):
44 torchapp_initialized = False
45 fine_tune = False
47 def __init__(self):
48 super().__init__()
50 # Make deep copies of methods so that we can change the function signatures dynamically
51 self.fit = self.copy_method(self.fit)
52 self.train = self.copy_method(self.train)
53 self.dataloaders = self.copy_method(self.dataloaders)
54 self.model = self.copy_method(self.model)
55 self.pretrained_location = self.copy_method(self.pretrained_location)
56 self.show_batch = self.copy_method(self.show_batch)
57 self.tune = self.copy_method(self.tune)
58 self.pretrained_local_path = self.copy_method(self.pretrained_local_path)
59 self.learner_kwargs = self.copy_method(self.learner_kwargs)
60 self.learner = self.copy_method(self.learner)
61 self.export = self.copy_method(self.export)
62 self.__call__ = self.copy_method(self.__call__)
63 self.validate = self.copy_method(self.validate)
64 self.callbacks = self.copy_method(self.callbacks)
65 self.extra_callbacks = self.copy_method(self.extra_callbacks)
66 self.inference_callbacks = self.copy_method(self.inference_callbacks)
67 self.one_batch_size = self.copy_method(self.one_batch_size)
68 self.one_batch_output = self.copy_method(self.one_batch_output)
69 self.one_batch_output_size = self.copy_method(self.one_batch_output_size)
70 self.one_batch_loss = self.copy_method(self.one_batch_loss)
71 self.loss_func = self.copy_method(self.loss_func)
72 self.metrics = self.copy_method(self.metrics)
73 self.lr_finder = self.copy_method(self.lr_finder)
74 self.inference_dataloader = self.copy_method(self.inference_dataloader)
75 self.output_results = self.copy_method(self.output_results)
77 # Add keyword arguments to the signatures of the methods used in the CLI
78 add_kwargs(to_func=self.learner_kwargs, from_funcs=[self.metrics, self.loss_func])
79 add_kwargs(to_func=self.learner, from_funcs=[self.learner_kwargs, self.dataloaders, self.model])
80 add_kwargs(to_func=self.callbacks, from_funcs=[self.extra_callbacks])
81 add_kwargs(to_func=self.export, from_funcs=[self.learner, self.callbacks])
82 add_kwargs(to_func=self.train, from_funcs=[self.learner, self.fit, self.callbacks])
83 add_kwargs(to_func=self.show_batch, from_funcs=self.dataloaders)
84 add_kwargs(to_func=self.tune, from_funcs=self.train)
85 add_kwargs(to_func=self.pretrained_local_path, from_funcs=self.pretrained_location)
86 add_kwargs(
87 to_func=self.__call__,
88 from_funcs=[self.pretrained_local_path, self.inference_dataloader, self.output_results, self.inference_callbacks],
89 )
90 add_kwargs(to_func=self.validate, from_funcs=[self.pretrained_local_path, self.dataloaders])
91 add_kwargs(to_func=self.one_batch_size, from_funcs=self.dataloaders)
92 add_kwargs(to_func=self.one_batch_output, from_funcs=self.learner)
93 add_kwargs(to_func=self.one_batch_loss, from_funcs=self.learner)
94 add_kwargs(to_func=self.lr_finder, from_funcs=self.learner)
95 add_kwargs(to_func=self.one_batch_output_size, from_funcs=self.one_batch_output)
97 # Make copies of methods to use just for the CLI
98 self.export_cli = self.copy_method(self.export)
99 self.train_cli = self.copy_method(self.train)
100 self.show_batch_cli = self.copy_method(self.show_batch)
101 self.tune_cli = self.copy_method(self.tune)
102 self.pretrained_local_path_cli = self.copy_method(self.pretrained_local_path)
103 self.infer_cli = self.copy_method(self.__call__)
104 self.validate_cli = self.copy_method(self.validate)
105 self.lr_finder_cli = self.copy_method(self.lr_finder)
107 # Remove params from defaults in methods not used for the cli
108 change_typer_to_defaults(self.fit)
109 change_typer_to_defaults(self.model)
110 change_typer_to_defaults(self.learner_kwargs)
111 change_typer_to_defaults(self.loss_func)
112 change_typer_to_defaults(self.metrics)
113 change_typer_to_defaults(self.export)
114 change_typer_to_defaults(self.learner)
115 change_typer_to_defaults(self.callbacks)
116 change_typer_to_defaults(self.extra_callbacks)
117 change_typer_to_defaults(self.train)
118 change_typer_to_defaults(self.show_batch)
119 change_typer_to_defaults(self.tune)
120 change_typer_to_defaults(self.pretrained_local_path)
121 change_typer_to_defaults(self.__call__)
122 change_typer_to_defaults(self.validate)
123 change_typer_to_defaults(self.dataloaders)
124 change_typer_to_defaults(self.pretrained_location)
125 change_typer_to_defaults(self.one_batch_size)
126 change_typer_to_defaults(self.one_batch_output_size)
127 change_typer_to_defaults(self.one_batch_output)
128 change_typer_to_defaults(self.one_batch_loss)
129 change_typer_to_defaults(self.lr_finder)
130 change_typer_to_defaults(self.inference_dataloader)
131 change_typer_to_defaults(self.inference_callbacks)
132 change_typer_to_defaults(self.output_results)
134 # Store a bool to let the app know later on (in self.assert_initialized)
135 # that __init__ has been called on this parent class
136 self.torchapp_initialized = True
137 self.learner_obj = None
138 # self.console = console
140 def __str__(self):
141 return self.__class__.__name__
143 def get_bibtex_files(self):
144 return [
145 bibtex_dir / "fastai.bib",
146 bibtex_dir / "torchapp.bib",
147 ]
149 def copy_method(self, method):
150 return MethodType(copy_func(method.__func__), self)
152 def pretrained_location(self) -> Union[str, Path]:
153 """
154 The location of a pretrained model.
156 It can be a URL, in which case it will need to be downloaded.
157 Or it can be part of the package bundle in which case,
158 it needs to be a relative path from directory which contains the code which defines the app.
160 This function by default returns an empty string.
161 Inherited classes need to override this method to use pretrained models.
163 Returns:
164 Union[str, Path]: The location of the pretrained model.
165 """
166 return ""
168 def pretrained_local_path(
169 self,
170 pretrained: str = Param(default=None, help="The location (URL or filepath) of a pretrained model."),
171 reload: bool = Param(
172 default=False,
173 help="Should the pretrained model be downloaded again if it is online and already present locally.",
174 ),
175 **kwargs,
176 ) -> Path:
177 """
178 The local path of the pretrained model.
180 If it is a URL, then it is downloaded.
181 If it is a relative path, then this method returns the absolute path to it.
183 Args:
184 pretrained (str, optional): The location (URL or filepath) of a pretrained model. If it is a relative path, then it is relative to the current working directory. Defaults to using the result of the `pretrained_location` method.
185 reload (bool, optional): Should the pretrained model be downloaded again if it is online and already present locally. Defaults to False.
187 Raises:
188 FileNotFoundError: If the file cannot be located in the local environment.
190 Returns:
191 Path: The absolute path to the model on the local filesystem.
192 """
193 if pretrained:
194 location = pretrained
195 base_dir = Path.cwd()
196 else:
197 location = str(call_func(self.pretrained_location, **kwargs))
198 module = inspect.getmodule(self)
199 base_dir = Path(module.__file__).parent.resolve()
201 if not location:
202 raise FileNotFoundError(f"Please pass in a pretrained model.")
204 # Check if needs to be downloaded
205 location = str(location)
206 if location.startswith("http"):
207 name = location.split("/")[-1]
208 extension_location = name.rfind(".")
209 if extension_location:
210 name_stem = name[:extension_location]
211 extension = name[extension_location:]
212 else:
213 name_stem = name
214 extension = ".dat"
215 url_hash = hashlib.md5(location.encode()).hexdigest()
216 path = self.cache_dir()/f"{name_stem}-{url_hash}{extension}"
217 cached_download(location, path, force=reload)
218 else:
219 path = Path(location)
220 if not path.is_absolute():
221 path = base_dir / path
223 if not path or not path.is_file():
224 raise FileNotFoundError(f"Cannot find pretrained model at '{path}'")
226 return path
228 def prepare_source(self, data):
229 return data
231 def inference_dataloader(self, learner, **kwargs):
232 dataloader = learner.dls.test_dl(**kwargs)
233 return dataloader
235 def validate(
236 self,
237 gpu: bool = Param(True, help="Whether or not to use a GPU for processing if available."),
238 **kwargs,
239 ):
240 path = call_func(self.pretrained_local_path, **kwargs)
242 # Check if CUDA is available
243 gpu = gpu and torch.cuda.is_available()
245 try:
246 learner = load_learner(path, cpu=not gpu)
247 except Exception:
248 import dill
249 learner = load_learner(path, cpu=not gpu, pickle_module=dill)
251 # Create a dataloader for inference
252 dataloaders = call_func(self.dataloaders, **kwargs)
254 table = Table(title="Validation", box=SIMPLE)
256 values = learner.validate(dl=dataloaders.valid)
257 names = [learner.recorder.loss.name] + [metric.name for metric in learner.metrics]
258 result = {name: value for name, value in zip(names, values)}
260 table.add_column("Metric", justify="right", style="cyan", no_wrap=True)
261 table.add_column("Value", style="magenta")
263 for name, value in result.items():
264 table.add_row(name, str(value))
266 console.print(table)
268 return result
270 def __call__(
271 self,
272 gpu: bool = Param(True, help="Whether or not to use a GPU for processing if available."),
273 **kwargs
274 ):
275 # Check if CUDA is available
276 gpu = gpu and torch.cuda.is_available()
278 # Open the exported learner from a pickle file
279 path = call_func(self.pretrained_local_path, **kwargs)
280 learner = self.learner_obj = load_learner(path, cpu=not gpu)
282 # Create a dataloader for inference
283 dataloader = call_func(self.inference_dataloader, learner, **kwargs)
285 inference_callbacks = call_func(self.inference_callbacks, **kwargs)
287 results = learner.get_preds(dl=dataloader, reorder=False, with_decoded=False, act=self.activation(), cbs=inference_callbacks)
289 # Output results
290 output_results = call_func(self.output_results, results, **kwargs)
291 return output_results if output_results is not None else results
293 def inference_callbacks(self):
294 return None
296 @classmethod
297 def main(cls, inference_only:bool=False):
298 """
299 Creates an instance of this class and runs the command-line interface.
300 """
301 cli = cls.click(inference_only=inference_only)
302 return cli()
304 @classmethod
305 def inference_only_main(cls):
306 """
307 Creates an instance of this class and runs the command-line interface for only the inference command.
308 """
309 return cls.main(inference_only=True)
311 @classmethod
312 def click(cls, inference_only:bool=False):
313 """
314 Creates an instance of this class and returns the click object for the command-line interface.
315 """
316 self = cls()
317 cli = self.cli(inference_only=inference_only)
318 return cli
320 @classmethod
321 def inference_only_click(cls):
322 """
323 Creates an instance of this class and returns the click object for the command-line interface.
324 """
325 return cls.click(inference_only=True)
327 def assert_initialized(self):
328 """
329 Asserts that this app has been initialized.
331 All sub-classes of TorchApp need to call super().__init__() if overriding the __init__() function.
333 Raises:
334 TorchAppInitializationError: If the app has not been properly initialized.
335 """
336 if not self.torchapp_initialized:
337 raise TorchAppInitializationError(
338 """The initialization function for this TorchApp has not been called.
339 Please ensure sub-classes of TorchApp call 'super().__init__()'"""
340 )
342 def version(self, verbose: bool = False):
343 """
344 Prints the version of the package that defines this app.
346 Used in the command-line interface.
348 Args:
349 verbose (bool, optional): Whether or not to print to stdout. Defaults to False.
351 Raises:
352 Exception: If it cannot find the package.
354 """
355 if verbose:
356 from importlib import metadata
358 module = inspect.getmodule(self)
359 package = ""
360 if module.__package__:
361 package = module.__package__.split('.')[0]
362 else:
363 path = Path(module.__file__).parent
364 while path.name:
365 try:
366 if metadata.distribution(path.name):
367 package = path.name
368 break
369 except Exception:
370 pass
371 path = path.parent
373 if package:
374 version = metadata.version(package)
375 print(version)
376 else:
377 raise Exception("Cannot find package.")
379 raise typer.Exit()
381 def cli(self, inference_only:bool=False):
382 """
383 Returns a 'Click' object which defines the command-line interface of the app.
384 """
385 self.assert_initialized()
387 cli = typer.Typer()
389 @cli.callback()
390 def base_callback(
391 version: Optional[bool] = typer.Option(
392 None,
393 "--version",
394 "-v",
395 callback=self.version,
396 is_eager=True,
397 help="Prints the current version.",
398 ),
399 ):
400 pass
402 typer_click_object = typer.main.get_command(cli)
404 params, _, _ = get_params_convertors_ctx_param_name_from_function(self.infer_cli)
405 command = click.Command(
406 name="infer",
407 callback=self.infer_cli,
408 params=params,
409 )
410 if inference_only:
411 return command
412 typer_click_object.add_command(command)
415 train_params, _, _ = get_params_convertors_ctx_param_name_from_function(self.train_cli)
416 train_command = click.Command(
417 name="train",
418 callback=self.train_cli,
419 params=train_params,
420 )
421 typer_click_object.add_command(train_command)
423 export_params, _, _ = get_params_convertors_ctx_param_name_from_function(self.export_cli)
424 export_command = click.Command(
425 name="export",
426 callback=self.export_cli,
427 params=export_params,
428 )
429 typer_click_object.add_command(export_command)
431 show_batch_params, _, _ = get_params_convertors_ctx_param_name_from_function(self.show_batch_cli)
432 command = click.Command(
433 name="show-batch",
434 callback=self.show_batch_cli,
435 params=show_batch_params,
436 )
437 typer_click_object.add_command(command)
439 params, _, _ = get_params_convertors_ctx_param_name_from_function(self.tune_cli)
440 tuning_params = self.tuning_params()
441 for param in params:
442 if param.name in tuning_params:
443 param.default = None
444 command = click.Command(
445 name="tune",
446 callback=self.tune_cli,
447 params=params,
448 )
449 typer_click_object.add_command(command)
451 params, _, _ = get_params_convertors_ctx_param_name_from_function(self.validate_cli)
452 command = click.Command(
453 name="validate",
454 callback=self.validate_cli,
455 params=params,
456 )
457 typer_click_object.add_command(command)
459 params, _, _ = get_params_convertors_ctx_param_name_from_function(self.lr_finder_cli)
460 command = click.Command(
461 name="lr-finder",
462 callback=self.lr_finder_cli,
463 params=params,
464 )
465 typer_click_object.add_command(command)
467 command = click.Command(
468 name="bibliography",
469 callback=self.print_bibliography,
470 )
471 typer_click_object.add_command(command)
473 command = click.Command(
474 name="bibtex",
475 callback=self.print_bibtex,
476 )
477 typer_click_object.add_command(command)
479 return typer_click_object
481 def tuning_params(self):
482 tuning_params = {}
483 signature = inspect.signature(self.tune_cli)
485 for key, value in signature.parameters.items():
486 default_value = value.default
487 if isinstance(default_value, Param) and default_value.tune == True:
489 # Override annotation if given in typing hints
490 if value.annotation:
491 default_value.annotation = value.annotation
493 default_value.check_choices()
495 tuning_params[key] = default_value
497 return tuning_params
499 def dataloaders(self):
500 raise NotImplementedError(
501 f"Please ensure that the 'dataloaders' method is implemented in {self.__class__.__name__}."
502 )
504 def model(self) -> nn.Module:
505 raise NotImplementedError(f"Please ensure that the 'model' method is implemented in {self.__class__.__name__}.")
507 def build_learner_func(self):
508 return Learner
510 def learner(
511 self,
512 fp16: bool = Param(
513 default=True,
514 help="Whether or not the floating-point precision of learner should be set to 16 bit.",
515 ),
516 **kwargs,
517 ) -> Learner:
518 """
519 Creates a fastai learner object.
520 """
521 console.print("Building dataloaders", style="bold")
522 dataloaders = call_func(self.dataloaders, **kwargs)
524 # Allow the dataloaders to go to GPU so long as it hasn't explicitly been set as a different device
525 if dataloaders.device is None:
526 dataloaders.cuda() # This will revert to CPU if cuda is not available
528 console.print("Building model", style="bold")
529 model = call_func(self.model, **kwargs)
531 console.print("Building learner", style="bold")
532 learner_kwargs = call_func(self.learner_kwargs, **kwargs)
533 build_learner_func = self.build_learner_func()
534 learner = build_learner_func(
535 dataloaders,
536 model,
537 **learner_kwargs,
538 )
540 learner.training_kwargs = kwargs
542 if fp16:
543 console.print("Setting floating-point precision of learner to 16 bit", style="red")
544 learner = learner.to_fp16()
546 # Save a pointer to the learner
547 self.learner_obj = learner
549 return learner
551 def learner_kwargs(
552 self,
553 output_dir: Path = Param("./outputs", help="The location of the output directory."),
554 weight_decay: float = Param(
555 None, help="The amount of weight decay. If None then it uses the default amount of weight decay in fastai."
556 ),
557 # l2_regularization: bool = Param(False, help="Whether to add decay to the gradients (L2 regularization) instead of to the weights directly (weight decay)."),
558 **kwargs,
559 ):
560 self.output_dir = Path(output_dir)
561 self.output_dir.mkdir(exist_ok=True, parents=True)
563 return dict(
564 loss_func=call_func(self.loss_func, **kwargs),
565 metrics=call_func(self.metrics, **kwargs),
566 path=self.output_dir,
567 wd=weight_decay,
568 )
570 def loss_func(self, **kwargs):
571 """The loss function. If None, then fastai will use the default loss function if it exists for this model."""
572 return None
574 def activation(self):
575 """The activation for the last layer. If None, then fastai will use the default activiation of the loss if it exists."""
576 return None
578 def metrics(self) -> List:
579 """
580 The list of metrics to use with this app.
582 By default this list is empty. This method should be subclassed to add metrics in child classes of TorchApp.
584 Returns:
585 List: The list of metrics.
586 """
587 return []
589 def monitor(self) -> str:
590 """
591 The metric to optimize for when performing hyperparameter tuning.
593 By default it returns 'valid_loss'.
594 """
595 return "valid_loss"
597 def goal(self) -> str:
598 """
599 Sets the optimality direction when evaluating the metric from `monitor`.
601 By default it produces the same behaviour as fastai callbacks (fastai.callback.tracker)
602 ie. it is set to "minimize" if the monitor metric has the string 'loss' or 'err' otherwise it is "maximize".
604 If the monitor is empty then this function returns None.
605 """
606 monitor = self.monitor()
607 if not monitor or not isinstance(monitor, str):
608 return None
610 return "minimize" if ("loss" in monitor) or ("err" in monitor) else "maximize"
612 def callbacks(
613 self,
614 project_name: str = Param(default=None, help="The name for this project for logging purposes."),
615 run_name: str = Param(default=None, help="The name for this particular run for logging purposes."),
616 run_id: str = Param(default=None, help="A unique ID for this particular run for logging purposes."),
617 notes: str = Param(None, help="A longer description of the run for logging purposes."),
618 tag: List[str] = Param(
619 None, help="A tag for logging purposes. Multiple tags can be added each introduced with --tag."
620 ),
621 wandb: bool = Param(default=False, help="Whether or not to use 'Weights and Biases' for logging."),
622 wandb_mode: str = Param(default="online", help="The mode for 'Weights and Biases'."),
623 wandb_dir: Path = Param(None, help="The location for 'Weights and Biases' output."),
624 wandb_entity: str = Param(None, help="An entity is a username or team name where you're sending runs."),
625 wandb_group: str = Param(None, help="Specify a group to organize individual runs into a larger experiment."),
626 wandb_job_type: str = Param(
627 None,
628 help="Specify the type of run, which is useful when you're grouping runs together into larger experiments using group.",
629 ),
630 mlflow: bool = Param(default=False, help="Whether or not to use MLflow for logging."),
631 **kwargs,
632 ) -> List:
633 """
634 The list of callbacks to use with this app in the fastai training loop.
636 Args:
637 project_name (str): The name for this project for logging purposes. If no name is given then the name of the app is used.
639 Returns:
640 list: The list of callbacks.
641 """
642 callbacks = [CSVLogger()]
643 monitor = self.monitor()
644 if monitor:
645 callbacks.append(SaveModelCallback(monitor=monitor))
647 if wandb:
648 callback = TorchAppWandbCallback(
649 app=self,
650 project_name=project_name,
651 id=run_id,
652 name=run_name,
653 mode=wandb_mode,
654 dir=wandb_dir,
655 entity=wandb_entity,
656 group=wandb_group,
657 job_type=wandb_job_type,
658 notes=notes,
659 tags=tag,
660 )
661 callbacks.append(callback)
662 self.add_bibtex_file(bibtex_dir / "wandb.bib") # this should be in the callback
664 if mlflow:
665 callbacks.append(TorchAppMlflowCallback(app=self, experiment_name=project_name))
666 self.add_bibtex_file(bibtex_dir / "mlflow.bib") # this should be in the callback
668 extra_callbacks = call_func(self.extra_callbacks, **kwargs)
669 if extra_callbacks:
670 callbacks += extra_callbacks
672 return callbacks
674 def extra_callbacks(self):
675 return None
677 def show_batch(
678 self, output_path: Path = Param(None, help="A location to save the HTML which summarizes the batch."), **kwargs
679 ):
680 dataloaders = call_func(self.dataloaders, **kwargs)
682 # patch the display function of ipython so we can capture the HTML
683 def mock_display(html_object):
684 self.batch_html = html_object
686 import IPython.display
688 ipython_display = IPython.display.display
689 IPython.display.display = mock_display
691 dataloaders.show_batch()
692 batch_html = getattr(self, "batch_html", None)
694 if not batch_html:
695 console.print(f"Cannot display batch as HTML")
696 return
698 html = batch_html.data
700 # Write output
701 if output_path:
702 console.print(f"Writing batch HTML to '{output_path}'")
703 with open(output_path, 'w') as f:
704 f.write(html)
705 else:
706 console.print(html)
707 console.print(f"To write this HTML output to a file, give an output path.")
709 # restore the ipython display function
710 IPython.display.display = ipython_display
711 return self.batch_html
713 def train(
714 self,
715 distributed: bool = Param(default=False, help="If the learner is distributed."),
716 **kwargs,
717 ) -> Learner:
718 """
719 Trains a model for this app.
721 Args:
722 distributed (bool, optional): If the learner is distributed. Defaults to Param(default=False, help="If the learner is distributed.").
724 Returns:
725 Learner: The fastai Learner object created for training.
726 """
727 self.training_kwargs = kwargs
728 self.training_kwargs['distributed'] = distributed
730 callbacks = call_func(self.callbacks, **kwargs)
731 learner = call_func(self.learner, **kwargs)
733 self.print_bibliography(verbose=True)
735 # with learner.distrib_ctx() if distributed == True else nullcontext():
736 call_func(self.fit, learner, callbacks, **kwargs)
738 learner.export()
740 return learner
742 def export(self, model_path:Path, **kwargs):
743 """
744 Generates a learner, saves model weights from a file and exports the learner so that it can be used for inference.
746 This is useful if a run has not reached completion but the model weights have still been saved.
747 """
748 # Run the callbacks to ensure that everything is initialized the same as running the training loop
749 call_func(self.callbacks, **kwargs)
750 learner = call_func(self.learner, **kwargs)
751 load_model(model_path, learner.model, opt=None, with_opt=False, device=learner.dls.device, strict=True)
752 learner.export()
753 return learner
755 def fit(
756 self,
757 learner,
758 callbacks,
759 epochs: int = Param(default=20, help="The number of epochs."),
760 freeze_epochs: int = Param(
761 default=3,
762 help="The number of epochs to train when the learner is frozen and the last layer is trained by itself. Only if `fine_tune` is set on the app.",
763 ),
764 learning_rate: float = Param(
765 default=1e-4, help="The base learning rate (when fine tuning) or the max learning rate otherwise."
766 ),
767 **kwargs,
768 ):
769 if self.fine_tune:
770 return learner.fine_tune(
771 epochs, freeze_epochs=freeze_epochs, base_lr=learning_rate, cbs=callbacks, **kwargs
772 ) # hack
774 return learner.fit_one_cycle(epochs, lr_max=learning_rate, cbs=callbacks, **kwargs)
776 def project_name(self) -> str:
777 """
778 The name to use for a project for logging purposes.
780 The default is to use the class name.
781 """
782 return self.__class__.__name__
784 def tune(
785 self,
786 runs: int = Param(default=1, help="The number of runs to attempt to train the model."),
787 engine: str = Param(
788 default="skopt",
789 help="The optimizer to use to perform the hyperparameter tuning. Options: wandb, optuna, skopt.",
790 ), # should be enum
791 id: str = Param(
792 default="",
793 help="The ID of this hyperparameter tuning job. "
794 "If using wandb, then this is the sweep id. "
795 "If using optuna, then this is the storage. "
796 "If using skopt, then this is the file to store the results. ",
797 ),
798 name: str = Param(
799 default="",
800 help="An informative name for this hyperparameter tuning job. If empty, then it creates a name from the project name.",
801 ),
802 method: str = Param(
803 default="", help="The sampling method to use to perform the hyperparameter tuning. By default it chooses the default method of the engine."
804 ), # should be enum
805 min_iter: int = Param(
806 default=None,
807 help="The minimum number of iterations if using early termination. If left empty, then early termination is not used.",
808 ),
809 seed: int = Param(
810 default=None,
811 help="A seed for the random number generator.",
812 ),
813 **kwargs,
814 ):
815 if not name:
816 name = f"{self.project_name()}-tuning"
818 if engine == "wandb":
819 from .tuning.wandb import wandb_tune
821 self.add_bibtex_file(bibtex_dir / "wandb.bib")
823 return wandb_tune(
824 self,
825 runs=runs,
826 sweep_id=id,
827 name=name,
828 method=method,
829 min_iter=min_iter,
830 **kwargs,
831 )
832 elif engine == "optuna":
833 from .tuning.optuna import optuna_tune
835 self.add_bibtex_file(bibtex_dir / "optuna.bib")
837 return optuna_tune(
838 self,
839 runs=runs,
840 storage=id,
841 name=name,
842 method=method,
843 seed=seed,
844 **kwargs,
845 )
846 elif engine in ["skopt", "scikit-optimize"]:
847 from .tuning.skopt import skopt_tune
849 self.add_bibtex_file(bibtex_dir / "skopt.bib")
851 return skopt_tune(
852 self,
853 runs=runs,
854 file=id,
855 name=name,
856 method=method,
857 seed=seed,
858 **kwargs,
859 )
860 else:
861 raise NotImplementedError(f"Optimizer engine {engine} not implemented.")
863 def get_best_metric(self, learner):
864 # The slice is there because 'epoch' is prepended to the list but it isn't included in the values
865 metric_index = learner.recorder.metric_names[1:].index(self.monitor())
866 metric_values = list(map(lambda row: row[metric_index], learner.recorder.values))
867 metric_function = min if self.goal()[:3] == "min" else max
868 metric_value = metric_function(metric_values)
869 return metric_value
871 def one_batch_size(self, **kwargs):
872 dls = call_func(self.dataloaders, **kwargs)
873 batch = dls.train.one_batch()
874 return batch[0].size()
876 def one_batch_output(self, **kwargs):
877 learner = call_func(self.learner, **kwargs)
878 batch = learner.dls.train.one_batch()
879 n_inputs = getattr(learner.dls, 'n_inp', 1 if len(batch) == 1 else len(batch) - 1)
880 batch_x = batch[:n_inputs]
882 learner.model.to(batch_x[0].device)
883 with torch.no_grad():
884 output = learner.model(*batch_x)
885 return output
887 def one_batch_output_size(self, **kwargs):
888 output = self.one_batch_output(**kwargs)
889 return output.size()
891 def one_batch_loss(self, **kwargs):
892 learner = call_func(self.learner, **kwargs)
893 batch = learner.dls.train.one_batch()
894 n_inputs = getattr(learner.dls, 'n_inp', 1 if len(batch) == 1 else len(batch) - 1)
895 batch_x = batch[:n_inputs]
896 batch_y = batch[n_inputs:]
898 learner.model.to(batch_x[0].device)
899 with torch.no_grad():
900 output = learner.model(*batch_x)
901 loss = learner.loss_func(output, *batch_y)
903 return loss
905 def lr_finder(
906 self, plot_filename: Path = None, start_lr: float = 1e-07, end_lr: float = 10, iterations: int = 100, **kwargs
907 ):
908 learner = call_func(self.learner, **kwargs)
910 from matplotlib import pyplot as plt
911 from fastai.callback.schedule import SuggestionMethod
913 suggest_funcs = (
914 SuggestionMethod.Valley,
915 SuggestionMethod.Minimum,
916 SuggestionMethod.Slide,
917 SuggestionMethod.Steep,
918 )
920 result = learner.lr_find(
921 stop_div=False,
922 num_it=iterations,
923 start_lr=start_lr,
924 end_lr=end_lr,
925 show_plot=plot_filename is not None,
926 suggest_funcs=suggest_funcs,
927 )
929 if plot_filename is not None:
930 plt.savefig(str(plot_filename))
932 print("\n")
933 table = Table(title="Suggested Learning Rates", box=SIMPLE)
935 table.add_column("Method", style="cyan", no_wrap=True)
936 table.add_column("Learning Rate", style="magenta")
937 table.add_column("Explanation")
939 for method, value in zip(suggest_funcs, result):
940 table.add_row(method.__name__, str(value), method.__doc__)
942 console.print(table)
944 return result
946 def cache_dir(self) -> Path:
947 """ Returns a path to a directory where data files for this app can be cached. """
948 cache_dir = Path(user_cache_dir("torchapps"))/self.__class__.__name__
949 cache_dir.mkdir(exist_ok=True, parents=True)
950 return cache_dir
952 def output_results(
953 self,
954 results,
955 **kwargs
956 ):
957 print(results)
958 return results