Coverage for seqbank/seqbank.py: 100.00%
211 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-02 04:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-02 04:29 +0000
1from functools import cached_property
2import numpy as np
3from pathlib import Path
4from joblib import Parallel, delayed
5import plotly.express as px
6import plotly.graph_objs as go
7from Bio.Seq import Seq
8from Bio.SeqRecord import SeqRecord
9from Bio import SeqIO
10from attrs import define
11from rich.progress import track
12import pyfastx
13from rich.progress import Progress, TimeElapsedColumn, MofNCompleteColumn
14from datetime import datetime
15from speedict import Rdict, Options, DBCompressionType, AccessType
16import atexit
18from .transform import seq_to_bytes, bytes_to_str
19from .io import get_file_format, open_path, download_file, seq_count, TemporaryDirectory
20from .exceptions import SeqBankError
21from .utils import parse_filter, format_fig
24@define(slots=False)
25class SeqBank:
26 path: Path
27 write: bool = False
29 def __attrs_post_init__(self) -> None:
30 """Initializes the SeqBank object after attributes are set.
32 Expands the user path for the SeqBank file and checks if it exists when not in write mode.
34 Raises:
35 FileNotFoundError: If the SeqBank file is not found at the given path when write mode is disabled.
36 """
37 self.path = Path(self.path).expanduser()
38 if not self.write and not self.path.exists():
39 raise FileNotFoundError(f"Cannot find SeqBank file at path: {self.path}")
41 def key(self, accession: str) -> str:
42 """Generates a key for a given accession.
44 Args:
45 accession (str): Accession string used to generate the key.
47 Returns:
48 str: A byte-encoded key as a string.
49 """
50 return bytes(accession, "ascii")
52 def key_url(self, url: str) -> str:
53 """Generates a key for a given URL.
55 Args:
56 url (str): The URL for which the key is generated.
58 Returns:
59 str: A byte-encoded key string prefixed with '/seqbank/url/'.
60 """
61 return self.key("/seqbank/url/" + url)
63 def close(self) -> None:
64 """Closes the SeqBank database connection.
66 Attempts to close the database connection, if it's open, and silently handles any exceptions.
67 """
68 try:
69 self._db.close()
70 except Exception:
71 pass
73 @cached_property
74 def file(self) -> Rdict:
75 """Initializes and configures the Rdict database for sequence storage.
77 Configures options for the database, such as compression type, optimization, and maximum open files.
78 Registers the close method to be executed upon program exit to ensure the database is closed.
80 Returns:
81 Rdict: The configured Rdict database object.
82 """
83 options = Options(raw_mode=True)
84 options.set_compression_type(DBCompressionType.none())
85 options.set_optimize_filters_for_hits(True)
86 options.optimize_for_point_lookup(1024)
87 options.set_max_open_files(500)
89 self._db = Rdict(
90 path=str(self.path),
91 options=options,
92 access_type=AccessType.read_write() if self.write else AccessType.read_only(),
93 )
95 atexit.register(self.close)
96 return self._db
98 def __len__(self) -> int:
99 """Calculates the total number of items in the SeqBank.
101 Iterates over all the keys in the database and counts them.
103 Returns:
104 int: The number of entries in the SeqBank.
105 """
106 count = 0
107 for _ in track(self.file.keys()):
108 count += 1
109 return count
111 def __getitem__(self, accession: str) -> np.ndarray:
112 """Retrieves the sequence data associated with a given accession.
114 Args:
115 accession (str): The accession key to look up in the SeqBank.
117 Returns:
118 np.ndarray: The sequence data stored in the SeqBank for the given accession.
120 Raises:
121 SeqBankError: If the accession cannot be read or an error occurs during retrieval.
122 """
123 try:
124 key = self.key(accession)
125 file = self.file
126 return file[key]
127 except Exception as err:
128 raise SeqBankError(f"Failed to read {accession} in SeqBank {self.path}:\n{err}")
130 def items(self):
131 """Yields all key-value pairs from the SeqBank.
133 The key is the accession, and the value is the sequence data in NumPy array format.
135 Yields:
136 tuple: A tuple containing the accession (key) and the sequence data (value) as a NumPy array.
137 """
138 for k, v in self.file.items():
139 yield k, np.frombuffer(v, dtype="u1")
141 def __contains__(self, accession: str) -> bool:
142 """Checks if a given accession exists in the SeqBank.
144 Args:
145 accession (str): The accession to check for existence.
147 Returns:
148 bool: True if the accession exists in the SeqBank, otherwise False.
149 """
150 try:
151 return self.key(accession) in self.file
152 except Exception:
153 return False
155 def delete(self, accession: str) -> None:
156 """Deletes a sequence entry from the SeqBank by its accession.
158 Args:
159 accession (str): The accession of the sequence to delete.
161 Returns:
162 None
163 """
164 key = self.key(accession)
165 f = self.file
166 if key in f:
167 del f[key]
169 def add(self, seq: str | Seq | SeqRecord | np.ndarray, accession: str) -> None:
170 """Adds a sequence to the SeqBank with a given accession.
172 This method accepts a sequence in various formats (string, Seq, SeqRecord, or NumPy array)
173 and stores it in the SeqBank after appropriate conversion to byte format.
175 Args:
176 seq (str | Seq | SeqRecord | np.ndarray): The sequence to add to the SeqBank.
177 It can be a string, Bio.Seq object, SeqRecord, or a NumPy array.
178 accession (str): The accession key for the sequence to be stored under.
180 Returns:
181 None
182 """
183 key = self.key(accession)
185 if isinstance(seq, SeqRecord):
186 seq = seq.seq
187 if isinstance(seq, Seq):
188 seq = str(seq)
189 if isinstance(seq, str):
190 seq = seq_to_bytes(seq)
192 self.file[key] = seq
194 def add_file(
195 self,
196 path: Path,
197 format: str = "",
198 progress=None,
199 overall_task=None,
200 filter: Path | list | set | None = None,
201 ) -> None:
202 """Adds sequences from a file to the SeqBank.
204 This method processes a sequence file in various formats (e.g., FASTA, FASTQ), optionally filtering specific accessions,
205 and adds the sequences to the SeqBank. Progress tracking is available for large file imports.
207 Args:
208 path (Path): The path to the sequence file.
209 format (str, optional): The format of the sequence file (e.g., "fasta", "fastq"). If not provided, it will be auto-detected.
210 progress (Progress, optional): A rich progress bar to display the import progress. Defaults to None.
211 overall_task (int | None, optional): An optional task ID for tracking the overall progress. Defaults to None.
212 filter (Path | list | set | None, optional): A filter for selecting specific accessions. Defaults to None.
214 Returns:
215 None
216 """
217 filter = parse_filter(filter)
218 format = format or get_file_format(path)
219 progress = progress or Progress()
221 # If fasta or fastq use pyfastx for speed
222 if format in ["fasta", "fastq"]:
223 total = sum(1 for _ in pyfastx.Fasta(str(path), build_index=False))
224 task = progress.add_task(f"[magenta]{path.name}", total=total)
226 for accession, seq in pyfastx.Fasta(str(path), build_index=False):
227 if filter and accession not in filter:
228 continue
229 self.add(seq, accession)
230 progress.update(task, advance=1)
231 else:
232 total = seq_count(path)
234 with open_path(path) as f:
235 task = progress.add_task(f"[magenta]{path.name}", total=total)
236 for record in SeqIO.parse(f, format):
237 if filter and record.id not in filter:
238 continue
240 self.add(record, record.id)
241 progress.update(task, advance=1)
243 progress.update(task, visible=False)
244 if overall_task is not None:
245 progress.update(overall_task, advance=1)
247 print("Added", path.name)
249 def add_sequence_from_file(
250 self,
251 accession: str,
252 path: Path,
253 format: str = "",
254 ) -> None:
255 """
256 Adds a single sequence from a file to the SeqBank.
258 This method processes a single sequence from a file in various formats (e.g., FASTA, FASTQ).
259 The accession for this sequence is provided as an argument.
261 Args:
262 accession (str): The accession key for the sequence to be stored under.
263 path (Path): The path to the sequence file.
264 format (str, optional): The format of the sequence file (e.g., "fasta", "fastq"). If not provided, it will be auto-detected.
266 Returns:
267 None
268 """
269 format = format or get_file_format(path)
271 def check_total(total: int, path: Path) -> None:
272 if total > 1:
273 raise SeqBankError(f"Multiple sequences found in {path}, found {total}")
275 # If fasta or fastq use pyfastx for speed
276 if format in ["fasta", "fastq"]:
277 total = sum(1 for _ in pyfastx.Fasta(str(path), build_index=False))
278 check_total(total, path)
280 for _, seq in pyfastx.Fasta(str(path), build_index=False):
281 self.add(seq, accession)
282 else:
283 total = seq_count(path)
284 check_total(total, path)
286 with open_path(path) as f:
287 for record in SeqIO.parse(f, format):
288 self.add(record, accession)
291 def seen_url(self, url: str) -> bool:
292 """Checks if a given URL has been seen (i.e., processed) before and present in the SeqBank.
294 Args:
295 url (str): The URL to check.
297 Returns:
298 bool: True if the URL has been seen (i.e., exists in the SeqBank), otherwise False.
299 """
300 return self.key_url(url) in self.file
302 def save_seen_url(self, url: str) -> None:
303 """Saves a URL as 'seen' by adding it to the SeqBank with a timestamp.
305 Args:
306 url (str): The URL to save as seen.
308 Returns:
309 None
310 """
311 url_key = self.key_url(url)
312 self.file[url_key] = bytes(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "ascii")
314 def add_url(
315 self,
316 url: str,
317 progress=None,
318 format: str = "",
319 force: bool = False,
320 overall_task=None,
321 tmp_dir: str | Path | None = None,
322 ) -> bool:
323 """Downloads and adds sequences from a URL to the SeqBank.
325 This method downloads a file from a given URL, processes it to extract sequences, and adds them to the SeqBank.
326 If the URL has already been processed, it can be skipped unless `force=True` is provided.
328 Args:
329 url (str): The URL to download the sequence file from.
330 progress (Progress, optional): A rich progress bar to display progress of the download and file processing. Defaults to None.
331 format (str, optional): The format of the sequence file (e.g., "fasta", "fastq"). If not provided, it will be auto-detected. Defaults to "".
332 force (bool, optional): Whether to force downloading and processing the URL even if it has been seen before. Defaults to False.
333 overall_task (int | None, optional): An optional task ID for tracking overall progress. Defaults to None.
334 tmp_dir (str | Path | None, optional): A temporary directory to store the downloaded file. Defaults to None.
336 Returns:
337 bool: True if the URL was successfully processed and added, False otherwise.
338 """
339 url_key = self.key_url(url)
340 if url_key in self.file and not force:
341 return False
343 with TemporaryDirectory(prefix=tmp_dir) as tmpdirname:
344 local_path = tmpdirname / Path(url).name
345 try:
346 download_file(url, local_path)
347 self.add_file(local_path, format=format, progress=progress, overall_task=overall_task)
348 self.save_seen_url(url)
349 except Exception as err:
350 print(f"Failed to add URL: {url}: {err}")
351 return False
353 return True
355 def get_accessions(self) -> set[str]:
356 """Retrieves all accessions stored in the SeqBank.
358 This method iterates through the SeqBank database keys and collects all accessions that do not belong
359 to the internal '/seqbank/' namespace.
361 Returns:
362 set[str]: A set of all accessions present in the SeqBank.
363 """
364 accessions = set()
365 file = self.file
367 for key in file.keys():
368 accession = key.decode("ascii")
369 if not accession.startswith("/seqbank/"):
370 accessions.update([accession])
372 return accessions
374 def missing(self, accessions: list[str] | set[str]) -> set[str]:
375 """Finds accessions that are not present in the SeqBank.
377 This method checks a list or set of accessions and returns those that are missing from the SeqBank.
379 Args:
380 accessions (list[str] | set[str]): A list or set of accessions to check for presence in the SeqBank.
382 Returns:
383 set[str]: A set of accessions that are missing from the SeqBank.
384 """
385 missing = set()
387 for accession in track(accessions):
388 if accession not in self:
389 missing.add(accession)
391 return missing
393 def add_urls(
394 self,
395 urls: list[str],
396 max: int = 0,
397 format: str = "",
398 force: bool = False,
399 workers: int = -1,
400 tmp_dir: str | Path | None = None,
401 ) -> None:
402 """Downloads and adds sequences from a list of URLs to the SeqBank.
404 This method processes a list of URLs, downloads the corresponding sequence files, and adds them to the SeqBank.
405 It filters out URLs that have already been processed unless `force=True` is specified, and it can limit the number
406 of URLs processed based on the `max` argument. The processing can be parallelized using the `workers` argument.
408 Args:
409 urls (list[str]): A list of URLs to download and process.
410 max (int, optional): Maximum number of URLs to process. If set to 0, all URLs will be processed. Defaults to 0.
411 format (str, optional): The format of the sequence files (e.g., "fasta", "fastq"). If not provided, it will be auto-detected. Defaults to "".
412 force (bool, optional): Whether to force re-processing of URLs even if they were processed before. Defaults to False.
413 workers (int, optional): Number of workers to use for parallel processing. If set to -1, all available CPU cores will be used. Defaults to -1.
414 tmp_dir (str | Path | None, optional): A temporary directory to store downloaded files. Defaults to None.
416 Returns:
417 None
418 """
419 # only add the URLs that haven't been seen before
420 urls_to_add = []
421 for url in urls:
422 if not self.seen_url(url):
423 urls_to_add.append(url)
425 # truncate URLs list to `max` if requested
426 if max and len(urls_to_add) >= max:
427 break
429 with Progress(*Progress.get_default_columns(), TimeElapsedColumn(), MofNCompleteColumn()) as progress:
430 parallel = Parallel(n_jobs=workers, prefer="threads")
431 add_url = delayed(self.add_url)
432 overall_task = progress.add_task(f"[bold red]Adding URLs", total=len(urls_to_add))
433 parallel(
434 add_url(url, progress=progress, format=format, force=force, overall_task=overall_task, tmp_dir=tmp_dir)
435 for url in urls_to_add
436 )
438 def ls(self) -> None:
439 """
440 Lists all accessions in the SeqBank.
442 Iterates through the keys in the SeqBank and prints each one, decoded from bytes to ASCII.
444 Returns:
445 None
446 """
447 for k in self.file.keys():
448 print(k.decode("ascii"))
450 def add_files(
451 self,
452 files: list[str],
453 format: str = "",
454 workers: int = 1,
455 filter: Path | list[str] | set[str] | None = None,
456 ) -> None:
457 """
458 Adds sequences from multiple files to the SeqBank.
460 This method processes a list of file paths, downloading and adding sequences from each file to the SeqBank.
461 It supports parallel processing and optional filtering of specific accessions. The `max` argument limits the number of files to process.
463 Args:
464 files (list[str]): A list of file paths to process.
465 max (int, optional): Maximum number of files to process. If set to 0, all files will be processed. Defaults to 0.
466 format (str, optional): The format of the sequence files (e.g., "fasta", "fastq"). If not provided, it will be auto-detected. Defaults to "".
467 workers (int, optional): Number of workers to use for parallel processing. Defaults to 1.
468 filter (Path | list[str] | set[str] | None, None], optional): A filter for selecting specific accessions. Defaults to None.
470 Returns:
471 None
472 """
473 filter = parse_filter(filter)
475 with Progress(*Progress.get_default_columns(), TimeElapsedColumn(), MofNCompleteColumn()) as progress:
476 parallel = Parallel(n_jobs=workers, prefer="threads")
477 add_file = delayed(self.add_file)
478 overall_task = progress.add_task(f"[bold red]Adding files", total=len(files))
479 parallel(
480 add_file(file, progress=progress, format=format, overall_task=overall_task, filter=filter)
481 for file in files
482 )
484 def copy(self, other: "SeqBank") -> None:
485 """
486 Copies all entries from the current SeqBank to another SeqBank instance.
488 This method iterates over all key-value pairs in the current SeqBank and adds them to the `other` SeqBank instance.
489 The `other` SeqBank must be writable.
491 Args:
492 other (SeqBank): The target SeqBank instance where entries will be copied.
494 Returns:
495 None
496 """
497 for k, v in track(self.file.items()):
498 other.file[k] = v
500 def numpy(self, accession: str) -> np.ndarray:
501 """
502 Retrieves the sequence data for a given accession and returns it as an unsigned char NumPy array.
504 Args:
505 accession (str): The accession key for which the sequence data is retrieved.
507 Returns:
508 np.ndarray: The sequence data associated with the given accession, represented as an unsigned char NumPy array.
509 """
510 return np.frombuffer(self[accession], dtype="u1")
512 def string(self, accession: str) -> str:
513 """
514 Retrieves the sequence data for a given accession and returns it as a string.
516 Args:
517 accession (str): The accession key for which the sequence data is retrieved.
519 Returns:
520 str: The sequence data associated with the given accession, represented as a string.
521 """
522 data = self[accession]
523 return bytes_to_str(data)
525 def record(self, accession: str) -> SeqRecord:
526 """
527 Retrieves the sequence data for a given accession and returns it as a BioPython SeqRecord object.
529 Args:
530 accession (str): The accession key for which the sequence data is retrieved.
532 Returns:
533 SeqRecord: A BioPython SeqRecord object containing the sequence data, with the given accession as its ID and an empty description.
534 """
535 record = SeqRecord(
536 Seq(self.string(accession)),
537 id=accession,
538 description="",
539 )
540 return record
542 def export(self, output: Path | str, format: str = "", accessions: list[str] | str | Path | None = None) -> None:
543 """
544 Exports the data from the SeqBank to a file using BioPython's SeqIO.
546 Args:
547 output (Path | str): The path or filename where the data should be exported.
548 format (str, optional): The file format for exporting. If not specified, it will be inferred from the file extension.
549 accessions (list[str] | str | Path | None, optional): A list of accessions to export. If a file path or string is provided, it will be read to obtain the list of accessions. If None, all accessions in the SeqBank are exported.
551 Returns:
552 None
553 """
554 accessions = accessions or self.get_accessions()
556 # Read list of accessions if given file
557 if isinstance(accessions, (str, Path)):
558 accessions = Path(accessions).read_text().strip().split("\n")
560 format = format or get_file_format(output)
561 with open(output, "w") as f:
562 for accession in accessions:
563 SeqIO.write(self.record(accession), f, format)
565 def lengths_dict(self) -> dict[str, int]:
566 """
567 Returns a dictionary where the keys are the accessions and the values
568 are the corresponding lengths of each sequence.
570 Returns:
571 dict[str, int]: A dictionary mapping each accession to the length of its corresponding sequence.
572 """
573 accession_lengths = {}
575 for accession in self.get_accessions():
576 # Retrieve the sequence
577 sequence = self[accession]
578 # Store the length of the sequence in the dictionary
579 accession_lengths[accession] = len(sequence)
581 return accession_lengths
583 def histogram(self, nbins: int = 30, max:int=0, min:int=0) -> go.Figure:
584 """
585 Creates a histogram of the lengths of all sequences and returns the Plotly figure object.
587 Args:
588 nbins (int): The number of bins for the histogram. Default is 30.
589 max (int): The maximum length of the sequence to include in the histogram. Default is all.
590 min (int): The minimum length of the sequence to include in the histogram. Default is 0.
592 Returns:
593 go.Figure: A Plotly figure object representing the histogram of sequence lengths.
594 """
595 # Get the dictionary of accession lengths
596 accession_lengths = self.lengths_dict()
598 # Extract the lengths from the dictionary
599 lengths = np.array(list(accession_lengths.values()))
601 if max:
602 lengths = lengths[lengths <= max]
603 if min:
604 lengths = lengths[lengths >= min]
606 # Create the histogram using Plotly Express
607 fig = px.histogram(lengths, nbins=nbins, title="Histogram of Sequence Lengths")
609 # Add labels and customize the layout, removing the legend
610 fig.update_layout(xaxis_title="Sequence Length", yaxis_title="Count", showlegend=False) # Remove the legend
612 format_fig(fig)
614 return fig