Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from contextlib import nullcontext 

2from pathlib import Path 

3from types import MethodType 

4from typing import List, Optional, Union, Dict 

5import inspect 

6import hashlib 

7from appdirs import user_cache_dir 

8import torch 

9from torch import nn 

10from fastai.learner import Learner, load_learner, load_model 

11from fastai.data.core import DataLoaders 

12from fastai.callback.schedule import fit_one_cycle 

13 

14# from fastai.distributed import distrib_ctx 

15from fastai.callback.tracker import SaveModelCallback 

16from fastai.callback.progress import CSVLogger 

17import click 

18import typer 

19from typer.main import get_params_convertors_ctx_param_name_from_function 

20from rich.console import Console 

21from rich.traceback import install 

22from rich.table import Table 

23from rich.box import SIMPLE 

24 

25 

26 

27install() 

28console = Console() 

29 

30from .citations import Citable 

31from .util import copy_func, call_func, change_typer_to_defaults, add_kwargs 

32from .params import Param 

33from .callbacks import TorchAppWandbCallback, TorchAppMlflowCallback 

34from .download import cached_download 

35 

36bibtex_dir = Path(__file__).parent / "bibtex" 

37 

38 

39class TorchAppInitializationError(Exception): 

40 pass 

41 

42 

43class TorchApp(Citable): 

44 torchapp_initialized = False 

45 fine_tune = False 

46 

47 def __init__(self): 

48 super().__init__() 

49 

50 # Make deep copies of methods so that we can change the function signatures dynamically 

51 self.fit = self.copy_method(self.fit) 

52 self.train = self.copy_method(self.train) 

53 self.dataloaders = self.copy_method(self.dataloaders) 

54 self.model = self.copy_method(self.model) 

55 self.pretrained_location = self.copy_method(self.pretrained_location) 

56 self.show_batch = self.copy_method(self.show_batch) 

57 self.tune = self.copy_method(self.tune) 

58 self.pretrained_local_path = self.copy_method(self.pretrained_local_path) 

59 self.learner_kwargs = self.copy_method(self.learner_kwargs) 

60 self.learner = self.copy_method(self.learner) 

61 self.export = self.copy_method(self.export) 

62 self.__call__ = self.copy_method(self.__call__) 

63 self.validate = self.copy_method(self.validate) 

64 self.callbacks = self.copy_method(self.callbacks) 

65 self.extra_callbacks = self.copy_method(self.extra_callbacks) 

66 self.inference_callbacks = self.copy_method(self.inference_callbacks) 

67 self.one_batch_size = self.copy_method(self.one_batch_size) 

68 self.one_batch_output = self.copy_method(self.one_batch_output) 

69 self.one_batch_output_size = self.copy_method(self.one_batch_output_size) 

70 self.one_batch_loss = self.copy_method(self.one_batch_loss) 

71 self.loss_func = self.copy_method(self.loss_func) 

72 self.metrics = self.copy_method(self.metrics) 

73 self.lr_finder = self.copy_method(self.lr_finder) 

74 self.inference_dataloader = self.copy_method(self.inference_dataloader) 

75 self.output_results = self.copy_method(self.output_results) 

76 

77 # Add keyword arguments to the signatures of the methods used in the CLI 

78 add_kwargs(to_func=self.learner_kwargs, from_funcs=[self.metrics, self.loss_func]) 

79 add_kwargs(to_func=self.learner, from_funcs=[self.learner_kwargs, self.dataloaders, self.model]) 

80 add_kwargs(to_func=self.callbacks, from_funcs=[self.extra_callbacks]) 

81 add_kwargs(to_func=self.export, from_funcs=[self.learner, self.callbacks]) 

82 add_kwargs(to_func=self.train, from_funcs=[self.learner, self.fit, self.callbacks]) 

83 add_kwargs(to_func=self.show_batch, from_funcs=self.dataloaders) 

84 add_kwargs(to_func=self.tune, from_funcs=self.train) 

85 add_kwargs(to_func=self.pretrained_local_path, from_funcs=self.pretrained_location) 

86 add_kwargs( 

87 to_func=self.__call__, 

88 from_funcs=[self.pretrained_local_path, self.inference_dataloader, self.output_results, self.inference_callbacks], 

89 ) 

90 add_kwargs(to_func=self.validate, from_funcs=[self.pretrained_local_path, self.dataloaders]) 

91 add_kwargs(to_func=self.one_batch_size, from_funcs=self.dataloaders) 

92 add_kwargs(to_func=self.one_batch_output, from_funcs=self.learner) 

93 add_kwargs(to_func=self.one_batch_loss, from_funcs=self.learner) 

94 add_kwargs(to_func=self.lr_finder, from_funcs=self.learner) 

95 add_kwargs(to_func=self.one_batch_output_size, from_funcs=self.one_batch_output) 

96 

97 # Make copies of methods to use just for the CLI 

98 self.export_cli = self.copy_method(self.export) 

99 self.train_cli = self.copy_method(self.train) 

100 self.show_batch_cli = self.copy_method(self.show_batch) 

101 self.tune_cli = self.copy_method(self.tune) 

102 self.pretrained_local_path_cli = self.copy_method(self.pretrained_local_path) 

103 self.infer_cli = self.copy_method(self.__call__) 

104 self.validate_cli = self.copy_method(self.validate) 

105 self.lr_finder_cli = self.copy_method(self.lr_finder) 

106 

107 # Remove params from defaults in methods not used for the cli 

108 change_typer_to_defaults(self.fit) 

109 change_typer_to_defaults(self.model) 

110 change_typer_to_defaults(self.learner_kwargs) 

111 change_typer_to_defaults(self.loss_func) 

112 change_typer_to_defaults(self.metrics) 

113 change_typer_to_defaults(self.export) 

114 change_typer_to_defaults(self.learner) 

115 change_typer_to_defaults(self.callbacks) 

116 change_typer_to_defaults(self.extra_callbacks) 

117 change_typer_to_defaults(self.train) 

118 change_typer_to_defaults(self.show_batch) 

119 change_typer_to_defaults(self.tune) 

120 change_typer_to_defaults(self.pretrained_local_path) 

121 change_typer_to_defaults(self.__call__) 

122 change_typer_to_defaults(self.validate) 

123 change_typer_to_defaults(self.dataloaders) 

124 change_typer_to_defaults(self.pretrained_location) 

125 change_typer_to_defaults(self.one_batch_size) 

126 change_typer_to_defaults(self.one_batch_output_size) 

127 change_typer_to_defaults(self.one_batch_output) 

128 change_typer_to_defaults(self.one_batch_loss) 

129 change_typer_to_defaults(self.lr_finder) 

130 change_typer_to_defaults(self.inference_dataloader) 

131 change_typer_to_defaults(self.inference_callbacks) 

132 change_typer_to_defaults(self.output_results) 

133 

134 # Store a bool to let the app know later on (in self.assert_initialized) 

135 # that __init__ has been called on this parent class 

136 self.torchapp_initialized = True 

137 self.learner_obj = None 

138 # self.console = console 

139 

140 def __str__(self): 

141 return self.__class__.__name__ 

142 

143 def get_bibtex_files(self): 

144 return [ 

145 bibtex_dir / "fastai.bib", 

146 bibtex_dir / "torchapp.bib", 

147 ] 

148 

149 def copy_method(self, method): 

150 return MethodType(copy_func(method.__func__), self) 

151 

152 def pretrained_location(self) -> Union[str, Path]: 

153 """ 

154 The location of a pretrained model. 

155 

156 It can be a URL, in which case it will need to be downloaded. 

157 Or it can be part of the package bundle in which case, 

158 it needs to be a relative path from directory which contains the code which defines the app. 

159 

160 This function by default returns an empty string. 

161 Inherited classes need to override this method to use pretrained models. 

162 

163 Returns: 

164 Union[str, Path]: The location of the pretrained model. 

165 """ 

166 return "" 

167 

168 def pretrained_local_path( 

169 self, 

170 pretrained: str = Param(default=None, help="The location (URL or filepath) of a pretrained model."), 

171 reload: bool = Param( 

172 default=False, 

173 help="Should the pretrained model be downloaded again if it is online and already present locally.", 

174 ), 

175 **kwargs, 

176 ) -> Path: 

177 """ 

178 The local path of the pretrained model. 

179 

180 If it is a URL, then it is downloaded. 

181 If it is a relative path, then this method returns the absolute path to it. 

182 

183 Args: 

184 pretrained (str, optional): The location (URL or filepath) of a pretrained model. If it is a relative path, then it is relative to the current working directory. Defaults to using the result of the `pretrained_location` method. 

185 reload (bool, optional): Should the pretrained model be downloaded again if it is online and already present locally. Defaults to False. 

186 

187 Raises: 

188 FileNotFoundError: If the file cannot be located in the local environment. 

189 

190 Returns: 

191 Path: The absolute path to the model on the local filesystem. 

192 """ 

193 if pretrained: 

194 location = pretrained 

195 base_dir = Path.cwd() 

196 else: 

197 location = str(call_func(self.pretrained_location, **kwargs)) 

198 module = inspect.getmodule(self) 

199 base_dir = Path(module.__file__).parent.resolve() 

200 

201 if not location: 

202 raise FileNotFoundError(f"Please pass in a pretrained model.") 

203 

204 # Check if needs to be downloaded 

205 location = str(location) 

206 if location.startswith("http"): 

207 name = location.split("/")[-1] 

208 extension_location = name.rfind(".") 

209 if extension_location: 

210 name_stem = name[:extension_location] 

211 extension = name[extension_location:] 

212 else: 

213 name_stem = name 

214 extension = ".dat" 

215 url_hash = hashlib.md5(location.encode()).hexdigest() 

216 path = self.cache_dir()/f"{name_stem}-{url_hash}{extension}" 

217 cached_download(location, path, force=reload) 

218 else: 

219 path = Path(location) 

220 if not path.is_absolute(): 

221 path = base_dir / path 

222 

223 if not path or not path.is_file(): 

224 raise FileNotFoundError(f"Cannot find pretrained model at '{path}'") 

225 

226 return path 

227 

228 def prepare_source(self, data): 

229 return data 

230 

231 def inference_dataloader(self, learner, **kwargs): 

232 dataloader = learner.dls.test_dl(**kwargs) 

233 return dataloader 

234 

235 def validate( 

236 self, 

237 gpu: bool = Param(True, help="Whether or not to use a GPU for processing if available."), 

238 **kwargs, 

239 ): 

240 path = call_func(self.pretrained_local_path, **kwargs) 

241 

242 # Check if CUDA is available 

243 gpu = gpu and torch.cuda.is_available() 

244 

245 try: 

246 learner = load_learner(path, cpu=not gpu) 

247 except Exception: 

248 import dill 

249 learner = load_learner(path, cpu=not gpu, pickle_module=dill) 

250 

251 # Create a dataloader for inference 

252 dataloaders = call_func(self.dataloaders, **kwargs) 

253 

254 table = Table(title="Validation", box=SIMPLE) 

255 

256 values = learner.validate(dl=dataloaders.valid) 

257 names = [learner.recorder.loss.name] + [metric.name for metric in learner.metrics] 

258 result = {name: value for name, value in zip(names, values)} 

259 

260 table.add_column("Metric", justify="right", style="cyan", no_wrap=True) 

261 table.add_column("Value", style="magenta") 

262 

263 for name, value in result.items(): 

264 table.add_row(name, str(value)) 

265 

266 console.print(table) 

267 

268 return result 

269 

270 def __call__( 

271 self, 

272 gpu: bool = Param(True, help="Whether or not to use a GPU for processing if available."), 

273 **kwargs 

274 ): 

275 # Check if CUDA is available 

276 gpu = gpu and torch.cuda.is_available() 

277 

278 # Open the exported learner from a pickle file 

279 path = call_func(self.pretrained_local_path, **kwargs) 

280 learner = self.learner_obj = load_learner(path, cpu=not gpu) 

281 

282 # Create a dataloader for inference 

283 dataloader = call_func(self.inference_dataloader, learner, **kwargs) 

284 

285 inference_callbacks = call_func(self.inference_callbacks, **kwargs) 

286 

287 results = learner.get_preds(dl=dataloader, reorder=False, with_decoded=False, act=self.activation(), cbs=inference_callbacks) 

288 

289 # Output results 

290 output_results = call_func(self.output_results, results, **kwargs) 

291 return output_results if output_results is not None else results 

292 

293 def inference_callbacks(self): 

294 return None 

295 

296 @classmethod 

297 def main(cls, inference_only:bool=False): 

298 """ 

299 Creates an instance of this class and runs the command-line interface. 

300 """ 

301 cli = cls.click(inference_only=inference_only) 

302 return cli() 

303 

304 @classmethod 

305 def inference_only_main(cls): 

306 """ 

307 Creates an instance of this class and runs the command-line interface for only the inference command. 

308 """ 

309 return cls.main(inference_only=True) 

310 

311 @classmethod 

312 def click(cls, inference_only:bool=False): 

313 """ 

314 Creates an instance of this class and returns the click object for the command-line interface. 

315 """ 

316 self = cls() 

317 cli = self.cli(inference_only=inference_only) 

318 return cli 

319 

320 @classmethod 

321 def inference_only_click(cls): 

322 """ 

323 Creates an instance of this class and returns the click object for the command-line interface. 

324 """ 

325 return cls.click(inference_only=True) 

326 

327 def assert_initialized(self): 

328 """ 

329 Asserts that this app has been initialized. 

330 

331 All sub-classes of TorchApp need to call super().__init__() if overriding the __init__() function. 

332 

333 Raises: 

334 TorchAppInitializationError: If the app has not been properly initialized. 

335 """ 

336 if not self.torchapp_initialized: 

337 raise TorchAppInitializationError( 

338 """The initialization function for this TorchApp has not been called. 

339 Please ensure sub-classes of TorchApp call 'super().__init__()'""" 

340 ) 

341 

342 def version(self, verbose: bool = False): 

343 """ 

344 Prints the version of the package that defines this app. 

345 

346 Used in the command-line interface. 

347 

348 Args: 

349 verbose (bool, optional): Whether or not to print to stdout. Defaults to False. 

350 

351 Raises: 

352 Exception: If it cannot find the package. 

353 

354 """ 

355 if verbose: 

356 from importlib import metadata 

357 

358 module = inspect.getmodule(self) 

359 package = "" 

360 if module.__package__: 

361 package = module.__package__.split('.')[0] 

362 else: 

363 path = Path(module.__file__).parent 

364 while path.name: 

365 try: 

366 if metadata.distribution(path.name): 

367 package = path.name 

368 break 

369 except Exception: 

370 pass 

371 path = path.parent 

372 

373 if package: 

374 version = metadata.version(package) 

375 print(version) 

376 else: 

377 raise Exception("Cannot find package.") 

378 

379 raise typer.Exit() 

380 

381 def cli(self, inference_only:bool=False): 

382 """ 

383 Returns a 'Click' object which defines the command-line interface of the app. 

384 """ 

385 self.assert_initialized() 

386 

387 cli = typer.Typer() 

388 

389 @cli.callback() 

390 def base_callback( 

391 version: Optional[bool] = typer.Option( 

392 None, 

393 "--version", 

394 "-v", 

395 callback=self.version, 

396 is_eager=True, 

397 help="Prints the current version.", 

398 ), 

399 ): 

400 pass 

401 

402 typer_click_object = typer.main.get_command(cli) 

403 

404 params, _, _ = get_params_convertors_ctx_param_name_from_function(self.infer_cli) 

405 command = click.Command( 

406 name="infer", 

407 callback=self.infer_cli, 

408 params=params, 

409 ) 

410 if inference_only: 

411 return command 

412 typer_click_object.add_command(command) 

413 

414 

415 train_params, _, _ = get_params_convertors_ctx_param_name_from_function(self.train_cli) 

416 train_command = click.Command( 

417 name="train", 

418 callback=self.train_cli, 

419 params=train_params, 

420 ) 

421 typer_click_object.add_command(train_command) 

422 

423 export_params, _, _ = get_params_convertors_ctx_param_name_from_function(self.export_cli) 

424 export_command = click.Command( 

425 name="export", 

426 callback=self.export_cli, 

427 params=export_params, 

428 ) 

429 typer_click_object.add_command(export_command) 

430 

431 show_batch_params, _, _ = get_params_convertors_ctx_param_name_from_function(self.show_batch_cli) 

432 command = click.Command( 

433 name="show-batch", 

434 callback=self.show_batch_cli, 

435 params=show_batch_params, 

436 ) 

437 typer_click_object.add_command(command) 

438 

439 params, _, _ = get_params_convertors_ctx_param_name_from_function(self.tune_cli) 

440 tuning_params = self.tuning_params() 

441 for param in params: 

442 if param.name in tuning_params: 

443 param.default = None 

444 command = click.Command( 

445 name="tune", 

446 callback=self.tune_cli, 

447 params=params, 

448 ) 

449 typer_click_object.add_command(command) 

450 

451 params, _, _ = get_params_convertors_ctx_param_name_from_function(self.validate_cli) 

452 command = click.Command( 

453 name="validate", 

454 callback=self.validate_cli, 

455 params=params, 

456 ) 

457 typer_click_object.add_command(command) 

458 

459 params, _, _ = get_params_convertors_ctx_param_name_from_function(self.lr_finder_cli) 

460 command = click.Command( 

461 name="lr-finder", 

462 callback=self.lr_finder_cli, 

463 params=params, 

464 ) 

465 typer_click_object.add_command(command) 

466 

467 command = click.Command( 

468 name="bibliography", 

469 callback=self.print_bibliography, 

470 ) 

471 typer_click_object.add_command(command) 

472 

473 command = click.Command( 

474 name="bibtex", 

475 callback=self.print_bibtex, 

476 ) 

477 typer_click_object.add_command(command) 

478 

479 return typer_click_object 

480 

481 def tuning_params(self): 

482 tuning_params = {} 

483 signature = inspect.signature(self.tune_cli) 

484 

485 for key, value in signature.parameters.items(): 

486 default_value = value.default 

487 if isinstance(default_value, Param) and default_value.tune == True: 

488 

489 # Override annotation if given in typing hints 

490 if value.annotation: 

491 default_value.annotation = value.annotation 

492 

493 default_value.check_choices() 

494 

495 tuning_params[key] = default_value 

496 

497 return tuning_params 

498 

499 def dataloaders(self): 

500 raise NotImplementedError( 

501 f"Please ensure that the 'dataloaders' method is implemented in {self.__class__.__name__}." 

502 ) 

503 

504 def model(self) -> nn.Module: 

505 raise NotImplementedError(f"Please ensure that the 'model' method is implemented in {self.__class__.__name__}.") 

506 

507 def build_learner_func(self): 

508 return Learner 

509 

510 def learner( 

511 self, 

512 fp16: bool = Param( 

513 default=True, 

514 help="Whether or not the floating-point precision of learner should be set to 16 bit.", 

515 ), 

516 **kwargs, 

517 ) -> Learner: 

518 """ 

519 Creates a fastai learner object. 

520 """ 

521 console.print("Building dataloaders", style="bold") 

522 dataloaders = call_func(self.dataloaders, **kwargs) 

523 

524 # Allow the dataloaders to go to GPU so long as it hasn't explicitly been set as a different device 

525 if dataloaders.device is None: 

526 dataloaders.cuda() # This will revert to CPU if cuda is not available 

527 

528 console.print("Building model", style="bold") 

529 model = call_func(self.model, **kwargs) 

530 

531 console.print("Building learner", style="bold") 

532 learner_kwargs = call_func(self.learner_kwargs, **kwargs) 

533 build_learner_func = self.build_learner_func() 

534 learner = build_learner_func( 

535 dataloaders, 

536 model, 

537 **learner_kwargs, 

538 ) 

539 

540 learner.training_kwargs = kwargs 

541 

542 if fp16: 

543 console.print("Setting floating-point precision of learner to 16 bit", style="red") 

544 learner = learner.to_fp16() 

545 

546 # Save a pointer to the learner 

547 self.learner_obj = learner 

548 

549 return learner 

550 

551 def learner_kwargs( 

552 self, 

553 output_dir: Path = Param("./outputs", help="The location of the output directory."), 

554 weight_decay: float = Param( 

555 None, help="The amount of weight decay. If None then it uses the default amount of weight decay in fastai." 

556 ), 

557 # l2_regularization: bool = Param(False, help="Whether to add decay to the gradients (L2 regularization) instead of to the weights directly (weight decay)."), 

558 **kwargs, 

559 ): 

560 self.output_dir = Path(output_dir) 

561 self.output_dir.mkdir(exist_ok=True, parents=True) 

562 

563 return dict( 

564 loss_func=call_func(self.loss_func, **kwargs), 

565 metrics=call_func(self.metrics, **kwargs), 

566 path=self.output_dir, 

567 wd=weight_decay, 

568 ) 

569 

570 def loss_func(self, **kwargs): 

571 """The loss function. If None, then fastai will use the default loss function if it exists for this model.""" 

572 return None 

573 

574 def activation(self): 

575 """The activation for the last layer. If None, then fastai will use the default activiation of the loss if it exists.""" 

576 return None 

577 

578 def metrics(self) -> List: 

579 """ 

580 The list of metrics to use with this app. 

581 

582 By default this list is empty. This method should be subclassed to add metrics in child classes of TorchApp. 

583 

584 Returns: 

585 List: The list of metrics. 

586 """ 

587 return [] 

588 

589 def monitor(self) -> str: 

590 """ 

591 The metric to optimize for when performing hyperparameter tuning. 

592 

593 By default it returns 'valid_loss'. 

594 """ 

595 return "valid_loss" 

596 

597 def goal(self) -> str: 

598 """ 

599 Sets the optimality direction when evaluating the metric from `monitor`. 

600 

601 By default it produces the same behaviour as fastai callbacks (fastai.callback.tracker) 

602 ie. it is set to "minimize" if the monitor metric has the string 'loss' or 'err' otherwise it is "maximize". 

603 

604 If the monitor is empty then this function returns None. 

605 """ 

606 monitor = self.monitor() 

607 if not monitor or not isinstance(monitor, str): 

608 return None 

609 

610 return "minimize" if ("loss" in monitor) or ("err" in monitor) else "maximize" 

611 

612 def callbacks( 

613 self, 

614 project_name: str = Param(default=None, help="The name for this project for logging purposes."), 

615 run_name: str = Param(default=None, help="The name for this particular run for logging purposes."), 

616 run_id: str = Param(default=None, help="A unique ID for this particular run for logging purposes."), 

617 notes: str = Param(None, help="A longer description of the run for logging purposes."), 

618 tag: List[str] = Param( 

619 None, help="A tag for logging purposes. Multiple tags can be added each introduced with --tag." 

620 ), 

621 wandb: bool = Param(default=False, help="Whether or not to use 'Weights and Biases' for logging."), 

622 wandb_mode: str = Param(default="online", help="The mode for 'Weights and Biases'."), 

623 wandb_dir: Path = Param(None, help="The location for 'Weights and Biases' output."), 

624 wandb_entity: str = Param(None, help="An entity is a username or team name where you're sending runs."), 

625 wandb_group: str = Param(None, help="Specify a group to organize individual runs into a larger experiment."), 

626 wandb_job_type: str = Param( 

627 None, 

628 help="Specify the type of run, which is useful when you're grouping runs together into larger experiments using group.", 

629 ), 

630 mlflow: bool = Param(default=False, help="Whether or not to use MLflow for logging."), 

631 **kwargs, 

632 ) -> List: 

633 """ 

634 The list of callbacks to use with this app in the fastai training loop. 

635 

636 Args: 

637 project_name (str): The name for this project for logging purposes. If no name is given then the name of the app is used. 

638 

639 Returns: 

640 list: The list of callbacks. 

641 """ 

642 callbacks = [CSVLogger()] 

643 monitor = self.monitor() 

644 if monitor: 

645 callbacks.append(SaveModelCallback(monitor=monitor)) 

646 

647 if wandb: 

648 callback = TorchAppWandbCallback( 

649 app=self, 

650 project_name=project_name, 

651 id=run_id, 

652 name=run_name, 

653 mode=wandb_mode, 

654 dir=wandb_dir, 

655 entity=wandb_entity, 

656 group=wandb_group, 

657 job_type=wandb_job_type, 

658 notes=notes, 

659 tags=tag, 

660 ) 

661 callbacks.append(callback) 

662 self.add_bibtex_file(bibtex_dir / "wandb.bib") # this should be in the callback 

663 

664 if mlflow: 

665 callbacks.append(TorchAppMlflowCallback(app=self, experiment_name=project_name)) 

666 self.add_bibtex_file(bibtex_dir / "mlflow.bib") # this should be in the callback 

667 

668 extra_callbacks = call_func(self.extra_callbacks, **kwargs) 

669 if extra_callbacks: 

670 callbacks += extra_callbacks 

671 

672 return callbacks 

673 

674 def extra_callbacks(self): 

675 return None 

676 

677 def show_batch( 

678 self, output_path: Path = Param(None, help="A location to save the HTML which summarizes the batch."), **kwargs 

679 ): 

680 dataloaders = call_func(self.dataloaders, **kwargs) 

681 

682 # patch the display function of ipython so we can capture the HTML 

683 def mock_display(html_object): 

684 self.batch_html = html_object 

685 

686 import IPython.display 

687 

688 ipython_display = IPython.display.display 

689 IPython.display.display = mock_display 

690 

691 dataloaders.show_batch() 

692 batch_html = getattr(self, "batch_html", None) 

693 

694 if not batch_html: 

695 console.print(f"Cannot display batch as HTML") 

696 return 

697 

698 html = batch_html.data 

699 

700 # Write output 

701 if output_path: 

702 console.print(f"Writing batch HTML to '{output_path}'") 

703 with open(output_path, 'w') as f: 

704 f.write(html) 

705 else: 

706 console.print(html) 

707 console.print(f"To write this HTML output to a file, give an output path.") 

708 

709 # restore the ipython display function 

710 IPython.display.display = ipython_display 

711 return self.batch_html 

712 

713 def train( 

714 self, 

715 distributed: bool = Param(default=False, help="If the learner is distributed."), 

716 **kwargs, 

717 ) -> Learner: 

718 """ 

719 Trains a model for this app. 

720 

721 Args: 

722 distributed (bool, optional): If the learner is distributed. Defaults to Param(default=False, help="If the learner is distributed."). 

723 

724 Returns: 

725 Learner: The fastai Learner object created for training. 

726 """ 

727 self.training_kwargs = kwargs 

728 self.training_kwargs['distributed'] = distributed 

729 

730 callbacks = call_func(self.callbacks, **kwargs) 

731 learner = call_func(self.learner, **kwargs) 

732 

733 self.print_bibliography(verbose=True) 

734 

735 # with learner.distrib_ctx() if distributed == True else nullcontext(): 

736 call_func(self.fit, learner, callbacks, **kwargs) 

737 

738 learner.export() 

739 

740 return learner 

741 

742 def export(self, model_path:Path, **kwargs): 

743 """  

744 Generates a learner, saves model weights from a file and exports the learner so that it can be used for inference. 

745  

746 This is useful if a run has not reached completion but the model weights have still been saved. 

747 """ 

748 # Run the callbacks to ensure that everything is initialized the same as running the training loop 

749 call_func(self.callbacks, **kwargs) 

750 learner = call_func(self.learner, **kwargs) 

751 load_model(model_path, learner.model, opt=None, with_opt=False, device=learner.dls.device, strict=True) 

752 learner.export() 

753 return learner 

754 

755 def fit( 

756 self, 

757 learner, 

758 callbacks, 

759 epochs: int = Param(default=20, help="The number of epochs."), 

760 freeze_epochs: int = Param( 

761 default=3, 

762 help="The number of epochs to train when the learner is frozen and the last layer is trained by itself. Only if `fine_tune` is set on the app.", 

763 ), 

764 learning_rate: float = Param( 

765 default=1e-4, help="The base learning rate (when fine tuning) or the max learning rate otherwise." 

766 ), 

767 **kwargs, 

768 ): 

769 if self.fine_tune: 

770 return learner.fine_tune( 

771 epochs, freeze_epochs=freeze_epochs, base_lr=learning_rate, cbs=callbacks, **kwargs 

772 ) # hack 

773 

774 return learner.fit_one_cycle(epochs, lr_max=learning_rate, cbs=callbacks, **kwargs) 

775 

776 def project_name(self) -> str: 

777 """ 

778 The name to use for a project for logging purposes. 

779 

780 The default is to use the class name. 

781 """ 

782 return self.__class__.__name__ 

783 

784 def tune( 

785 self, 

786 runs: int = Param(default=1, help="The number of runs to attempt to train the model."), 

787 engine: str = Param( 

788 default="skopt", 

789 help="The optimizer to use to perform the hyperparameter tuning. Options: wandb, optuna, skopt.", 

790 ), # should be enum 

791 id: str = Param( 

792 default="", 

793 help="The ID of this hyperparameter tuning job. " 

794 "If using wandb, then this is the sweep id. " 

795 "If using optuna, then this is the storage. " 

796 "If using skopt, then this is the file to store the results. ", 

797 ), 

798 name: str = Param( 

799 default="", 

800 help="An informative name for this hyperparameter tuning job. If empty, then it creates a name from the project name.", 

801 ), 

802 method: str = Param( 

803 default="", help="The sampling method to use to perform the hyperparameter tuning. By default it chooses the default method of the engine." 

804 ), # should be enum 

805 min_iter: int = Param( 

806 default=None, 

807 help="The minimum number of iterations if using early termination. If left empty, then early termination is not used.", 

808 ), 

809 seed: int = Param( 

810 default=None, 

811 help="A seed for the random number generator.", 

812 ), 

813 **kwargs, 

814 ): 

815 if not name: 

816 name = f"{self.project_name()}-tuning" 

817 

818 if engine == "wandb": 

819 from .tuning.wandb import wandb_tune 

820 

821 self.add_bibtex_file(bibtex_dir / "wandb.bib") 

822 

823 return wandb_tune( 

824 self, 

825 runs=runs, 

826 sweep_id=id, 

827 name=name, 

828 method=method, 

829 min_iter=min_iter, 

830 **kwargs, 

831 ) 

832 elif engine == "optuna": 

833 from .tuning.optuna import optuna_tune 

834 

835 self.add_bibtex_file(bibtex_dir / "optuna.bib") 

836 

837 return optuna_tune( 

838 self, 

839 runs=runs, 

840 storage=id, 

841 name=name, 

842 method=method, 

843 seed=seed, 

844 **kwargs, 

845 ) 

846 elif engine in ["skopt", "scikit-optimize"]: 

847 from .tuning.skopt import skopt_tune 

848 

849 self.add_bibtex_file(bibtex_dir / "skopt.bib") 

850 

851 return skopt_tune( 

852 self, 

853 runs=runs, 

854 file=id, 

855 name=name, 

856 method=method, 

857 seed=seed, 

858 **kwargs, 

859 ) 

860 else: 

861 raise NotImplementedError(f"Optimizer engine {engine} not implemented.") 

862 

863 def get_best_metric(self, learner): 

864 # The slice is there because 'epoch' is prepended to the list but it isn't included in the values 

865 metric_index = learner.recorder.metric_names[1:].index(self.monitor()) 

866 metric_values = list(map(lambda row: row[metric_index], learner.recorder.values)) 

867 metric_function = min if self.goal()[:3] == "min" else max 

868 metric_value = metric_function(metric_values) 

869 return metric_value 

870 

871 def one_batch_size(self, **kwargs): 

872 dls = call_func(self.dataloaders, **kwargs) 

873 batch = dls.train.one_batch() 

874 return batch[0].size() 

875 

876 def one_batch_output(self, **kwargs): 

877 learner = call_func(self.learner, **kwargs) 

878 batch = learner.dls.train.one_batch() 

879 n_inputs = getattr(learner.dls, 'n_inp', 1 if len(batch) == 1 else len(batch) - 1) 

880 batch_x = batch[:n_inputs] 

881 

882 learner.model.to(batch_x[0].device) 

883 with torch.no_grad(): 

884 output = learner.model(*batch_x) 

885 return output 

886 

887 def one_batch_output_size(self, **kwargs): 

888 output = self.one_batch_output(**kwargs) 

889 return output.size() 

890 

891 def one_batch_loss(self, **kwargs): 

892 learner = call_func(self.learner, **kwargs) 

893 batch = learner.dls.train.one_batch() 

894 n_inputs = getattr(learner.dls, 'n_inp', 1 if len(batch) == 1 else len(batch) - 1) 

895 batch_x = batch[:n_inputs] 

896 batch_y = batch[n_inputs:] 

897 

898 learner.model.to(batch_x[0].device) 

899 with torch.no_grad(): 

900 output = learner.model(*batch_x) 

901 loss = learner.loss_func(output, *batch_y) 

902 

903 return loss 

904 

905 def lr_finder( 

906 self, plot_filename: Path = None, start_lr: float = 1e-07, end_lr: float = 10, iterations: int = 100, **kwargs 

907 ): 

908 learner = call_func(self.learner, **kwargs) 

909 

910 from matplotlib import pyplot as plt 

911 from fastai.callback.schedule import SuggestionMethod 

912 

913 suggest_funcs = ( 

914 SuggestionMethod.Valley, 

915 SuggestionMethod.Minimum, 

916 SuggestionMethod.Slide, 

917 SuggestionMethod.Steep, 

918 ) 

919 

920 result = learner.lr_find( 

921 stop_div=False, 

922 num_it=iterations, 

923 start_lr=start_lr, 

924 end_lr=end_lr, 

925 show_plot=plot_filename is not None, 

926 suggest_funcs=suggest_funcs, 

927 ) 

928 

929 if plot_filename is not None: 

930 plt.savefig(str(plot_filename)) 

931 

932 print("\n") 

933 table = Table(title="Suggested Learning Rates", box=SIMPLE) 

934 

935 table.add_column("Method", style="cyan", no_wrap=True) 

936 table.add_column("Learning Rate", style="magenta") 

937 table.add_column("Explanation") 

938 

939 for method, value in zip(suggest_funcs, result): 

940 table.add_row(method.__name__, str(value), method.__doc__) 

941 

942 console.print(table) 

943 

944 return result 

945 

946 def cache_dir(self) -> Path: 

947 """ Returns a path to a directory where data files for this app can be cached. """ 

948 cache_dir = Path(user_cache_dir("torchapps"))/self.__class__.__name__ 

949 cache_dir.mkdir(exist_ok=True, parents=True) 

950 return cache_dir 

951 

952 def output_results( 

953 self, 

954 results, 

955 **kwargs 

956 ): 

957 print(results) 

958 return results