Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from pathlib import Path
2from typing import List
3from torch import nn
4import pandas as pd
5from fastai.data.core import DataLoaders
6from torchapp.util import copy_func, call_func, change_typer_to_defaults, add_kwargs
7from fastai.learner import Learner, load_learner
8from fastai.metrics import accuracy, Precision, Recall, RocAuc, F1Score
9import torchapp as ta
10from rich.console import Console
11from rich.table import Table
12from rich.box import SIMPLE
13from Bio import SeqIO
14from Bio.SeqIO import FastaIO
16import time
18from fastai.losses import CrossEntropyLossFlat
20console = Console()
22from . import dataloaders, models, refseq, transforms
25class Corgi(ta.TorchApp):
26 """
27 corgi - Classifier for ORganelle Genomes Inter alia
28 """
30 def __init__(self):
31 super().__init__()
32 self.categories = refseq.REFSEQ_CATEGORIES # This will be overridden by the dataloader
33 self.category_counts = self.copy_method(self.category_counts)
34 add_kwargs(to_func=self.category_counts, from_funcs=self.dataloaders)
35 self.category_counts_cli = self.copy_method(self.category_counts)
36 change_typer_to_defaults(self.category_counts)
38 def dataloaders(
39 self,
40 csv: Path = ta.Param(help="The CSV which has the sequences to use."),
41 base_dir: Path = ta.Param(help="The base directory with the RefSeq HDF5 files."),
42 batch_size: int = ta.Param(default=32, help="The batch size."),
43 dataloader_type: dataloaders.DataloaderType = ta.Param(
44 default=dataloaders.DataloaderType.PLAIN, case_sensitive=False
45 ),
46 validation_seq_length:int = 1_000,
47 deform_lambda:float = ta.Param(default=None, help="The lambda for the deform transform."),
48 ) -> DataLoaders:
49 """
50 Creates a FastAI DataLoaders object which Corgi uses in training and prediction.
52 Args:
53 inputs (Path): The input file.
54 batch_size (int): The number of elements to use in a batch for training and prediction. Defaults to 32.
55 """
56 if csv is None:
57 raise Exception("No CSV given")
58 if base_dir is None:
59 raise Exception("No base_dir given")
60 dls = dataloaders.create_dataloaders_refseq_path(
61 csv,
62 base_dir=base_dir,
63 batch_size=batch_size,
64 dataloader_type=dataloader_type,
65 deform_lambda=deform_lambda,
66 validation_seq_length=validation_seq_length,
67 )
68 self.categories = dls.vocab
69 return dls
71 def model(
72 self,
73 embedding_dim: int = ta.Param(
74 default=8,
75 help="The size of the embeddings for the nucleotides (N, A, G, C, T).",
76 tune=True,
77 tune_min=4,
78 tune_max=32,
79 log=True,
80 ),
81 filters: int = ta.Param(
82 default=256,
83 help="The number of filters in each of the 1D convolution layers. These are concatenated together",
84 ),
85 cnn_layers: int = ta.Param(
86 default=6,
87 help="The number of 1D convolution layers.",
88 tune=True,
89 tune_min=2,
90 tune_max=6,
91 ),
92 kernel_size_maxpool: int = ta.Param(
93 default=2,
94 help="The size of the pooling before going to the LSTM.",
95 ),
96 lstm_dims: int = ta.Param(default=256, help="The size of the hidden layers in the LSTM in both directions."),
97 final_layer_dims: int = ta.Param(
98 default=0, help="The size of a dense layer after the LSTM. If this is zero then this layer isn't used."
99 ),
100 dropout: float = ta.Param(
101 default=0.2,
102 help="The amount of dropout to use. (not currently enabled)",
103 tune=True,
104 tune_min=0.0,
105 tune_max=0.3,
106 ),
107 final_bias: bool = ta.Param(
108 default=True,
109 help="Whether or not to use bias in the final layer.",
110 tune=True,
111 ),
112 cnn_only: bool = True,
113 kernel_size: int = ta.Param(
114 default=3, help="The size of the kernels for CNN only classifier.", tune=True, tune_choices=[3, 5, 7, 9]
115 ),
116 cnn_dims_start: int = ta.Param(
117 default=None,
118 help="The size of the number of filters in the first CNN layer. If not set then it is derived from the MACC",
119 ),
120 factor: float = ta.Param(
121 default=2.0,
122 help="The factor to multiply the number of filters in the CNN layers each time it is downscaled.",
123 tune=True,
124 log=True,
125 tune_min=0.5,
126 tune_max=2.5,
127 ),
128 penultimate_dims: int = ta.Param(
129 default=1024,
130 help="The factor to multiply the number of filters in the CNN layers each time it is downscaled.",
131 tune=True,
132 log=True,
133 tune_min=512,
134 tune_max=2048,
135 ),
136 include_length: bool = False,
137 transformer_heads: int = ta.Param(8, help="The number of heads in the transformer."),
138 transformer_layers: int = ta.Param(0, help="The number of layers in the transformer. If zero then no transformer is used."),
139 macc:int = ta.Param(
140 default=10_000_000,
141 help="The approximate number of multiply or accumulate operations in the model. Used to set cnn_dims_start if not provided explicitly.",
142 ),
143 ) -> nn.Module:
144 """
145 Creates a deep learning model for the Corgi to use.
147 Returns:
148 nn.Module: The created model.
149 """
150 num_classes = len(self.categories)
152 # if cnn_dims_start not given then calculate it from the MACC
153 if not cnn_dims_start:
154 assert macc
156 cnn_dims_start = models.calc_cnn_dims_start(
157 macc=macc,
158 seq_len=1024, # arbitary number
159 embedding_dim=embedding_dim,
160 cnn_layers=cnn_layers,
161 kernel_size=kernel_size,
162 factor=factor,
163 penultimate_dims=penultimate_dims,
164 num_classes=num_classes,
165 )
167 if cnn_only:
168 return models.ConvClassifier(
169 num_embeddings=5, # i.e. the size of the vocab which is N, A, C, G, T
170 kernel_size=kernel_size,
171 factor=factor,
172 cnn_layers=cnn_layers,
173 num_classes=num_classes,
174 kernel_size_maxpool=kernel_size_maxpool,
175 final_bias=final_bias,
176 dropout=dropout,
177 cnn_dims_start=cnn_dims_start,
178 penultimate_dims=penultimate_dims,
179 include_length=include_length,
180 transformer_layers=transformer_layers,
181 transformer_heads=transformer_heads,
182 )
184 return models.ConvRecurrantClassifier(
185 num_classes=num_classes,
186 embedding_dim=embedding_dim,
187 filters=filters,
188 cnn_layers=cnn_layers,
189 lstm_dims=lstm_dims,
190 final_layer_dims=final_layer_dims,
191 dropout=dropout,
192 kernel_size_maxpool=kernel_size_maxpool,
193 final_bias=final_bias,
194 )
196 def metrics(self):
197 average = "macro"
198 return [
199 accuracy,
200 F1Score(average=average),
201 Precision(average=average),
202 Recall(average=average),
203 RocAuc(average=average),
204 ]
206 def monitor(self):
207 return "f1_score"
209 # def loss_func(self, label_smoothing:float=ta.Param(0.1, help="The amount of label smoothing.")):
210 # return CrossEntropyLossFlat(label_smoothing=label_smoothing)
212 def inference_dataloader(
213 self,
214 learner,
215 file: List[Path] = ta.Param(None, help="A fasta file with sequences to be classified."),
216 max_seqs: int = None,
217 batch_size:int = 1,
218 max_length:int = 5_000,
219 min_length:int = 128,
220 **kwargs,
221 ):
222 self.seqio_dataloader = dataloaders.SeqIODataloader(files=file, device=learner.dls.device, batch_size=batch_size, max_length=max_length, max_seqs=max_seqs, min_length=min_length)
223 self.categories = learner.dls.vocab
224 return self.seqio_dataloader
226 def output_results(
227 self,
228 results,
229 output_dir:Path = ta.Param(default=None, help="A path to output the results as a CSV."),
230 csv: Path = ta.Param(default=None, help="A path to output the results as a CSV. If not given then a default name is chosen inside the output directory."),
231 save_filtered:bool = ta.Param(default=True, help="Whether or not to save the filtered sequences."),
232 threshold: float = ta.Param(
233 default=None,
234 help="The threshold to use for filtering. "
235 "If not given, then only the most likely category used for filtering.",
236 ),
237 **kwargs,
238 ):
239 if not output_dir:
240 time_string = time.strftime("%Y_%m_%d-%I_%M_%S_%p")
241 output_dir = f"corgi-output-{time_string}"
243 output_dir = Path(output_dir)
245 chunk_details = pd.DataFrame(self.seqio_dataloader.chunk_details, columns=["file", "accession", "chunk"])
246 predictions_df = pd.DataFrame(results[0].numpy(), columns=self.categories)
247 results_df = pd.concat(
248 [chunk_details.drop(columns=['chunk']), predictions_df],
249 axis=1,
250 )
252 # Average over chunks
253 results_df = results_df.groupby(["file", "accession"]).mean().reset_index()
255 columns = set(predictions_df.columns)
257 results_df['prediction'] = results_df[self.categories].idxmax(axis=1)
258 results_df['eukaryotic'] = predictions_df[list(columns & set(refseq.EUKARYOTIC))].sum(axis=1)
259 results_df['prokaryotic'] = predictions_df[list(columns & set(refseq.PROKARYOTIC))].sum(axis=1)
260 results_df['organellar'] = predictions_df[list(columns & set(refseq.ORGANELLAR))].sum(axis=1)
262 if not csv:
263 output_dir.mkdir(parents=True, exist_ok=True)
264 csv = output_dir / f"corgi-output.csv"
266 console.print(f"Writing results for {len(results_df)} sequences to: {csv}")
267 results_df.to_csv(csv, index=False)
269 # Write all the sequences to fasta files
270 if save_filtered:
271 record_to_string = FastaIO.as_fasta
273 output_dir.mkdir(parents=True, exist_ok=True)
275 file_handles = {}
277 for file, record in self.seqio_dataloader.iter_records():
278 row = results_df[ (results_df.accession == record.id) & (results_df.file == file) ]
279 if len(row) == 0:
280 categories = ["unclassified"]
281 else:
282 # Get the categories to write to
283 if not threshold:
284 # if no threshold then just use the most likely category
285 categories = [row['prediction'].item()]
286 else:
287 # otherwise use all categories above or equal to the threshold
288 category_predictions = row.iloc[0][self.categories]
289 categories = [category_predictions[category_predictions >= threshold].index.item()]
291 for category in categories:
292 if category not in file_handles:
293 file_path = output_dir / f"{category}.fasta"
294 file_handles[category] = open(file_path, "w")
296 file_handle = file_handles[category]
297 file_handle.write(record_to_string(record))
299 for file_handle in file_handles.values():
300 file_handle.close()
302 # Output bar chart
303 from termgraph.module import Data, BarChart, Args
305 value_counts = results_df['prediction'].value_counts()
306 data = Data([[count] for count in value_counts], value_counts.index)
307 chart = BarChart(
308 data,
309 Args(
310 space_between=False,
311 ),
312 )
314 chart.draw()
316 def category_counts_dataloader(self, dataloader, description):
317 from collections import Counter
319 counter = Counter()
320 for batch in dataloader:
321 counter.update(batch[1].cpu().numpy())
322 total = sum(counter.values())
324 table = Table(title=f"{description}: Categories in epoch", box=SIMPLE)
326 table.add_column("Category", justify="right", style="cyan", no_wrap=True)
327 table.add_column("Count", justify="center")
328 table.add_column("Percentage")
330 for category_id, category in enumerate(self.categories):
331 count = counter[category_id]
332 table.add_row(category, str(count), f"{count/total*100:.1f}%")
334 table.add_row("Total", str(total), "")
336 console.print(table)
338 def category_counts(self, **kwargs):
339 dataloaders = call_func(self.dataloaders, **kwargs)
340 self.category_counts_dataloader(dataloaders.train, "Training")
341 self.category_counts_dataloader(dataloaders.valid, "Validation")
343 def pretrained_location(self) -> str:
344 return "https://github.com/rbturnbull/corgi/releases/download/v0.3.1-alpha/corgi-0.3.pkl"