Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from pathlib import Path 

2from typing import List 

3from torch import nn 

4import pandas as pd 

5from fastai.data.core import DataLoaders 

6from torchapp.util import copy_func, call_func, change_typer_to_defaults, add_kwargs 

7from fastai.learner import Learner, load_learner 

8from fastai.metrics import accuracy, Precision, Recall, RocAuc, F1Score 

9import torchapp as ta 

10from rich.console import Console 

11from rich.table import Table 

12from rich.box import SIMPLE 

13from Bio import SeqIO 

14from Bio.SeqIO import FastaIO 

15 

16import time 

17 

18from fastai.losses import CrossEntropyLossFlat 

19 

20console = Console() 

21 

22from . import dataloaders, models, refseq, transforms 

23 

24 

25class Corgi(ta.TorchApp): 

26 """ 

27 corgi - Classifier for ORganelle Genomes Inter alia 

28 """ 

29 

30 def __init__(self): 

31 super().__init__() 

32 self.categories = refseq.REFSEQ_CATEGORIES # This will be overridden by the dataloader 

33 self.category_counts = self.copy_method(self.category_counts) 

34 add_kwargs(to_func=self.category_counts, from_funcs=self.dataloaders) 

35 self.category_counts_cli = self.copy_method(self.category_counts) 

36 change_typer_to_defaults(self.category_counts) 

37 

38 def dataloaders( 

39 self, 

40 csv: Path = ta.Param(help="The CSV which has the sequences to use."), 

41 base_dir: Path = ta.Param(help="The base directory with the RefSeq HDF5 files."), 

42 batch_size: int = ta.Param(default=32, help="The batch size."), 

43 dataloader_type: dataloaders.DataloaderType = ta.Param( 

44 default=dataloaders.DataloaderType.PLAIN, case_sensitive=False 

45 ), 

46 validation_seq_length:int = 1_000, 

47 deform_lambda:float = ta.Param(default=None, help="The lambda for the deform transform."), 

48 ) -> DataLoaders: 

49 """ 

50 Creates a FastAI DataLoaders object which Corgi uses in training and prediction. 

51 

52 Args: 

53 inputs (Path): The input file. 

54 batch_size (int): The number of elements to use in a batch for training and prediction. Defaults to 32. 

55 """ 

56 if csv is None: 

57 raise Exception("No CSV given") 

58 if base_dir is None: 

59 raise Exception("No base_dir given") 

60 dls = dataloaders.create_dataloaders_refseq_path( 

61 csv, 

62 base_dir=base_dir, 

63 batch_size=batch_size, 

64 dataloader_type=dataloader_type, 

65 deform_lambda=deform_lambda, 

66 validation_seq_length=validation_seq_length, 

67 ) 

68 self.categories = dls.vocab 

69 return dls 

70 

71 def model( 

72 self, 

73 embedding_dim: int = ta.Param( 

74 default=8, 

75 help="The size of the embeddings for the nucleotides (N, A, G, C, T).", 

76 tune=True, 

77 tune_min=4, 

78 tune_max=32, 

79 log=True, 

80 ), 

81 filters: int = ta.Param( 

82 default=256, 

83 help="The number of filters in each of the 1D convolution layers. These are concatenated together", 

84 ), 

85 cnn_layers: int = ta.Param( 

86 default=6, 

87 help="The number of 1D convolution layers.", 

88 tune=True, 

89 tune_min=2, 

90 tune_max=6, 

91 ), 

92 kernel_size_maxpool: int = ta.Param( 

93 default=2, 

94 help="The size of the pooling before going to the LSTM.", 

95 ), 

96 lstm_dims: int = ta.Param(default=256, help="The size of the hidden layers in the LSTM in both directions."), 

97 final_layer_dims: int = ta.Param( 

98 default=0, help="The size of a dense layer after the LSTM. If this is zero then this layer isn't used." 

99 ), 

100 dropout: float = ta.Param( 

101 default=0.2, 

102 help="The amount of dropout to use. (not currently enabled)", 

103 tune=True, 

104 tune_min=0.0, 

105 tune_max=0.3, 

106 ), 

107 final_bias: bool = ta.Param( 

108 default=True, 

109 help="Whether or not to use bias in the final layer.", 

110 tune=True, 

111 ), 

112 cnn_only: bool = True, 

113 kernel_size: int = ta.Param( 

114 default=3, help="The size of the kernels for CNN only classifier.", tune=True, tune_choices=[3, 5, 7, 9] 

115 ), 

116 cnn_dims_start: int = ta.Param( 

117 default=None, 

118 help="The size of the number of filters in the first CNN layer. If not set then it is derived from the MACC", 

119 ), 

120 factor: float = ta.Param( 

121 default=2.0, 

122 help="The factor to multiply the number of filters in the CNN layers each time it is downscaled.", 

123 tune=True, 

124 log=True, 

125 tune_min=0.5, 

126 tune_max=2.5, 

127 ), 

128 penultimate_dims: int = ta.Param( 

129 default=1024, 

130 help="The factor to multiply the number of filters in the CNN layers each time it is downscaled.", 

131 tune=True, 

132 log=True, 

133 tune_min=512, 

134 tune_max=2048, 

135 ), 

136 include_length: bool = False, 

137 transformer_heads: int = ta.Param(8, help="The number of heads in the transformer."), 

138 transformer_layers: int = ta.Param(0, help="The number of layers in the transformer. If zero then no transformer is used."), 

139 macc:int = ta.Param( 

140 default=10_000_000, 

141 help="The approximate number of multiply or accumulate operations in the model. Used to set cnn_dims_start if not provided explicitly.", 

142 ), 

143 ) -> nn.Module: 

144 """ 

145 Creates a deep learning model for the Corgi to use. 

146 

147 Returns: 

148 nn.Module: The created model. 

149 """ 

150 num_classes = len(self.categories) 

151 

152 # if cnn_dims_start not given then calculate it from the MACC 

153 if not cnn_dims_start: 

154 assert macc 

155 

156 cnn_dims_start = models.calc_cnn_dims_start( 

157 macc=macc, 

158 seq_len=1024, # arbitary number 

159 embedding_dim=embedding_dim, 

160 cnn_layers=cnn_layers, 

161 kernel_size=kernel_size, 

162 factor=factor, 

163 penultimate_dims=penultimate_dims, 

164 num_classes=num_classes, 

165 ) 

166 

167 if cnn_only: 

168 return models.ConvClassifier( 

169 num_embeddings=5, # i.e. the size of the vocab which is N, A, C, G, T 

170 kernel_size=kernel_size, 

171 factor=factor, 

172 cnn_layers=cnn_layers, 

173 num_classes=num_classes, 

174 kernel_size_maxpool=kernel_size_maxpool, 

175 final_bias=final_bias, 

176 dropout=dropout, 

177 cnn_dims_start=cnn_dims_start, 

178 penultimate_dims=penultimate_dims, 

179 include_length=include_length, 

180 transformer_layers=transformer_layers, 

181 transformer_heads=transformer_heads, 

182 ) 

183 

184 return models.ConvRecurrantClassifier( 

185 num_classes=num_classes, 

186 embedding_dim=embedding_dim, 

187 filters=filters, 

188 cnn_layers=cnn_layers, 

189 lstm_dims=lstm_dims, 

190 final_layer_dims=final_layer_dims, 

191 dropout=dropout, 

192 kernel_size_maxpool=kernel_size_maxpool, 

193 final_bias=final_bias, 

194 ) 

195 

196 def metrics(self): 

197 average = "macro" 

198 return [ 

199 accuracy, 

200 F1Score(average=average), 

201 Precision(average=average), 

202 Recall(average=average), 

203 RocAuc(average=average), 

204 ] 

205 

206 def monitor(self): 

207 return "f1_score" 

208 

209 # def loss_func(self, label_smoothing:float=ta.Param(0.1, help="The amount of label smoothing.")): 

210 # return CrossEntropyLossFlat(label_smoothing=label_smoothing) 

211 

212 def inference_dataloader( 

213 self, 

214 learner, 

215 file: List[Path] = ta.Param(None, help="A fasta file with sequences to be classified."), 

216 max_seqs: int = None, 

217 batch_size:int = 1, 

218 max_length:int = 5_000, 

219 min_length:int = 128, 

220 **kwargs, 

221 ): 

222 self.seqio_dataloader = dataloaders.SeqIODataloader(files=file, device=learner.dls.device, batch_size=batch_size, max_length=max_length, max_seqs=max_seqs, min_length=min_length) 

223 self.categories = learner.dls.vocab 

224 return self.seqio_dataloader 

225 

226 def output_results( 

227 self, 

228 results, 

229 output_dir:Path = ta.Param(default=None, help="A path to output the results as a CSV."), 

230 csv: Path = ta.Param(default=None, help="A path to output the results as a CSV. If not given then a default name is chosen inside the output directory."), 

231 save_filtered:bool = ta.Param(default=True, help="Whether or not to save the filtered sequences."), 

232 threshold: float = ta.Param( 

233 default=None, 

234 help="The threshold to use for filtering. " 

235 "If not given, then only the most likely category used for filtering.", 

236 ), 

237 **kwargs, 

238 ): 

239 if not output_dir: 

240 time_string = time.strftime("%Y_%m_%d-%I_%M_%S_%p") 

241 output_dir = f"corgi-output-{time_string}" 

242 

243 output_dir = Path(output_dir) 

244 

245 chunk_details = pd.DataFrame(self.seqio_dataloader.chunk_details, columns=["file", "accession", "chunk"]) 

246 predictions_df = pd.DataFrame(results[0].numpy(), columns=self.categories) 

247 results_df = pd.concat( 

248 [chunk_details.drop(columns=['chunk']), predictions_df], 

249 axis=1, 

250 ) 

251 

252 # Average over chunks 

253 results_df = results_df.groupby(["file", "accession"]).mean().reset_index() 

254 

255 columns = set(predictions_df.columns) 

256 

257 results_df['prediction'] = results_df[self.categories].idxmax(axis=1) 

258 results_df['eukaryotic'] = predictions_df[list(columns & set(refseq.EUKARYOTIC))].sum(axis=1) 

259 results_df['prokaryotic'] = predictions_df[list(columns & set(refseq.PROKARYOTIC))].sum(axis=1) 

260 results_df['organellar'] = predictions_df[list(columns & set(refseq.ORGANELLAR))].sum(axis=1) 

261 

262 if not csv: 

263 output_dir.mkdir(parents=True, exist_ok=True) 

264 csv = output_dir / f"corgi-output.csv" 

265 

266 console.print(f"Writing results for {len(results_df)} sequences to: {csv}") 

267 results_df.to_csv(csv, index=False) 

268 

269 # Write all the sequences to fasta files 

270 if save_filtered: 

271 record_to_string = FastaIO.as_fasta 

272 

273 output_dir.mkdir(parents=True, exist_ok=True) 

274 

275 file_handles = {} 

276 

277 for file, record in self.seqio_dataloader.iter_records(): 

278 row = results_df[ (results_df.accession == record.id) & (results_df.file == file) ] 

279 if len(row) == 0: 

280 categories = ["unclassified"] 

281 else: 

282 # Get the categories to write to 

283 if not threshold: 

284 # if no threshold then just use the most likely category 

285 categories = [row['prediction'].item()] 

286 else: 

287 # otherwise use all categories above or equal to the threshold 

288 category_predictions = row.iloc[0][self.categories] 

289 categories = [category_predictions[category_predictions >= threshold].index.item()] 

290 

291 for category in categories: 

292 if category not in file_handles: 

293 file_path = output_dir / f"{category}.fasta" 

294 file_handles[category] = open(file_path, "w") 

295 

296 file_handle = file_handles[category] 

297 file_handle.write(record_to_string(record)) 

298 

299 for file_handle in file_handles.values(): 

300 file_handle.close() 

301 

302 # Output bar chart 

303 from termgraph.module import Data, BarChart, Args 

304 

305 value_counts = results_df['prediction'].value_counts() 

306 data = Data([[count] for count in value_counts], value_counts.index) 

307 chart = BarChart( 

308 data, 

309 Args( 

310 space_between=False, 

311 ), 

312 ) 

313 

314 chart.draw() 

315 

316 def category_counts_dataloader(self, dataloader, description): 

317 from collections import Counter 

318 

319 counter = Counter() 

320 for batch in dataloader: 

321 counter.update(batch[1].cpu().numpy()) 

322 total = sum(counter.values()) 

323 

324 table = Table(title=f"{description}: Categories in epoch", box=SIMPLE) 

325 

326 table.add_column("Category", justify="right", style="cyan", no_wrap=True) 

327 table.add_column("Count", justify="center") 

328 table.add_column("Percentage") 

329 

330 for category_id, category in enumerate(self.categories): 

331 count = counter[category_id] 

332 table.add_row(category, str(count), f"{count/total*100:.1f}%") 

333 

334 table.add_row("Total", str(total), "") 

335 

336 console.print(table) 

337 

338 def category_counts(self, **kwargs): 

339 dataloaders = call_func(self.dataloaders, **kwargs) 

340 self.category_counts_dataloader(dataloaders.train, "Training") 

341 self.category_counts_dataloader(dataloaders.valid, "Validation") 

342 

343 def pretrained_location(self) -> str: 

344 return "https://github.com/rbturnbull/corgi/releases/download/v0.3.1-alpha/corgi-0.3.pkl"