Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from typing import List 

2from pathlib import Path 

3import torch 

4from fastai.data.block import DataBlock, CategoryBlock 

5from fastai.data.transforms import ColReader, RandomSplitter, DisplayedTransform, ColSplitter, get_image_files 

6from fastai.metrics import accuracy 

7from fastai.vision.data import ImageBlock 

8from fastai.vision.augment import Resize, ResizeMethod 

9import pandas as pd 

10import torchapp as ta 

11from fastai.vision.augment import aug_transforms 


13from torchapp.vision import VisionApp 

14from rich.console import Console 

15console = Console() 



18class PathColReader(DisplayedTransform): 

19 def __init__(self, column_name: str, base_dir: Path): 

20 self.column_name = column_name 

21 self.base_dir = base_dir 


23 def __call__(self, row, **kwargs): 

24 path = Path(row[self.column_name]) 

25 if not path.is_absolute(): 

26 path = self.base_dir / path 

27 return path 



30class ImageClassifier(VisionApp): 

31 """ 

32 A TorchApp for classifying images. 


34 For training, it expects a CSV with image paths and categories. 

35 """ 


37 def dataloaders( 

38 self, 

39 csv: Path = ta.Param(default=None, help="A CSV with image paths and categories."), 

40 image_column: str = ta.Param(default="image", help="The name of the column with the image paths."), 

41 category_column: str = ta.Param( 

42 default="category", help="The name of the column with the category of the image." 

43 ), 

44 base_dir: Path = ta.Param(default=None, help="The base directory for images with relative paths. If not given, then it is relative to the csv directory."), 

45 validation_column: str = ta.Param( 

46 default="validation", 

47 help="The column in the dataset to use for validation. " 

48 "If the column is not in the dataset, then a validation set will be chosen randomly according to `validation_proportion`.", 

49 ), 

50 validation_value: str = ta.Param( 

51 default=None, 

52 help="If set, then the value in the `validation_column` must equal this string for the item to be in the validation set. " 

53 ), 

54 validation_proportion: float = ta.Param( 

55 default=0.2, 

56 help="The proportion of the dataset to keep for validation. Used if `validation_column` is not in the dataset.", 

57 ), 

58 batch_size: int = ta.Param(default=16, help="The number of items to use in each batch."), 

59 width: int = ta.Param(default=224, help="The width to resize all the images to."), 

60 height: int = ta.Param(default=224, help="The height to resize all the images to."), 

61 resize_method: str = ta.Param(default="squish", help="The method to resize images."), 

62 max_lighting:float=0.0, 

63 max_rotate:float=0.0, 

64 max_warp:float=0.0, 

65 max_zoom:float=1.0, 

66 do_flip:bool=False, 

67 p_affine:float=0.75, 

68 p_lighting:float=0.75, 

69 ): 

70 df = pd.read_csv(csv) 


72 base_dir = base_dir or Path(csv).parent 


74 # Create splitter for training/validation images 

75 if validation_value is not None: 

76 validation_column_new = f"{validation_column} is {validation_value}" 

77 df[validation_column_new] = df[validation_column].astype(str) == validation_value 

78 validation_column = validation_column_new 


80 if validation_column and validation_column in df: 

81 splitter = ColSplitter(validation_column) 

82 else: 

83 splitter = RandomSplitter(validation_proportion) 


85 batch_transforms = aug_transforms( 

86 p_lighting=p_lighting, 

87 p_affine=p_affine, 

88 max_rotate=max_rotate, 

89 do_flip=do_flip, 

90 max_lighting=max_lighting, 

91 max_zoom=max_zoom, 

92 max_warp=max_warp, 

93 pad_mode='zeros', 

94 ) 


96 datablock = DataBlock( 

97 blocks=[ImageBlock, CategoryBlock], 

98 get_x=PathColReader(column_name=image_column, base_dir=base_dir), 

99 get_y=ColReader(category_column), 

100 splitter=splitter, 

101 item_tfms=Resize((height, width), method=resize_method), 

102 batch_tfms=batch_transforms, 

103 ) 


105 return datablock.dataloaders(df, bs=batch_size) 


107 def metrics(self): 

108 return [accuracy] 


110 def monitor(self): 

111 return "accuracy" 


113 def inference_dataloader( 

114 self, 

115 learner, 

116 items:List[Path] = None, 

117 csv: Path = ta.Param(default=None, help="A CSV with image paths."), 

118 image_column: str = ta.Param(default="image", help="The name of the column with the image paths."), 

119 base_dir: Path = ta.Param(default="./", help="The base directory for images with relative paths."), 

120 **kwargs 

121 ): 

122 self.items = [] 

123 if isinstance(items, (Path, str)): 

124 self.items.append(Path(items)) 

125 else: 

126 try: 

127 for item in items: 

128 item = Path(item) 

129 # If the item is a directory then get all images in that directory 

130 if item.is_dir(): 

131 self.items.extend( get_image_files(item) ) 

132 else: 

133 self.items.append(item) 

134 except: 

135 raise ValueError(f"Cannot interpret list of items.") 


137 # Read CSV if available 

138 if csv is not None: 

139 df = pd.read_csv(csv) 

140 for _, row in df.iterrows(): 

141 self.items.append(Path(row[image_column])) 


143 if not self.items: 

144 raise ValueError(f"No items found.") 


146 # Set relative to base dir 

147 if base_dir: 

148 base_dir = Path(base_dir) 


150 self.items = [base_dir / item if not item.is_absolute() else item for item in self.items] 


152 return learner.dls.test_dl(self.items, **kwargs) 


154 def output_results( 

155 self, 

156 results, 

157 output_csv:Path = ta.Param(None, help="Path to write predictions in CSV format"), 

158 verbose:bool = True, 

159 **kwargs 

160 ): 

161 data = [] 

162 vocab = self.learner_obj.dls.vocab 

163 for item, scores in zip(self.items, results[0]): 

164 probabilities = torch.softmax(torch.as_tensor(scores), dim=-1) 

165 prediction = vocab[torch.argmax(probabilities)] 

166 if verbose: 

167 console.print(f"'{item}': '{prediction}'") 

168 data.append( [item,prediction] + probabilities.tolist() ) 


170 df = pd.DataFrame(data, columns=["path","prediction"]+list(vocab)) 

171 if output_csv: 

172 df.to_csv(output_csv) 


174 if verbose: 

175 console.print(df) 


177 return df 



180if __name__ == "__main__": 

181 ImageClassifier.main()