Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import torch 

2from enum import Enum 

3import random 

4from itertools import chain 

5 

6import gzip 

7import pandas as pd 

8from pathlib import Path 

9import numpy as np 

10from rich.progress import track 

11 

12from Bio import SeqIO 

13 

14from fastcore.foundation import L 

15from fastcore.dispatch import typedispatch 

16from fastcore.meta import delegates 

17 

18from fastai.data.core import TfmdDL, DataLoaders, get_empty_df 

19from fastai.callback.data import WeightedDL 

20from fastai.data.block import DataBlock, TransformBlock, CategoryBlock 

21from fastai.torch_core import display_df 

22from fastai.data.transforms import ColSplitter, ColReader, RandomSplitter 

23 

24from .tensor import TensorDNA, dna_seq_to_numpy, dna_seq_to_tensor 

25from .transforms import RandomSliceBatch, SliceTransform, RowToTensorDNA, PadBatchX, DeterministicSliceBatch, DeformBatch 

26from .refseq import RefSeqCategory 

27 

28 

29def fasta_open(fasta_path): 

30 if fasta_path.suffix == ".gz": 

31 return gzip.open(fasta_path, "rt") 

32 return open(fasta_path, "rt") 

33 

34 

35def fasta_seq_count(fasta_path): 

36 seq_count = 0 

37 with fasta_open(fasta_path) as fasta: 

38 for line in fasta: 

39 if line.startswith(">"): 

40 seq_count += 1 

41 return seq_count 

42 

43 

44@delegates() 

45class StratifiedDL(TfmdDL): 

46 def __init__(self, dataset=None, bs=None, groups=None, **kwargs): 

47 super().__init__(dataset=dataset, bs=bs, **kwargs) 

48 self.groups = [list(group) for group in groups] if groups else None 

49 self.min_length = None 

50 if not self.groups or not self.shuffle: 

51 return 

52 

53 for group in self.groups: 

54 if self.min_length is None: 

55 self.min_length = len(group) 

56 continue 

57 self.min_length = min(self.min_length, len(group)) 

58 self.queues = [self.shuffle_fn(group) for group in self.groups] 

59 self.n = self.min_length * len(self.queues) 

60 

61 def get_idxs(self): 

62 if not self.groups or not self.shuffle: 

63 return super().get_idxs() 

64 

65 epoch_indexes = [] 

66 for i, queue in enumerate(self.queues): 

67 if len(queue) < self.min_length: 

68 queue += self.shuffle_fn(self.groups[i]) 

69 

70 epoch_indexes.append(queue[: self.min_length]) 

71 self.queues[i] = queue[self.min_length :] 

72 

73 return list(chain(*zip(*epoch_indexes))) 

74 

75 

76@typedispatch 

77def show_batch(x: TensorDNA, y, samples, ctxs=None, max_n=20, trunc_at=150, **kwargs): 

78 if ctxs is None: 

79 ctxs = get_empty_df(min(len(samples), max_n)) 

80 if trunc_at is not None: 

81 samples = L((s[0], *s[1:]) for s in samples) 

82 ctxs = [(sample[0].show(), str(sample[1])) for sample in samples] 

83 df = pd.DataFrame(ctxs, columns=["x", "y"]) 

84 display_df(df) 

85 return ctxs 

86 

87 

88def get_sequence_as_tensor(row): 

89 return TensorDNA(row["sequence"]) 

90 

91 

92def create_datablock_refseq(categories, validation_column="validation", validation_prob=0.2, vocab=None) -> DataBlock: 

93 

94 # Check if there is a validation column in the dataset otherwise use a random splitter 

95 if validation_column: 

96 splitter = ColSplitter(validation_column) 

97 else: 

98 splitter = RandomSplitter(valid_pct=validation_prob, seed=42) 

99 

100 return DataBlock( 

101 blocks=(TransformBlock, CategoryBlock(vocab=vocab)), 

102 splitter=splitter, 

103 get_y=ColReader("category"), 

104 item_tfms=RowToTensorDNA(categories), 

105 ) 

106 

107 

108def create_datablock(seq_length=None, validation_column="validation", validation_prob=0.2, vocab=None) -> DataBlock: 

109 

110 # Check if we need to slice to a specific sequence length 

111 if seq_length: 

112 item_tfms = SliceTransform(seq_length) 

113 else: 

114 item_tfms = None 

115 

116 # Check if there is a validation column in the dataset otherwise use a random splitter 

117 if validation_column: 

118 splitter = ColSplitter(validation_column) 

119 else: 

120 splitter = RandomSplitter(valid_pct=validation_prob, seed=42) 

121 

122 return DataBlock( 

123 blocks=(TransformBlock, CategoryBlock(vocab=vocab)), 

124 splitter=splitter, 

125 get_x=get_sequence_as_tensor, 

126 get_y=ColReader("category"), 

127 item_tfms=item_tfms, 

128 ) 

129 

130 

131class DataloaderType(str, Enum): 

132 PLAIN = "PLAIN" 

133 WEIGHTED = "WEIGHTED" 

134 STRATIFIED = "STRATIFIED" 

135 

136 

137def create_dataloaders_refseq_path( 

138 dataframe_path: Path, 

139 base_dir: Path, 

140 batch_size:int=64, 

141 deform_lambda: float = None, 

142 validation_seq_length:int=1_000, 

143 **kwargs 

144): 

145 dataframe_path = Path(dataframe_path) 

146 

147 print('Training using:\t', dataframe_path) 

148 if dataframe_path.suffix == ".parquet": 

149 df = pd.read_parquet(str(dataframe_path), engine="pyarrow") 

150 else: 

151 df = pd.read_csv(str(dataframe_path)) 

152 

153 print(f'Dataframe has {len(df)} sequences.') 

154 dls = create_dataloaders_refseq( 

155 df, 

156 batch_size=batch_size, 

157 base_dir=base_dir, 

158 deform_lambda=deform_lambda, 

159 validation_seq_length=validation_seq_length, 

160 **kwargs 

161 ) 

162 return dls 

163 

164 

165def create_dataloaders_refseq( 

166 df: pd.DataFrame, 

167 base_dir: Path, 

168 batch_size=64, 

169 dataloader_type: DataloaderType = DataloaderType.PLAIN, 

170 verbose: bool = True, 

171 validation_seq_length:int = 1_000, 

172 deform_lambda: float = None, 

173 **kwargs, 

174) -> DataLoaders: 

175 categories = [RefSeqCategory(name, base_dir=base_dir) for name in df.category.unique()] 

176 

177 # Set up batch transforms 

178 before_batch = [ 

179 RandomSliceBatch(only_split_index=0), 

180 DeterministicSliceBatch(seq_length=validation_seq_length, only_split_index=1), 

181 ] 

182 if deform_lambda is not None: 

183 before_batch.append(DeformBatch(deform_lambda=deform_lambda)) 

184 

185 dataloaders_kwargs = dict(bs=batch_size, drop_last=False, before_batch=before_batch) 

186 

187 validation_column = "validation" 

188 random.seed(42) 

189 if validation_column not in df: 

190 df[validation_column] = 0 

191 value_counts = df.category.value_counts() 

192 validation_per_category = int(0.2 * value_counts.min()) 

193 

194 for name in df.category.unique(): 

195 indexes_for_category = df.index[df.category == name] 

196 validation_indexes = random.sample(list(indexes_for_category.values), validation_per_category) 

197 df.loc[validation_indexes, validation_column] = 1 

198 

199 print("Creating Datablock") 

200 vocab = df['category'].unique() 

201 datablock = create_datablock_refseq(categories, validation_column=validation_column, vocab=vocab, **kwargs) 

202 

203 dataloader_type = str(dataloader_type).upper() 

204 if validation_column in df: 

205 training_df = df[df[validation_column] == 0].reset_index() 

206 

207 if dataloader_type == "STRATIFIED": 

208 print("Creating groups for balancing dataset") 

209 groups = [training_df.index[training_df['category'] == name] for name in vocab] 

210 

211 dataloaders_kwargs['dl_type'] = StratifiedDL 

212 dataloaders_kwargs['dl_kwargs'] = [dict(groups=groups), dict()] 

213 elif dataloader_type == "WEIGHTED": 

214 print("Creating weights for balancing dataset") 

215 weights = np.zeros((len(training_df),)) 

216 value_counts = training_df['category'].value_counts() 

217 for name in df.category.unique(): 

218 weight = value_counts.max() / value_counts[name] 

219 print(f"\tWeight for {name}: {weight}") 

220 weights[training_df['category'] == name] = weight 

221 

222 dataloaders_kwargs['dl_type'] = WeightedDL 

223 dataloaders_kwargs['dl_kwargs'] = [dict(wgts=weights), dict()] 

224 elif dataloader_type == "PLAIN": 

225 pass 

226 else: 

227 raise Exception(f"dataloader type {dataloader_type} not understood") 

228 

229 print("Creating Dataloaders") 

230 return datablock.dataloaders(df, verbose=verbose, **dataloaders_kwargs) 

231 

232 

233def create_dataloaders(df: pd.DataFrame, batch_size=64, **kwargs) -> DataLoaders: 

234 datablock = create_datablock(**kwargs) 

235 return datablock.dataloaders(df, bs=batch_size, drop_last=False) 

236 

237 

238def fasta_to_dataframe( 

239 fasta_path, 

240 max_seqs=None, 

241 validation_from_filename=True, 

242 validation_prob=0.2, 

243): 

244 """ 

245 Creates a pandas dataframe from a fasta file. 

246 

247 If validation_from_filename is True then it checks if 'valid' or 'train' is in the filename, 

248 otherwise it falls back to using the validation_prob. 

249 If 'valid' or 'train' is in the filename and validation_from_filename is True then validation_prob is ignored. 

250 """ 

251 fasta_path = Path(fasta_path) 

252 print(f"Processing:\t{fasta_path}") 

253 

254 if not fasta_path.exists(): 

255 raise FileNotFoundError(f"Cannot find fasta file {fasta_path}.") 

256 

257 seq_count = fasta_seq_count(fasta_path) 

258 print(f"{seq_count} sequences") 

259 if max_seqs and seq_count >= max_seqs: 

260 print(f"Limiting to maximum number of sequences: {max_seqs}") 

261 seq_count = max_seqs 

262 

263 data = [] 

264 fasta = fasta_open(fasta_path) 

265 

266 if validation_from_filename: 

267 if "valid" in str(fasta_path): 

268 validation = 1 

269 elif "train" in str(fasta_path): 

270 validation = 0 

271 else: 

272 validation_from_filename = False 

273 

274 seqs = SeqIO.parse(fasta, "fasta") 

275 for seq_index, seq in enumerate(track(seqs, total=seq_count, description=f"Reading fasta file:")): 

276 if max_seqs and seq_index >= max_seqs: 

277 break 

278 

279 if not validation_from_filename: 

280 validation = int(random.random() < validation_prob) 

281 

282 seq_as_numpy = dna_seq_to_tensor(seq.seq) 

283 data.append([seq.id, seq.description, seq_as_numpy, validation]) 

284 

285 fasta.close() 

286 

287 df = pd.DataFrame(data, columns=["id", "description", "sequence", "validation"]) 

288 df["file"] = str(fasta_path) 

289 df["category"] = fasta_path.name.split(".")[0] 

290 return df 

291 

292 

293def fastas_to_dataframe(fasta_paths, **kwargs): 

294 dfs = [fasta_to_dataframe(fasta_path, **kwargs) for fasta_path in fasta_paths] 

295 return pd.concat(dfs) 

296 

297 

298def create_dataloaders_from_fastas(fasta_paths, batch_size=64, seq_length=None, **kwargs) -> DataLoaders: 

299 """ 

300 Creates a DataLoaders object from a list of fasta paths. 

301 """ 

302 df = fastas_to_dataframe(fasta_paths, **kwargs) 

303 return create_dataloaders(df, batch_size=batch_size, seq_length=seq_length) 

304 

305 

306class FastaDataloader: 

307 def __init__(self, fasta_files, device): 

308 self.fasta_files = list(fasta_files) 

309 self.device = device 

310 

311 def __iter__(self): 

312 self.randomize() 

313 self.before_iter() 

314 self.__idxs = self.get_idxs() # called in context of main process (not workers/subprocesses) 

315 for b in _loaders[self.fake_l.num_workers == 0](self.fake_l): 

316 if self.device is not None: 

317 b = to_device(b, self.device) 

318 yield self.after_batch(b) 

319 self.after_iter() 

320 if hasattr(self, 'it'): 

321 del self.it 

322 

323 

324class SeqIODataloader: 

325 def __init__(self, files, device, batch_size:int=1, min_length:int=128, max_length:int=5_000, max_seqs:int=None, format:str=""): 

326 self.files = list(files) 

327 self.device = device 

328 self.format = format 

329 self.chunk_details = [] 

330 self.max_length = max_length 

331 self.batch_size = batch_size 

332 self.min_length = min_length 

333 self.pad = PadBatchX() 

334 self.count = 0 

335 self.max_seqs = max_seqs 

336 seqs = 0 

337 for file in self.files: 

338 for record in self.parse(file): 

339 if len(record.seq) < self.min_length: 

340 continue 

341 

342 if self.max_seqs and seqs >= self.max_seqs: 

343 break 

344 

345 chunks = len(record.seq)//self.max_length + 1 

346 self.count += chunks 

347 seqs += 1 

348 

349 

350 def get_file_format(self, file): 

351 if self.format: 

352 return self.format 

353 

354 file = Path(file) 

355 suffix = file.suffix.lower() 

356 

357 if suffix in [".fa", ".fna", ".fasta"]: 

358 return "fasta" 

359 

360 if suffix in [".genbank", ".gb", ".gbk"]: 

361 return "genbank" 

362 

363 if suffix in [".tab", ".tsv"]: 

364 return "tsv" 

365 

366 if suffix in [".fastq", ".fq"]: 

367 return "fastq" 

368 

369 raise ValueError(f"Cannot determine file format of {file}.") 

370 

371 def __len__(self): 

372 return self.count 

373 

374 def parse(self, file): 

375 return SeqIO.parse(file, self.get_file_format(file)) 

376 

377 def iter_records(self): 

378 for file in self.files: 

379 for record in self.parse(file): 

380 yield file, record 

381 

382 def __iter__(self): 

383 batch = [] 

384 seqs = 0 

385 

386 for file in self.files: 

387 for record in self.parse(file): 

388 if len(record.seq) < self.min_length: 

389 continue 

390 

391 if self.max_seqs and seqs >= self.max_seqs: 

392 break 

393 

394 seqs += 1 

395 t = dna_seq_to_tensor(record.seq) 

396 chunks = len(t)//self.max_length + 1 

397 

398 for chunk_index, chunk in enumerate(t.chunk(chunks)): 

399 self.chunk_details.append( (file, record.id, chunk_index) ) 

400 batch.append(chunk) 

401 if len(batch) >= self.batch_size: 

402 batch = self.pad(batch) 

403 yield batch 

404 batch = [] 

405 

406 if batch: 

407 batch = self.pad(batch) 

408 yield batch