Coverage for seqbank/seqbank.py: 100.00%

211 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-02 04:29 +0000

1from functools import cached_property 

2import numpy as np 

3from pathlib import Path 

4from joblib import Parallel, delayed 

5import plotly.express as px 

6import plotly.graph_objs as go 

7from Bio.Seq import Seq 

8from Bio.SeqRecord import SeqRecord 

9from Bio import SeqIO 

10from attrs import define 

11from rich.progress import track 

12import pyfastx 

13from rich.progress import Progress, TimeElapsedColumn, MofNCompleteColumn 

14from datetime import datetime 

15from speedict import Rdict, Options, DBCompressionType, AccessType 

16import atexit 

17 

18from .transform import seq_to_bytes, bytes_to_str 

19from .io import get_file_format, open_path, download_file, seq_count, TemporaryDirectory 

20from .exceptions import SeqBankError 

21from .utils import parse_filter, format_fig 

22 

23 

24@define(slots=False) 

25class SeqBank: 

26 path: Path 

27 write: bool = False 

28 

29 def __attrs_post_init__(self) -> None: 

30 """Initializes the SeqBank object after attributes are set. 

31 

32 Expands the user path for the SeqBank file and checks if it exists when not in write mode. 

33 

34 Raises: 

35 FileNotFoundError: If the SeqBank file is not found at the given path when write mode is disabled. 

36 """ 

37 self.path = Path(self.path).expanduser() 

38 if not self.write and not self.path.exists(): 

39 raise FileNotFoundError(f"Cannot find SeqBank file at path: {self.path}") 

40 

41 def key(self, accession: str) -> str: 

42 """Generates a key for a given accession. 

43 

44 Args: 

45 accession (str): Accession string used to generate the key. 

46 

47 Returns: 

48 str: A byte-encoded key as a string. 

49 """ 

50 return bytes(accession, "ascii") 

51 

52 def key_url(self, url: str) -> str: 

53 """Generates a key for a given URL. 

54 

55 Args: 

56 url (str): The URL for which the key is generated. 

57 

58 Returns: 

59 str: A byte-encoded key string prefixed with '/seqbank/url/'. 

60 """ 

61 return self.key("/seqbank/url/" + url) 

62 

63 def close(self) -> None: 

64 """Closes the SeqBank database connection. 

65 

66 Attempts to close the database connection, if it's open, and silently handles any exceptions. 

67 """ 

68 try: 

69 self._db.close() 

70 except Exception: 

71 pass 

72 

73 @cached_property 

74 def file(self) -> Rdict: 

75 """Initializes and configures the Rdict database for sequence storage. 

76 

77 Configures options for the database, such as compression type, optimization, and maximum open files. 

78 Registers the close method to be executed upon program exit to ensure the database is closed. 

79 

80 Returns: 

81 Rdict: The configured Rdict database object. 

82 """ 

83 options = Options(raw_mode=True) 

84 options.set_compression_type(DBCompressionType.none()) 

85 options.set_optimize_filters_for_hits(True) 

86 options.optimize_for_point_lookup(1024) 

87 options.set_max_open_files(500) 

88 

89 self._db = Rdict( 

90 path=str(self.path), 

91 options=options, 

92 access_type=AccessType.read_write() if self.write else AccessType.read_only(), 

93 ) 

94 

95 atexit.register(self.close) 

96 return self._db 

97 

98 def __len__(self) -> int: 

99 """Calculates the total number of items in the SeqBank. 

100 

101 Iterates over all the keys in the database and counts them. 

102 

103 Returns: 

104 int: The number of entries in the SeqBank. 

105 """ 

106 count = 0 

107 for _ in track(self.file.keys()): 

108 count += 1 

109 return count 

110 

111 def __getitem__(self, accession: str) -> np.ndarray: 

112 """Retrieves the sequence data associated with a given accession. 

113 

114 Args: 

115 accession (str): The accession key to look up in the SeqBank. 

116 

117 Returns: 

118 np.ndarray: The sequence data stored in the SeqBank for the given accession. 

119 

120 Raises: 

121 SeqBankError: If the accession cannot be read or an error occurs during retrieval. 

122 """ 

123 try: 

124 key = self.key(accession) 

125 file = self.file 

126 return file[key] 

127 except Exception as err: 

128 raise SeqBankError(f"Failed to read {accession} in SeqBank {self.path}:\n{err}") 

129 

130 def items(self): 

131 """Yields all key-value pairs from the SeqBank. 

132 

133 The key is the accession, and the value is the sequence data in NumPy array format. 

134 

135 Yields: 

136 tuple: A tuple containing the accession (key) and the sequence data (value) as a NumPy array. 

137 """ 

138 for k, v in self.file.items(): 

139 yield k, np.frombuffer(v, dtype="u1") 

140 

141 def __contains__(self, accession: str) -> bool: 

142 """Checks if a given accession exists in the SeqBank. 

143 

144 Args: 

145 accession (str): The accession to check for existence. 

146 

147 Returns: 

148 bool: True if the accession exists in the SeqBank, otherwise False. 

149 """ 

150 try: 

151 return self.key(accession) in self.file 

152 except Exception: 

153 return False 

154 

155 def delete(self, accession: str) -> None: 

156 """Deletes a sequence entry from the SeqBank by its accession. 

157 

158 Args: 

159 accession (str): The accession of the sequence to delete. 

160 

161 Returns: 

162 None 

163 """ 

164 key = self.key(accession) 

165 f = self.file 

166 if key in f: 

167 del f[key] 

168 

169 def add(self, seq: str | Seq | SeqRecord | np.ndarray, accession: str) -> None: 

170 """Adds a sequence to the SeqBank with a given accession. 

171 

172 This method accepts a sequence in various formats (string, Seq, SeqRecord, or NumPy array) 

173 and stores it in the SeqBank after appropriate conversion to byte format. 

174 

175 Args: 

176 seq (str | Seq | SeqRecord | np.ndarray): The sequence to add to the SeqBank. 

177 It can be a string, Bio.Seq object, SeqRecord, or a NumPy array. 

178 accession (str): The accession key for the sequence to be stored under. 

179 

180 Returns: 

181 None 

182 """ 

183 key = self.key(accession) 

184 

185 if isinstance(seq, SeqRecord): 

186 seq = seq.seq 

187 if isinstance(seq, Seq): 

188 seq = str(seq) 

189 if isinstance(seq, str): 

190 seq = seq_to_bytes(seq) 

191 

192 self.file[key] = seq 

193 

194 def add_file( 

195 self, 

196 path: Path, 

197 format: str = "", 

198 progress=None, 

199 overall_task=None, 

200 filter: Path | list | set | None = None, 

201 ) -> None: 

202 """Adds sequences from a file to the SeqBank. 

203 

204 This method processes a sequence file in various formats (e.g., FASTA, FASTQ), optionally filtering specific accessions, 

205 and adds the sequences to the SeqBank. Progress tracking is available for large file imports. 

206 

207 Args: 

208 path (Path): The path to the sequence file. 

209 format (str, optional): The format of the sequence file (e.g., "fasta", "fastq"). If not provided, it will be auto-detected. 

210 progress (Progress, optional): A rich progress bar to display the import progress. Defaults to None. 

211 overall_task (int | None, optional): An optional task ID for tracking the overall progress. Defaults to None. 

212 filter (Path | list | set | None, optional): A filter for selecting specific accessions. Defaults to None. 

213 

214 Returns: 

215 None 

216 """ 

217 filter = parse_filter(filter) 

218 format = format or get_file_format(path) 

219 progress = progress or Progress() 

220 

221 # If fasta or fastq use pyfastx for speed 

222 if format in ["fasta", "fastq"]: 

223 total = sum(1 for _ in pyfastx.Fasta(str(path), build_index=False)) 

224 task = progress.add_task(f"[magenta]{path.name}", total=total) 

225 

226 for accession, seq in pyfastx.Fasta(str(path), build_index=False): 

227 if filter and accession not in filter: 

228 continue 

229 self.add(seq, accession) 

230 progress.update(task, advance=1) 

231 else: 

232 total = seq_count(path) 

233 

234 with open_path(path) as f: 

235 task = progress.add_task(f"[magenta]{path.name}", total=total) 

236 for record in SeqIO.parse(f, format): 

237 if filter and record.id not in filter: 

238 continue 

239 

240 self.add(record, record.id) 

241 progress.update(task, advance=1) 

242 

243 progress.update(task, visible=False) 

244 if overall_task is not None: 

245 progress.update(overall_task, advance=1) 

246 

247 print("Added", path.name) 

248 

249 def add_sequence_from_file( 

250 self, 

251 accession: str, 

252 path: Path, 

253 format: str = "", 

254 ) -> None: 

255 """ 

256 Adds a single sequence from a file to the SeqBank. 

257 

258 This method processes a single sequence from a file in various formats (e.g., FASTA, FASTQ). 

259 The accession for this sequence is provided as an argument. 

260 

261 Args: 

262 accession (str): The accession key for the sequence to be stored under. 

263 path (Path): The path to the sequence file. 

264 format (str, optional): The format of the sequence file (e.g., "fasta", "fastq"). If not provided, it will be auto-detected. 

265 

266 Returns: 

267 None 

268 """ 

269 format = format or get_file_format(path) 

270 

271 def check_total(total: int, path: Path) -> None: 

272 if total > 1: 

273 raise SeqBankError(f"Multiple sequences found in {path}, found {total}") 

274 

275 # If fasta or fastq use pyfastx for speed 

276 if format in ["fasta", "fastq"]: 

277 total = sum(1 for _ in pyfastx.Fasta(str(path), build_index=False)) 

278 check_total(total, path) 

279 

280 for _, seq in pyfastx.Fasta(str(path), build_index=False): 

281 self.add(seq, accession) 

282 else: 

283 total = seq_count(path) 

284 check_total(total, path) 

285 

286 with open_path(path) as f: 

287 for record in SeqIO.parse(f, format): 

288 self.add(record, accession) 

289 

290 

291 def seen_url(self, url: str) -> bool: 

292 """Checks if a given URL has been seen (i.e., processed) before and present in the SeqBank. 

293 

294 Args: 

295 url (str): The URL to check. 

296 

297 Returns: 

298 bool: True if the URL has been seen (i.e., exists in the SeqBank), otherwise False. 

299 """ 

300 return self.key_url(url) in self.file 

301 

302 def save_seen_url(self, url: str) -> None: 

303 """Saves a URL as 'seen' by adding it to the SeqBank with a timestamp. 

304 

305 Args: 

306 url (str): The URL to save as seen. 

307 

308 Returns: 

309 None 

310 """ 

311 url_key = self.key_url(url) 

312 self.file[url_key] = bytes(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "ascii") 

313 

314 def add_url( 

315 self, 

316 url: str, 

317 progress=None, 

318 format: str = "", 

319 force: bool = False, 

320 overall_task=None, 

321 tmp_dir: str | Path | None = None, 

322 ) -> bool: 

323 """Downloads and adds sequences from a URL to the SeqBank. 

324 

325 This method downloads a file from a given URL, processes it to extract sequences, and adds them to the SeqBank. 

326 If the URL has already been processed, it can be skipped unless `force=True` is provided. 

327 

328 Args: 

329 url (str): The URL to download the sequence file from. 

330 progress (Progress, optional): A rich progress bar to display progress of the download and file processing. Defaults to None. 

331 format (str, optional): The format of the sequence file (e.g., "fasta", "fastq"). If not provided, it will be auto-detected. Defaults to "". 

332 force (bool, optional): Whether to force downloading and processing the URL even if it has been seen before. Defaults to False. 

333 overall_task (int | None, optional): An optional task ID for tracking overall progress. Defaults to None. 

334 tmp_dir (str | Path | None, optional): A temporary directory to store the downloaded file. Defaults to None. 

335 

336 Returns: 

337 bool: True if the URL was successfully processed and added, False otherwise. 

338 """ 

339 url_key = self.key_url(url) 

340 if url_key in self.file and not force: 

341 return False 

342 

343 with TemporaryDirectory(prefix=tmp_dir) as tmpdirname: 

344 local_path = tmpdirname / Path(url).name 

345 try: 

346 download_file(url, local_path) 

347 self.add_file(local_path, format=format, progress=progress, overall_task=overall_task) 

348 self.save_seen_url(url) 

349 except Exception as err: 

350 print(f"Failed to add URL: {url}: {err}") 

351 return False 

352 

353 return True 

354 

355 def get_accessions(self) -> set[str]: 

356 """Retrieves all accessions stored in the SeqBank. 

357 

358 This method iterates through the SeqBank database keys and collects all accessions that do not belong 

359 to the internal '/seqbank/' namespace. 

360 

361 Returns: 

362 set[str]: A set of all accessions present in the SeqBank. 

363 """ 

364 accessions = set() 

365 file = self.file 

366 

367 for key in file.keys(): 

368 accession = key.decode("ascii") 

369 if not accession.startswith("/seqbank/"): 

370 accessions.update([accession]) 

371 

372 return accessions 

373 

374 def missing(self, accessions: list[str] | set[str]) -> set[str]: 

375 """Finds accessions that are not present in the SeqBank. 

376 

377 This method checks a list or set of accessions and returns those that are missing from the SeqBank. 

378 

379 Args: 

380 accessions (list[str] | set[str]): A list or set of accessions to check for presence in the SeqBank. 

381 

382 Returns: 

383 set[str]: A set of accessions that are missing from the SeqBank. 

384 """ 

385 missing = set() 

386 

387 for accession in track(accessions): 

388 if accession not in self: 

389 missing.add(accession) 

390 

391 return missing 

392 

393 def add_urls( 

394 self, 

395 urls: list[str], 

396 max: int = 0, 

397 format: str = "", 

398 force: bool = False, 

399 workers: int = -1, 

400 tmp_dir: str | Path | None = None, 

401 ) -> None: 

402 """Downloads and adds sequences from a list of URLs to the SeqBank. 

403 

404 This method processes a list of URLs, downloads the corresponding sequence files, and adds them to the SeqBank. 

405 It filters out URLs that have already been processed unless `force=True` is specified, and it can limit the number 

406 of URLs processed based on the `max` argument. The processing can be parallelized using the `workers` argument. 

407 

408 Args: 

409 urls (list[str]): A list of URLs to download and process. 

410 max (int, optional): Maximum number of URLs to process. If set to 0, all URLs will be processed. Defaults to 0. 

411 format (str, optional): The format of the sequence files (e.g., "fasta", "fastq"). If not provided, it will be auto-detected. Defaults to "". 

412 force (bool, optional): Whether to force re-processing of URLs even if they were processed before. Defaults to False. 

413 workers (int, optional): Number of workers to use for parallel processing. If set to -1, all available CPU cores will be used. Defaults to -1. 

414 tmp_dir (str | Path | None, optional): A temporary directory to store downloaded files. Defaults to None. 

415 

416 Returns: 

417 None 

418 """ 

419 # only add the URLs that haven't been seen before 

420 urls_to_add = [] 

421 for url in urls: 

422 if not self.seen_url(url): 

423 urls_to_add.append(url) 

424 

425 # truncate URLs list to `max` if requested 

426 if max and len(urls_to_add) >= max: 

427 break 

428 

429 with Progress(*Progress.get_default_columns(), TimeElapsedColumn(), MofNCompleteColumn()) as progress: 

430 parallel = Parallel(n_jobs=workers, prefer="threads") 

431 add_url = delayed(self.add_url) 

432 overall_task = progress.add_task(f"[bold red]Adding URLs", total=len(urls_to_add)) 

433 parallel( 

434 add_url(url, progress=progress, format=format, force=force, overall_task=overall_task, tmp_dir=tmp_dir) 

435 for url in urls_to_add 

436 ) 

437 

438 def ls(self) -> None: 

439 """ 

440 Lists all accessions in the SeqBank. 

441 

442 Iterates through the keys in the SeqBank and prints each one, decoded from bytes to ASCII. 

443 

444 Returns: 

445 None 

446 """ 

447 for k in self.file.keys(): 

448 print(k.decode("ascii")) 

449 

450 def add_files( 

451 self, 

452 files: list[str], 

453 format: str = "", 

454 workers: int = 1, 

455 filter: Path | list[str] | set[str] | None = None, 

456 ) -> None: 

457 """ 

458 Adds sequences from multiple files to the SeqBank. 

459 

460 This method processes a list of file paths, downloading and adding sequences from each file to the SeqBank. 

461 It supports parallel processing and optional filtering of specific accessions. The `max` argument limits the number of files to process. 

462 

463 Args: 

464 files (list[str]): A list of file paths to process. 

465 max (int, optional): Maximum number of files to process. If set to 0, all files will be processed. Defaults to 0. 

466 format (str, optional): The format of the sequence files (e.g., "fasta", "fastq"). If not provided, it will be auto-detected. Defaults to "". 

467 workers (int, optional): Number of workers to use for parallel processing. Defaults to 1. 

468 filter (Path | list[str] | set[str] | None, None], optional): A filter for selecting specific accessions. Defaults to None. 

469 

470 Returns: 

471 None 

472 """ 

473 filter = parse_filter(filter) 

474 

475 with Progress(*Progress.get_default_columns(), TimeElapsedColumn(), MofNCompleteColumn()) as progress: 

476 parallel = Parallel(n_jobs=workers, prefer="threads") 

477 add_file = delayed(self.add_file) 

478 overall_task = progress.add_task(f"[bold red]Adding files", total=len(files)) 

479 parallel( 

480 add_file(file, progress=progress, format=format, overall_task=overall_task, filter=filter) 

481 for file in files 

482 ) 

483 

484 def copy(self, other: "SeqBank") -> None: 

485 """ 

486 Copies all entries from the current SeqBank to another SeqBank instance. 

487 

488 This method iterates over all key-value pairs in the current SeqBank and adds them to the `other` SeqBank instance. 

489 The `other` SeqBank must be writable. 

490 

491 Args: 

492 other (SeqBank): The target SeqBank instance where entries will be copied. 

493 

494 Returns: 

495 None 

496 """ 

497 for k, v in track(self.file.items()): 

498 other.file[k] = v 

499 

500 def numpy(self, accession: str) -> np.ndarray: 

501 """ 

502 Retrieves the sequence data for a given accession and returns it as an unsigned char NumPy array. 

503 

504 Args: 

505 accession (str): The accession key for which the sequence data is retrieved. 

506 

507 Returns: 

508 np.ndarray: The sequence data associated with the given accession, represented as an unsigned char NumPy array. 

509 """ 

510 return np.frombuffer(self[accession], dtype="u1") 

511 

512 def string(self, accession: str) -> str: 

513 """ 

514 Retrieves the sequence data for a given accession and returns it as a string. 

515 

516 Args: 

517 accession (str): The accession key for which the sequence data is retrieved. 

518 

519 Returns: 

520 str: The sequence data associated with the given accession, represented as a string. 

521 """ 

522 data = self[accession] 

523 return bytes_to_str(data) 

524 

525 def record(self, accession: str) -> SeqRecord: 

526 """ 

527 Retrieves the sequence data for a given accession and returns it as a BioPython SeqRecord object. 

528 

529 Args: 

530 accession (str): The accession key for which the sequence data is retrieved. 

531 

532 Returns: 

533 SeqRecord: A BioPython SeqRecord object containing the sequence data, with the given accession as its ID and an empty description. 

534 """ 

535 record = SeqRecord( 

536 Seq(self.string(accession)), 

537 id=accession, 

538 description="", 

539 ) 

540 return record 

541 

542 def export(self, output: Path | str, format: str = "", accessions: list[str] | str | Path | None = None) -> None: 

543 """ 

544 Exports the data from the SeqBank to a file using BioPython's SeqIO. 

545 

546 Args: 

547 output (Path | str): The path or filename where the data should be exported. 

548 format (str, optional): The file format for exporting. If not specified, it will be inferred from the file extension. 

549 accessions (list[str] | str | Path | None, optional): A list of accessions to export. If a file path or string is provided, it will be read to obtain the list of accessions. If None, all accessions in the SeqBank are exported. 

550 

551 Returns: 

552 None 

553 """ 

554 accessions = accessions or self.get_accessions() 

555 

556 # Read list of accessions if given file 

557 if isinstance(accessions, (str, Path)): 

558 accessions = Path(accessions).read_text().strip().split("\n") 

559 

560 format = format or get_file_format(output) 

561 with open(output, "w") as f: 

562 for accession in accessions: 

563 SeqIO.write(self.record(accession), f, format) 

564 

565 def lengths_dict(self) -> dict[str, int]: 

566 """ 

567 Returns a dictionary where the keys are the accessions and the values 

568 are the corresponding lengths of each sequence. 

569 

570 Returns: 

571 dict[str, int]: A dictionary mapping each accession to the length of its corresponding sequence. 

572 """ 

573 accession_lengths = {} 

574 

575 for accession in self.get_accessions(): 

576 # Retrieve the sequence 

577 sequence = self[accession] 

578 # Store the length of the sequence in the dictionary 

579 accession_lengths[accession] = len(sequence) 

580 

581 return accession_lengths 

582 

583 def histogram(self, nbins: int = 30, max:int=0, min:int=0) -> go.Figure: 

584 """ 

585 Creates a histogram of the lengths of all sequences and returns the Plotly figure object. 

586 

587 Args: 

588 nbins (int): The number of bins for the histogram. Default is 30. 

589 max (int): The maximum length of the sequence to include in the histogram. Default is all. 

590 min (int): The minimum length of the sequence to include in the histogram. Default is 0. 

591 

592 Returns: 

593 go.Figure: A Plotly figure object representing the histogram of sequence lengths. 

594 """ 

595 # Get the dictionary of accession lengths 

596 accession_lengths = self.lengths_dict() 

597 

598 # Extract the lengths from the dictionary 

599 lengths = np.array(list(accession_lengths.values())) 

600 

601 if max: 

602 lengths = lengths[lengths <= max] 

603 if min: 

604 lengths = lengths[lengths >= min] 

605 

606 # Create the histogram using Plotly Express 

607 fig = px.histogram(lengths, nbins=nbins, title="Histogram of Sequence Lengths") 

608 

609 # Add labels and customize the layout, removing the legend 

610 fig.update_layout(xaxis_title="Sequence Length", yaxis_title="Count", showlegend=False) # Remove the legend 

611 

612 format_fig(fig) 

613 

614 return fig