Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from typing import List
2from pathlib import Path
3import torch
4from fastai.data.block import DataBlock, CategoryBlock
5from fastai.data.transforms import ColReader, RandomSplitter, DisplayedTransform, ColSplitter, get_image_files
6from fastai.metrics import accuracy
7from fastai.vision.data import ImageBlock
8from fastai.vision.augment import Resize, ResizeMethod
9import pandas as pd
10import torchapp as ta
11from fastai.vision.augment import aug_transforms
13from torchapp.vision import VisionApp
14from rich.console import Console
15console = Console()
18class PathColReader(DisplayedTransform):
19 def __init__(self, column_name: str, base_dir: Path):
20 self.column_name = column_name
21 self.base_dir = base_dir
23 def __call__(self, row, **kwargs):
24 path = Path(row[self.column_name])
25 if not path.is_absolute():
26 path = self.base_dir / path
27 return path
30class ImageClassifier(VisionApp):
31 """
32 A TorchApp for classifying images.
34 For training, it expects a CSV with image paths and categories.
35 """
37 def dataloaders(
38 self,
39 csv: Path = ta.Param(default=None, help="A CSV with image paths and categories."),
40 image_column: str = ta.Param(default="image", help="The name of the column with the image paths."),
41 category_column: str = ta.Param(
42 default="category", help="The name of the column with the category of the image."
43 ),
44 base_dir: Path = ta.Param(default=None, help="The base directory for images with relative paths. If not given, then it is relative to the csv directory."),
45 validation_column: str = ta.Param(
46 default="validation",
47 help="The column in the dataset to use for validation. "
48 "If the column is not in the dataset, then a validation set will be chosen randomly according to `validation_proportion`.",
49 ),
50 validation_value: str = ta.Param(
51 default=None,
52 help="If set, then the value in the `validation_column` must equal this string for the item to be in the validation set. "
53 ),
54 validation_proportion: float = ta.Param(
55 default=0.2,
56 help="The proportion of the dataset to keep for validation. Used if `validation_column` is not in the dataset.",
57 ),
58 batch_size: int = ta.Param(default=16, help="The number of items to use in each batch."),
59 width: int = ta.Param(default=224, help="The width to resize all the images to."),
60 height: int = ta.Param(default=224, help="The height to resize all the images to."),
61 resize_method: str = ta.Param(default="squish", help="The method to resize images."),
62 max_lighting:float=0.0,
63 max_rotate:float=0.0,
64 max_warp:float=0.0,
65 max_zoom:float=1.0,
66 do_flip:bool=False,
67 p_affine:float=0.75,
68 p_lighting:float=0.75,
69 ):
70 df = pd.read_csv(csv)
72 base_dir = base_dir or Path(csv).parent
74 # Create splitter for training/validation images
75 if validation_value is not None:
76 validation_column_new = f"{validation_column} is {validation_value}"
77 df[validation_column_new] = df[validation_column].astype(str) == validation_value
78 validation_column = validation_column_new
80 if validation_column and validation_column in df:
81 splitter = ColSplitter(validation_column)
82 else:
83 splitter = RandomSplitter(validation_proportion)
85 batch_transforms = aug_transforms(
86 p_lighting=p_lighting,
87 p_affine=p_affine,
88 max_rotate=max_rotate,
89 do_flip=do_flip,
90 max_lighting=max_lighting,
91 max_zoom=max_zoom,
92 max_warp=max_warp,
93 pad_mode='zeros',
94 )
96 datablock = DataBlock(
97 blocks=[ImageBlock, CategoryBlock],
98 get_x=PathColReader(column_name=image_column, base_dir=base_dir),
99 get_y=ColReader(category_column),
100 splitter=splitter,
101 item_tfms=Resize((height, width), method=resize_method),
102 batch_tfms=batch_transforms,
103 )
105 return datablock.dataloaders(df, bs=batch_size)
107 def metrics(self):
108 return [accuracy]
110 def monitor(self):
111 return "accuracy"
113 def inference_dataloader(
114 self,
115 learner,
116 items:List[Path] = None,
117 csv: Path = ta.Param(default=None, help="A CSV with image paths."),
118 image_column: str = ta.Param(default="image", help="The name of the column with the image paths."),
119 base_dir: Path = ta.Param(default="./", help="The base directory for images with relative paths."),
120 **kwargs
121 ):
122 self.items = []
123 if isinstance(items, (Path, str)):
124 self.items.append(Path(items))
125 else:
126 try:
127 for item in items:
128 item = Path(item)
129 # If the item is a directory then get all images in that directory
130 if item.is_dir():
131 self.items.extend( get_image_files(item) )
132 else:
133 self.items.append(item)
134 except:
135 raise ValueError(f"Cannot interpret list of items.")
137 # Read CSV if available
138 if csv is not None:
139 df = pd.read_csv(csv)
140 for _, row in df.iterrows():
141 self.items.append(Path(row[image_column]))
143 if not self.items:
144 raise ValueError(f"No items found.")
146 # Set relative to base dir
147 if base_dir:
148 base_dir = Path(base_dir)
150 self.items = [base_dir / item if not item.is_absolute() else item for item in self.items]
152 return learner.dls.test_dl(self.items, **kwargs)
154 def output_results(
155 self,
156 results,
157 output_csv:Path = ta.Param(None, help="Path to write predictions in CSV format"),
158 verbose:bool = True,
159 **kwargs
160 ):
161 data = []
162 vocab = self.learner_obj.dls.vocab
163 for item, scores in zip(self.items, results[0]):
164 probabilities = torch.softmax(torch.as_tensor(scores), dim=-1)
165 prediction = vocab[torch.argmax(probabilities)]
166 if verbose:
167 console.print(f"'{item}': '{prediction}'")
168 data.append( [item,prediction] + probabilities.tolist() )
170 df = pd.DataFrame(data, columns=["path","prediction"]+list(vocab))
171 if output_csv:
172 df.to_csv(output_csv)
174 if verbose:
175 console.print(df)
177 return df
180if __name__ == "__main__":
181 ImageClassifier.main()