Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import torch
2from enum import Enum
3import random
4from itertools import chain
6import gzip
7import pandas as pd
8from pathlib import Path
9import numpy as np
10from rich.progress import track
12from Bio import SeqIO
14from fastcore.foundation import L
15from fastcore.dispatch import typedispatch
16from fastcore.meta import delegates
18from fastai.data.core import TfmdDL, DataLoaders, get_empty_df
19from fastai.callback.data import WeightedDL
20from fastai.data.block import DataBlock, TransformBlock, CategoryBlock
21from fastai.torch_core import display_df
22from fastai.data.transforms import ColSplitter, ColReader, RandomSplitter
24from .tensor import TensorDNA, dna_seq_to_numpy, dna_seq_to_tensor
25from .transforms import RandomSliceBatch, SliceTransform, RowToTensorDNA, PadBatchX, DeterministicSliceBatch, DeformBatch
26from .refseq import RefSeqCategory
29def fasta_open(fasta_path):
30 if fasta_path.suffix == ".gz":
31 return gzip.open(fasta_path, "rt")
32 return open(fasta_path, "rt")
35def fasta_seq_count(fasta_path):
36 seq_count = 0
37 with fasta_open(fasta_path) as fasta:
38 for line in fasta:
39 if line.startswith(">"):
40 seq_count += 1
41 return seq_count
44@delegates()
45class StratifiedDL(TfmdDL):
46 def __init__(self, dataset=None, bs=None, groups=None, **kwargs):
47 super().__init__(dataset=dataset, bs=bs, **kwargs)
48 self.groups = [list(group) for group in groups] if groups else None
49 self.min_length = None
50 if not self.groups or not self.shuffle:
51 return
53 for group in self.groups:
54 if self.min_length is None:
55 self.min_length = len(group)
56 continue
57 self.min_length = min(self.min_length, len(group))
58 self.queues = [self.shuffle_fn(group) for group in self.groups]
59 self.n = self.min_length * len(self.queues)
61 def get_idxs(self):
62 if not self.groups or not self.shuffle:
63 return super().get_idxs()
65 epoch_indexes = []
66 for i, queue in enumerate(self.queues):
67 if len(queue) < self.min_length:
68 queue += self.shuffle_fn(self.groups[i])
70 epoch_indexes.append(queue[: self.min_length])
71 self.queues[i] = queue[self.min_length :]
73 return list(chain(*zip(*epoch_indexes)))
76@typedispatch
77def show_batch(x: TensorDNA, y, samples, ctxs=None, max_n=20, trunc_at=150, **kwargs):
78 if ctxs is None:
79 ctxs = get_empty_df(min(len(samples), max_n))
80 if trunc_at is not None:
81 samples = L((s[0], *s[1:]) for s in samples)
82 ctxs = [(sample[0].show(), str(sample[1])) for sample in samples]
83 df = pd.DataFrame(ctxs, columns=["x", "y"])
84 display_df(df)
85 return ctxs
88def get_sequence_as_tensor(row):
89 return TensorDNA(row["sequence"])
92def create_datablock_refseq(categories, validation_column="validation", validation_prob=0.2, vocab=None) -> DataBlock:
94 # Check if there is a validation column in the dataset otherwise use a random splitter
95 if validation_column:
96 splitter = ColSplitter(validation_column)
97 else:
98 splitter = RandomSplitter(valid_pct=validation_prob, seed=42)
100 return DataBlock(
101 blocks=(TransformBlock, CategoryBlock(vocab=vocab)),
102 splitter=splitter,
103 get_y=ColReader("category"),
104 item_tfms=RowToTensorDNA(categories),
105 )
108def create_datablock(seq_length=None, validation_column="validation", validation_prob=0.2, vocab=None) -> DataBlock:
110 # Check if we need to slice to a specific sequence length
111 if seq_length:
112 item_tfms = SliceTransform(seq_length)
113 else:
114 item_tfms = None
116 # Check if there is a validation column in the dataset otherwise use a random splitter
117 if validation_column:
118 splitter = ColSplitter(validation_column)
119 else:
120 splitter = RandomSplitter(valid_pct=validation_prob, seed=42)
122 return DataBlock(
123 blocks=(TransformBlock, CategoryBlock(vocab=vocab)),
124 splitter=splitter,
125 get_x=get_sequence_as_tensor,
126 get_y=ColReader("category"),
127 item_tfms=item_tfms,
128 )
131class DataloaderType(str, Enum):
132 PLAIN = "PLAIN"
133 WEIGHTED = "WEIGHTED"
134 STRATIFIED = "STRATIFIED"
137def create_dataloaders_refseq_path(
138 dataframe_path: Path,
139 base_dir: Path,
140 batch_size:int=64,
141 deform_lambda: float = None,
142 validation_seq_length:int=1_000,
143 **kwargs
144):
145 dataframe_path = Path(dataframe_path)
147 print('Training using:\t', dataframe_path)
148 if dataframe_path.suffix == ".parquet":
149 df = pd.read_parquet(str(dataframe_path), engine="pyarrow")
150 else:
151 df = pd.read_csv(str(dataframe_path))
153 print(f'Dataframe has {len(df)} sequences.')
154 dls = create_dataloaders_refseq(
155 df,
156 batch_size=batch_size,
157 base_dir=base_dir,
158 deform_lambda=deform_lambda,
159 validation_seq_length=validation_seq_length,
160 **kwargs
161 )
162 return dls
165def create_dataloaders_refseq(
166 df: pd.DataFrame,
167 base_dir: Path,
168 batch_size=64,
169 dataloader_type: DataloaderType = DataloaderType.PLAIN,
170 verbose: bool = True,
171 validation_seq_length:int = 1_000,
172 deform_lambda: float = None,
173 **kwargs,
174) -> DataLoaders:
175 categories = [RefSeqCategory(name, base_dir=base_dir) for name in df.category.unique()]
177 # Set up batch transforms
178 before_batch = [
179 RandomSliceBatch(only_split_index=0),
180 DeterministicSliceBatch(seq_length=validation_seq_length, only_split_index=1),
181 ]
182 if deform_lambda is not None:
183 before_batch.append(DeformBatch(deform_lambda=deform_lambda))
185 dataloaders_kwargs = dict(bs=batch_size, drop_last=False, before_batch=before_batch)
187 validation_column = "validation"
188 random.seed(42)
189 if validation_column not in df:
190 df[validation_column] = 0
191 value_counts = df.category.value_counts()
192 validation_per_category = int(0.2 * value_counts.min())
194 for name in df.category.unique():
195 indexes_for_category = df.index[df.category == name]
196 validation_indexes = random.sample(list(indexes_for_category.values), validation_per_category)
197 df.loc[validation_indexes, validation_column] = 1
199 print("Creating Datablock")
200 vocab = df['category'].unique()
201 datablock = create_datablock_refseq(categories, validation_column=validation_column, vocab=vocab, **kwargs)
203 dataloader_type = str(dataloader_type).upper()
204 if validation_column in df:
205 training_df = df[df[validation_column] == 0].reset_index()
207 if dataloader_type == "STRATIFIED":
208 print("Creating groups for balancing dataset")
209 groups = [training_df.index[training_df['category'] == name] for name in vocab]
211 dataloaders_kwargs['dl_type'] = StratifiedDL
212 dataloaders_kwargs['dl_kwargs'] = [dict(groups=groups), dict()]
213 elif dataloader_type == "WEIGHTED":
214 print("Creating weights for balancing dataset")
215 weights = np.zeros((len(training_df),))
216 value_counts = training_df['category'].value_counts()
217 for name in df.category.unique():
218 weight = value_counts.max() / value_counts[name]
219 print(f"\tWeight for {name}: {weight}")
220 weights[training_df['category'] == name] = weight
222 dataloaders_kwargs['dl_type'] = WeightedDL
223 dataloaders_kwargs['dl_kwargs'] = [dict(wgts=weights), dict()]
224 elif dataloader_type == "PLAIN":
225 pass
226 else:
227 raise Exception(f"dataloader type {dataloader_type} not understood")
229 print("Creating Dataloaders")
230 return datablock.dataloaders(df, verbose=verbose, **dataloaders_kwargs)
233def create_dataloaders(df: pd.DataFrame, batch_size=64, **kwargs) -> DataLoaders:
234 datablock = create_datablock(**kwargs)
235 return datablock.dataloaders(df, bs=batch_size, drop_last=False)
238def fasta_to_dataframe(
239 fasta_path,
240 max_seqs=None,
241 validation_from_filename=True,
242 validation_prob=0.2,
243):
244 """
245 Creates a pandas dataframe from a fasta file.
247 If validation_from_filename is True then it checks if 'valid' or 'train' is in the filename,
248 otherwise it falls back to using the validation_prob.
249 If 'valid' or 'train' is in the filename and validation_from_filename is True then validation_prob is ignored.
250 """
251 fasta_path = Path(fasta_path)
252 print(f"Processing:\t{fasta_path}")
254 if not fasta_path.exists():
255 raise FileNotFoundError(f"Cannot find fasta file {fasta_path}.")
257 seq_count = fasta_seq_count(fasta_path)
258 print(f"{seq_count} sequences")
259 if max_seqs and seq_count >= max_seqs:
260 print(f"Limiting to maximum number of sequences: {max_seqs}")
261 seq_count = max_seqs
263 data = []
264 fasta = fasta_open(fasta_path)
266 if validation_from_filename:
267 if "valid" in str(fasta_path):
268 validation = 1
269 elif "train" in str(fasta_path):
270 validation = 0
271 else:
272 validation_from_filename = False
274 seqs = SeqIO.parse(fasta, "fasta")
275 for seq_index, seq in enumerate(track(seqs, total=seq_count, description=f"Reading fasta file:")):
276 if max_seqs and seq_index >= max_seqs:
277 break
279 if not validation_from_filename:
280 validation = int(random.random() < validation_prob)
282 seq_as_numpy = dna_seq_to_tensor(seq.seq)
283 data.append([seq.id, seq.description, seq_as_numpy, validation])
285 fasta.close()
287 df = pd.DataFrame(data, columns=["id", "description", "sequence", "validation"])
288 df["file"] = str(fasta_path)
289 df["category"] = fasta_path.name.split(".")[0]
290 return df
293def fastas_to_dataframe(fasta_paths, **kwargs):
294 dfs = [fasta_to_dataframe(fasta_path, **kwargs) for fasta_path in fasta_paths]
295 return pd.concat(dfs)
298def create_dataloaders_from_fastas(fasta_paths, batch_size=64, seq_length=None, **kwargs) -> DataLoaders:
299 """
300 Creates a DataLoaders object from a list of fasta paths.
301 """
302 df = fastas_to_dataframe(fasta_paths, **kwargs)
303 return create_dataloaders(df, batch_size=batch_size, seq_length=seq_length)
306class FastaDataloader:
307 def __init__(self, fasta_files, device):
308 self.fasta_files = list(fasta_files)
309 self.device = device
311 def __iter__(self):
312 self.randomize()
313 self.before_iter()
314 self.__idxs = self.get_idxs() # called in context of main process (not workers/subprocesses)
315 for b in _loaders[self.fake_l.num_workers == 0](self.fake_l):
316 if self.device is not None:
317 b = to_device(b, self.device)
318 yield self.after_batch(b)
319 self.after_iter()
320 if hasattr(self, 'it'):
321 del self.it
324class SeqIODataloader:
325 def __init__(self, files, device, batch_size:int=1, min_length:int=128, max_length:int=5_000, max_seqs:int=None, format:str=""):
326 self.files = list(files)
327 self.device = device
328 self.format = format
329 self.chunk_details = []
330 self.max_length = max_length
331 self.batch_size = batch_size
332 self.min_length = min_length
333 self.pad = PadBatchX()
334 self.count = 0
335 self.max_seqs = max_seqs
336 seqs = 0
337 for file in self.files:
338 for record in self.parse(file):
339 if len(record.seq) < self.min_length:
340 continue
342 if self.max_seqs and seqs >= self.max_seqs:
343 break
345 chunks = len(record.seq)//self.max_length + 1
346 self.count += chunks
347 seqs += 1
350 def get_file_format(self, file):
351 if self.format:
352 return self.format
354 file = Path(file)
355 suffix = file.suffix.lower()
357 if suffix in [".fa", ".fna", ".fasta"]:
358 return "fasta"
360 if suffix in [".genbank", ".gb", ".gbk"]:
361 return "genbank"
363 if suffix in [".tab", ".tsv"]:
364 return "tsv"
366 if suffix in [".fastq", ".fq"]:
367 return "fastq"
369 raise ValueError(f"Cannot determine file format of {file}.")
371 def __len__(self):
372 return self.count
374 def parse(self, file):
375 return SeqIO.parse(file, self.get_file_format(file))
377 def iter_records(self):
378 for file in self.files:
379 for record in self.parse(file):
380 yield file, record
382 def __iter__(self):
383 batch = []
384 seqs = 0
386 for file in self.files:
387 for record in self.parse(file):
388 if len(record.seq) < self.min_length:
389 continue
391 if self.max_seqs and seqs >= self.max_seqs:
392 break
394 seqs += 1
395 t = dna_seq_to_tensor(record.seq)
396 chunks = len(t)//self.max_length + 1
398 for chunk_index, chunk in enumerate(t.chunk(chunks)):
399 self.chunk_details.append( (file, record.id, chunk_index) )
400 batch.append(chunk)
401 if len(batch) >= self.batch_size:
402 batch = self.pad(batch)
403 yield batch
404 batch = []
406 if batch:
407 batch = self.pad(batch)
408 yield batch