Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python3 

2from pathlib import Path 

3import numpy as np 

4from fastai.tabular.data import TabularDataLoaders 

5from fastai.tabular.all import tabular_learner, accuracy, error_rate 

6from sklearn.datasets import load_iris 

7import torchapp as ta 

8 

9 

10class IrisApp(ta.TorchApp): 

11 """ 

12 A classification app to predict the type of iris from sepal and petal lengths and widths. 

13 

14 A classic dataset publised in: 

15 Fisher, R.A. “The Use of Multiple Measurements in Taxonomic Problems” Annals of Eugenics, 7, Part II, 179–188 (1936). 

16 For more information about the dataset, see: 

17 https://scikit-learn.org/stable/datasets/toy_dataset.html#iris-plants-dataset 

18 """ 

19 

20 def dataloaders( 

21 self, 

22 batch_size: int = ta.Param(default=32, tune_min=8, tune_max=128, log=True, tune=True), 

23 ): 

24 df = load_iris(as_frame=True) 

25 

26 df["frame"]["target_name"] = np.take(df["target_names"], df["target"]) 

27 

28 return TabularDataLoaders.from_df( 

29 df["frame"], 

30 cont_names=df["feature_names"], 

31 y_names="target_name", 

32 bs=batch_size, 

33 ) 

34 

35 def metrics(self) -> list: 

36 return [accuracy, error_rate] 

37 

38 def model(self): 

39 return None 

40 

41 def build_learner_func(self): 

42 return tabular_learner 

43 

44 def get_bibtex_files(self): 

45 files = super().get_bibtex_files() 

46 files.append(Path(__file__).parent / "iris.bib") 

47 return files 

48 

49 

50if __name__ == "__main__": 

51 IrisApp.main()