"""
Tools for accessing POD5 data from PyArrow files
"""
import enum
import mmap
from collections import namedtuple
from dataclasses import fields
from io import IOBase
from functools import total_ordering
from pathlib import Path
from typing import (
Collection,
Dict,
Generator,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
from uuid import UUID
import lib_pod5 as p5b
import numpy as np
import numpy.typing as npt
import packaging.version
import pyarrow as pa
from pod5.pod5_types import (
Calibration,
EndReason,
EndReasonEnum,
PathOrStr,
Pore,
Read,
RunInfo,
ShiftScalePair,
)
from .api_utils import Pod5ApiException, format_read_ids, pack_read_ids, safe_close
from .signal_tools import vbz_decompress_signal, vbz_decompress_signal_into
ReadRecordV3Columns = namedtuple(
"ReadRecordV3Columns",
[
"read_id",
"read_number",
"start",
"channel",
"well",
"median_before",
"pore_type",
"calibration_offset",
"calibration_scale",
"end_reason",
"end_reason_forced",
"run_info",
"signal",
"num_minknow_events",
"tracked_scaling_scale",
"tracked_scaling_shift",
"predicted_scaling_scale",
"predicted_scaling_shift",
"num_reads_since_mux_change",
"time_since_mux_change",
"num_samples",
],
)
[docs]@total_ordering
class ReadTableVersion(enum.Enum):
"""Version of read table"""
V3: int = 3
def __lt__(self, other) -> bool:
if self.__class__ is other.__class__:
return self.value < other.value
return NotImplemented
def __eq__(self, other) -> bool:
if self.__class__ is other.__class__:
return self.value == other.value
return NotImplemented
Signal = namedtuple("Signal", ["signal", "samples"])
SignalRowInfo = namedtuple(
"SignalRowInfo",
["batch_index", "batch_row_index", "sample_count", "byte_count"],
)
[docs]class ReadRecord:
"""
Represents the data for a single read from a pod5 record.
"""
[docs] def __init__(
self,
reader: "Reader",
batch: "ReadRecordBatch",
row: int,
batch_signal_cache: Optional[List[npt.NDArray[np.int16]]] = None,
selected_batch_index: Optional[int] = None,
):
""" """
self._reader = reader
self._batch = batch
self._row = row
self._batch_signal_cache = batch_signal_cache
self._selected_batch_index = selected_batch_index
@property
def read_id(self) -> UUID:
"""
Get the unique read identifier for the read as a `UUID`.
"""
return UUID(bytes=self._batch.columns.read_id[self._row].as_py())
@property
def read_number(self) -> int:
"""
Get the integer read number of the read.
"""
return self._batch.columns.read_number[self._row].as_py() # type: ignore
@property
def start_sample(self) -> int:
"""
Get the absolute sample which the read started.
"""
return self._batch.columns.start[self._row].as_py() # type: ignore
@property
def num_samples(self) -> int:
"""
Get the number of samples in the reads signal data.
"""
return self._batch.columns.num_samples[self._row].as_py() # type: ignore
@property
def median_before(self) -> float:
"""
Get the median before level (in pico amps) for the read.
"""
return self._batch.columns.median_before[self._row].as_py() # type: ignore
@property
def num_minknow_events(self) -> float:
"""
Find the number of minknow events in the read.
"""
return self._batch.columns.num_minknow_events[self._row].as_py() # type: ignore
@property
def tracked_scaling(self) -> ShiftScalePair:
"""
Find the tracked scaling value in the read.
"""
return ShiftScalePair(
self._batch.columns.tracked_scaling_shift[self._row].as_py(),
self._batch.columns.tracked_scaling_scale[self._row].as_py(),
)
@property
def predicted_scaling(self) -> ShiftScalePair:
"""
Find the predicted scaling value in the read.
"""
return ShiftScalePair(
self._batch.columns.predicted_scaling_shift[self._row].as_py(),
self._batch.columns.predicted_scaling_scale[self._row].as_py(),
)
@property
def num_reads_since_mux_change(self) -> int:
"""
Number of selected reads since the last mux change on this reads channel.
"""
return self._batch.columns.num_reads_since_mux_change[self._row].as_py() # type: ignore
@property
def time_since_mux_change(self) -> int:
"""
Time in seconds since the last mux change on this reads channel.
"""
return self._batch.columns.time_since_mux_change[self._row].as_py() # type: ignore
@property
def pore(self) -> Pore:
"""
Get the pore data associated with the read.
"""
return Pore(
self._batch.columns.channel[self._row].as_py(),
self._batch.columns.well[self._row].as_py(),
self._batch.columns.pore_type[self._row].as_py(),
)
@property
def calibration(self) -> Calibration:
"""
Get the calibration data associated with the read.
"""
return Calibration(
self._batch.columns.calibration_offset[self._row].as_py(),
self._batch.columns.calibration_scale[self._row].as_py(),
)
@property
def calibration_digitisation(self) -> int:
"""
Get the digitisation value used by the sequencer.
Intended to assist workflows ported from legacy file formats.
"""
return self.run_info.adc_max - self.run_info.adc_min + 1
@property
def calibration_range(self) -> float:
"""
Get the calibration range value.
Intended to assist workflows ported from legacy file formats.
"""
return self.calibration.scale * self.calibration_digitisation
@property
def end_reason(self) -> EndReason:
"""
Get the end reason data associated with the read.
"""
return EndReason(
reason=EndReasonEnum[
self._batch.columns.end_reason[self._row].as_py().upper()
],
forced=self._batch.columns.end_reason_forced[self._row].as_py(),
)
@property
def run_info(self) -> RunInfo:
"""
Get the run info data associated with the read.
"""
return self._reader._lookup_run_info(self._batch, self._row)
@property
def end_reason_index(self) -> int:
"""
Get the dictionary index of the end reason data associated with the read.
This property is the same as the EndReason enumeration value.
"""
return self._batch.columns.end_reason[self._row].index.as_py() # type: ignore
@property
def run_info_index(self) -> int:
"""
Get the dictionary index of the run info data associated with the read.
"""
return self._batch.columns.run_info[self._row].index.as_py() # type: ignore
@property
def sample_count(self) -> int:
"""
Get the number of samples in the reads signal data.
"""
return self.num_samples
@property
def byte_count(self) -> int:
"""
Get the number of bytes used to store the reads data.
"""
return sum(r.byte_count for r in self.signal_rows)
@property
def has_cached_signal(self) -> bool:
"""
Get if cached signal is available for this read.
"""
return self._batch_signal_cache is not None
@property
def signal(self) -> npt.NDArray[np.int16]:
"""
Get the full signal for the read.
Returns
-------
numpy.ndarray[int16]
A numpy array of signal data with int16 type.
"""
if self._batch_signal_cache is not None:
if self._selected_batch_index is not None:
return self._batch_signal_cache[self._selected_batch_index]
return self._batch_signal_cache[self._row]
rows = self._batch.columns.signal[self._row]
batch_data = [self._find_signal_row_index(r.as_py()) for r in rows]
sample_counts = []
for batch, _, batch_row_index in batch_data:
sample_counts.append(batch.samples[batch_row_index].as_py())
output = np.empty(dtype=np.int16, shape=(sum(sample_counts),))
current_sample_index = 0
for i, (batch, _, batch_row_index) in enumerate(batch_data):
signal = batch.signal
current_row_count = sample_counts[i]
output_slice = output[
current_sample_index : current_sample_index + current_row_count
]
if self._reader.is_vbz_compressed:
vbz_decompress_signal_into(
memoryview(signal[batch_row_index].as_buffer()), output_slice
)
else:
output_slice[:] = signal.to_numpy()
current_sample_index += current_row_count
return output
@property
def signal_pa(self) -> npt.NDArray[np.float32]:
"""
Get the full signal for the read, calibrated in pico amps.
Returns
-------
numpy.ndarray[float32]
A numpy array of signal data in pico amps with float32 type.
"""
return self.calibrate_signal_array(self.signal)
[docs] def signal_for_chunk(self, index: int) -> npt.NDArray[np.int16]:
"""
Get the signal for a given chunk of the read.
Returns
-------
numpy.ndarray[int16]
A numpy array of signal data with int16 type for the specified chunk.
"""
# signal_rows can be used to find details of the signal chunks.
chunk_abs_row_index = self._batch.columns.signal[self._row][index]
return self._get_signal_for_row(chunk_abs_row_index.as_py())
@property
def signal_rows(self) -> List[SignalRowInfo]:
"""
Get all signal rows for the read
Returns
-------
list[SignalRowInfo]
A list of signal row data (as SignalRowInfo) in the read.
"""
def map_signal_row(sig_row) -> SignalRowInfo:
sig_row = sig_row.as_py()
batch, batch_index, batch_row_index = self._find_signal_row_index(sig_row)
return SignalRowInfo(
batch_index,
batch_row_index,
batch.samples[batch_row_index].as_py(),
len(batch.signal[batch_row_index].as_buffer()),
)
return [map_signal_row(r) for r in self._batch.columns.signal[self._row]]
[docs] def calibrate_signal_array(
self, signal_array_adc: npt.NDArray[np.int16]
) -> npt.NDArray[np.float32]:
"""
Transform an array of int16 signal data from ADC space to pA.
Returns
-------
A numpy array of signal data with float32 type.
"""
offset = np.float32(self.calibration.offset)
scale = np.float32(self.calibration.scale)
return (signal_array_adc + offset) * scale
def _find_signal_row_index(self, signal_row: int) -> Tuple[Signal, int, int]:
"""
Map from a signal_row to a Signal, batch index and row index within that batch.
Returns
-------
A Tuple containing the `Signal` and its `batch_index` and `row_index`
"""
sig_row_count: int = self._reader.signal_batch_row_count
sig_batch_idx: int = signal_row // sig_row_count
sig_batch = self._reader._get_signal_batch(sig_batch_idx)
batch_row_idx: int = signal_row - (sig_batch_idx * sig_row_count)
return sig_batch, sig_batch_idx, batch_row_idx
def _get_signal_for_row(self, signal_row: int) -> npt.NDArray[np.int16]:
"""
Get the signal data for a given absolute signal row index
Returns
-------
A numpy array of signal data with int16 type.
"""
batch, _, batch_row_index = self._find_signal_row_index(signal_row)
signal = batch.signal
if self._reader.is_vbz_compressed:
sample_count = batch.samples[batch_row_index].as_py()
return vbz_decompress_signal(
memoryview(signal[batch_row_index].as_buffer()), sample_count
)
return signal.to_numpy()
[docs] def to_read(self) -> Read:
"""
Create a mutable :py:class:`pod5.pod5_types.Read` from this
:py:class:`ReadRecord` instance.
Returns
-------
:py:class:`pod5.pod5_types.Read`
"""
return Read(
read_id=self.read_id,
pore=self.pore,
calibration=self.calibration,
median_before=self.median_before,
end_reason=self.end_reason,
read_number=self.read_number,
run_info=self.run_info,
start_sample=self.start_sample,
signal=self.signal,
)
[docs]class ReadRecordBatch:
"""
Read data for a batch of reads.
"""
[docs] def __init__(self, reader: "Reader", batch: pa.RecordBatch):
""" """
self._reader: "Reader" = reader
self._batch: pa.RecordBatch = batch
self._signal_cache: Optional[p5b.Pod5SignalCacheBatch] = None
self._selected_batch_rows: Optional[Iterable[int]] = None
self._columns: Optional[ReadRecordV3Columns] = None
@property
def columns(self) -> ReadRecordV3Columns:
"""Return the data from this batch as a ReadRecordColumns instance"""
if self._columns is None:
self._columns = ReadRecordV3Columns(
*[
self._batch.column(name)
for name in self._reader._columns_type._fields
]
)
return self._columns
[docs] def set_cached_signal(self, signal_cache: p5b.Pod5SignalCacheBatch) -> None:
"""Set the signal cache"""
self._signal_cache = signal_cache
[docs] def set_selected_batch_rows(self, selected_batch_rows: Iterable[int]) -> None:
"""Set the selected batch rows"""
self._selected_batch_rows = selected_batch_rows
[docs] def reads(self) -> Generator[ReadRecord, None, None]:
"""
Iterate all reads in this batch.
Yields
------
ReadRecord
ReadRecord instances in the file.
"""
signal_cache = None
if self._signal_cache and self._signal_cache.samples:
signal_cache = self._signal_cache.samples
if self._selected_batch_rows is not None:
for idx, row in enumerate(self._selected_batch_rows):
yield ReadRecord(
self._reader,
self,
row,
batch_signal_cache=signal_cache,
selected_batch_index=idx,
)
else:
for i in range(self.num_reads):
yield ReadRecord(self._reader, self, i, batch_signal_cache=signal_cache)
[docs] def get_read(self, row: int) -> ReadRecord:
"""Get the ReadRecord at row index"""
return ReadRecord(self._reader, self, row)
@property
def num_reads(self) -> int:
"""Return the number of rows in this RecordBatch"""
return int(self._batch.num_rows)
@property
def read_id_column(self):
"""
Get the column of read ids for this batch
"""
if self._selected_batch_rows is not None:
return self.columns.read_id.take(self._selected_batch_rows)
return self.columns.read_id
@property
def read_number_column(self):
"""
Get the column of read numbers for this batch
"""
if self._selected_batch_rows is not None:
return self.columns.read_number.take(self._selected_batch_rows)
return self.columns.read_number
@property
def cached_sample_count_column(self) -> npt.NDArray[np.uint64]:
"""
Get the sample counts from the cached signal data
"""
if not self._signal_cache:
raise RuntimeError("No cached signal data available")
return self._signal_cache.sample_count
@property
def cached_samples_column(self) -> List[npt.NDArray[np.int16]]:
"""
Get the samples column from the cached signal data
"""
if not self._signal_cache:
raise RuntimeError("No cached signal data available")
return self._signal_cache.samples
[docs]class ArrowTableHandle:
"""Class for managing arrow file handles and memory view mapping of tables"""
[docs] def __init__(self, location: p5b.EmbeddedFileData) -> None:
"""
Open a pod5 file at the given `path` and use the location data to load
an arrow table (e.g. signal table)
Parameters
----------
location : lib_pod5.pod5_format_pybind.EmbeddedFileData
Location data for how a pod5 file should be spit in memory to read a table.
This is returned from p5b.Pod5FileReader.get_file_X_location methods
Raises
------
Pod5ApiException
If handle could not be opened
"""
# The location data is passed from the p5b.Pod5FileReader.get_file_X_location
# methods
self._location = location
self._path = Path(self._location.file_path)
# Open the file
self._fh = self._path.open("rb")
# Create a memory view of the file and select the region for the table
try:
self._reader = self._open_with_mmap()
except OSError:
# If we fail fall back to a traditional open.
self._reader = self._open_without_mmap()
def _open_without_mmap(self):
class File(IOBase):
def __init__(self, handle, location):
self._handle = handle
self._location = location
self.seek(0, whence=0)
def seek(self, position, whence=0):
if whence == 0:
position = position + self._location.offset
elif whence == 2:
position = (
self._location.offset + self._location.length
) - position
whence = 0
# The new abs location:
abs_location = self._handle.seek(position, whence)
return abs_location - self._location.offset
def read(self, size=-1):
return self._handle.read(size)
return pa.ipc.open_file(pa.PythonFile(File(self._fh, self._location)))
def _open_with_mmap(self):
_mmap = mmap.mmap(self._fh.fileno(), length=0, access=mmap.ACCESS_READ)
file_view = memoryview(_mmap)
arrow_table_view = file_view[
self._location.offset : self._location.offset + self._location.length
]
# Open the table
try:
return pa.ipc.open_file(pa.BufferReader(arrow_table_view))
except pa.ArrowInvalid as exc:
raise Pod5ApiException(f"Failed to open ArrowTable: {self._path}") from exc
@property
def reader(self) -> pa.ipc.RecordBatchFileReader:
"""Return the pyarrow file reader object"""
if self._reader is not None:
return self._reader
raise RuntimeError(f"Could not open pyarrow reader: {p5b.get_error_string()}")
[docs] def close(self) -> None:
"""
Cleanly close the open file handles and memory views.
"""
self._reader = None
safe_close(self, "_fh")
def __enter__(self) -> "ArrowTableHandle":
return self
def __exit__(self, *exc_details) -> None:
self.close()
def __del__(self) -> None:
self.close()
[docs]class Reader:
"""
The base reader for POD5 data
"""
[docs] def __init__(self, path: PathOrStr):
"""
Open a pod5 filepath for reading
"""
self._path = Path(path).absolute()
self._file_reader: Optional[p5b.Pod5FileReader] = None
self._read_handle: Optional[ArrowTableHandle] = None
self._run_info_handle: Optional[ArrowTableHandle] = None
self._signal_handle: Optional[ArrowTableHandle] = None
(
self._file_reader,
self._read_handle,
self._run_info_handle,
self._signal_handle,
) = self._open_arrow_table_handles(self._path)
schema_metadata = self.read_table.schema.metadata
self._file_identifier = UUID(
schema_metadata[b"MINKNOW:file_identifier"].decode("utf-8")
)
self._writing_software = schema_metadata[b"MINKNOW:software"].decode("utf-8")
writing_version_str = schema_metadata[b"MINKNOW:pod5_version"].decode("utf-8")
writing_version = packaging.version.parse(writing_version_str)
self._columns_type = ReadRecordV3Columns
self._reads_table_version = ReadTableVersion.V3
self._file_version = writing_version
self._file_version_pre_migration = packaging.version.Version(
self._file_reader.get_file_version_pre_migration()
)
# Warning: The cached signal maintains an open file handle. So ensure that
# this dictionary is cleared before closing.
self._cached_signal_batches: Dict[int, Signal] = {}
self._cached_run_infos: Dict[str, RunInfo] = {}
self._is_vbz_compressed: Optional[bool] = None
self._signal_batch_row_count: Optional[int] = None
@staticmethod
def _open_arrow_table_handles(
path: Path,
) -> Tuple[
p5b.Pod5FileReader, ArrowTableHandle, ArrowTableHandle, ArrowTableHandle
]:
"""Open handles to the underlying arrow tables within this pod5 file"""
if not path.is_file():
raise FileNotFoundError(f"Failed to open pod5 file at: {path}")
file_reader = p5b.open_file(str(path))
if not file_reader:
raise Pod5ApiException(
f"Failed to open reader for {path} Reason: {p5b.get_error_string()}"
)
read_handle = ArrowTableHandle(file_reader.get_file_read_table_location())
run_info_handle = ArrowTableHandle(
file_reader.get_file_run_info_table_location()
)
signal_handle = ArrowTableHandle(file_reader.get_file_signal_table_location())
return file_reader, read_handle, run_info_handle, signal_handle
def __del__(self) -> None:
self.close()
def __enter__(self) -> "Reader":
return self
def __exit__(self, *exc_details) -> None:
self.close()
def __iter__(self) -> Generator[ReadRecord, None, None]:
"""Iterate over all reads"""
yield from self.reads()
[docs] def close(self) -> None:
"""Close files handles"""
safe_close(self, "_read_handle")
self._read_handle = None
safe_close(self, "_run_info_handle")
self._run_info_handle = None
safe_close(self, "_signal_handle")
self._signal_handle = None
safe_close(self, "_file_reader")
self._file_reader = None
# Explicitly clear this dictionary to close file handles used in cache
self._cached_signal_batches = {}
@property
def path(self) -> Path:
"""Return the path to this pod5 file"""
return self._path
@property
def inner_file_reader(self) -> p5b.Pod5FileReader:
"""Access the inner c_api Pod5FileReader - use with caution"""
if self._file_reader is None:
raise RuntimeError("Pod5FileReader has been closed!")
return self._file_reader
@property
def read_table(self) -> pa.ipc.RecordBatchFileReader:
"""Access the pod5 read table"""
if self._read_handle is None:
raise RuntimeError("ArrowTableHandle has been closed!")
return self._read_handle.reader
@property
def run_info_table(self) -> pa.ipc.RecordBatchFileReader:
"""Access the pod5 run_info table"""
if self._run_info_handle is None:
raise RuntimeError("ArrowTableHandle has been closed!")
return self._run_info_handle.reader
@property
def signal_table(self) -> pa.ipc.RecordBatchFileReader:
"""Access the pod5 signal table - use with caution"""
if self._signal_handle is None:
raise RuntimeError("ArrowTableHandle has been closed!")
return self._signal_handle.reader
@property
def file_version(self) -> packaging.version.Version:
return self._file_version
@property
def file_version_pre_migration(self) -> packaging.version.Version:
return self._file_version_pre_migration
@property
def writing_software(self) -> str:
return self._writing_software
@property
def file_identifier(self) -> UUID:
return self._file_identifier
@property
def reads_table_version(self) -> ReadTableVersion:
return self._reads_table_version
@property
def is_vbz_compressed(self) -> bool:
"""Return if this file's signal is compressed"""
if self._is_vbz_compressed is None:
self._is_vbz_compressed = self.signal_table.schema.field(
"signal"
).type.equals(pa.large_binary())
return self._is_vbz_compressed
@property
def signal_batch_row_count(self) -> int:
"""Return signal batch row count"""
if self._signal_batch_row_count is None:
if self.signal_table.num_record_batches > 0:
self._signal_batch_row_count = self.signal_table.get_batch(0).num_rows
else:
self._signal_batch_row_count = 0
return self._signal_batch_row_count
@property
def batch_count(self) -> int:
"""
Find the number of read batches available in the file.
"""
return self.read_table.num_record_batches
@property
def num_reads(self) -> int:
"""
Find the number of reads in the file.
"""
return sum(batch.num_reads for batch in self.read_batches())
@property
def read_ids_raw(self) -> pa.ChunkedArray:
"""
Return chunked arrow array of read ids.
To get read ids as string use `Reader.read_ids`
"""
return pa.chunked_array([batch.read_id_column for batch in self.read_batches()])
@property
def read_ids(self) -> List[str]:
"""
Return all read_ids as a list of strings.
For the most performant implementation consider `Reader.read_ids_raw`
"""
def arrow_to_numpy(batch):
# Get the arrow data as a buffer
id_buffer = batch.read_id_column.buffers()[1]
# Pack the arrow buffer into a numpy array of the the right shape
array = np.frombuffer(id_buffer, dtype=np.uint8)
return array.reshape((batch.num_reads, 16))
read_ids = np.concatenate(
[arrow_to_numpy(batch) for batch in self.read_batches()]
)
return format_read_ids(read_ids)
[docs] def get_batch(self, index: int) -> ReadRecordBatch:
"""
Get a read batch in the file.
Returns
-------
:py:class:`ReadRecordBatch`
The requested batch as a ReadRecordBatch.
"""
return ReadRecordBatch(self, self.read_table.get_batch(index))
[docs] def read_batches(
self,
selection: Optional[List[str]] = None,
batch_selection: Optional[Iterable[int]] = None,
missing_ok: bool = False,
preload: Optional[Set[str]] = None,
) -> Generator[ReadRecordBatch, None, None]:
"""
Iterate batches in the file, optionally selecting certain rows.
Parameters
----------
selection : iterable[str]
The read ids to walk in the file.
batch_selection : iterable[int]
The read batches to walk in the file.
missing_ok : bool
If selection contains entries not found in the file, an error will be raised.
preload : set[str]
Columns to preload - "samples" and "sample_count" are valid values
Returns
-------
An iterable of :py:class:`ReadRecordBatch` in the file.
"""
if selection is not None:
assert not batch_selection
yield from self._select_read_batches(
selection, missing_ok=missing_ok, preload=preload
)
elif batch_selection is not None:
assert not selection
yield from self._read_some_batches(batch_selection, preload=preload)
else:
yield from self._reads_batches(preload=preload)
[docs] def reads(
self,
selection: Optional[Iterable[str]] = None,
missing_ok: bool = False,
preload: Optional[Set[str]] = None,
) -> Generator[ReadRecord, None, None]:
"""
Iterate reads in the file, optionally filtering for certain read ids.
Parameters
----------
selection : iterable[str]
The read ids to walk in the file.
missing_ok : bool
If selection contains entries not found in the file, an error will be raised.
preload : set[str]
Columns to preload - "samples" and "sample_count" are valid values
Returns
-------
An iterable of :py:class:`ReadRecord` in the file.
"""
if selection is None:
yield from self._reads(preload=preload)
else:
yield from self._select_reads(
list(selection), missing_ok=missing_ok, preload=preload
)
def _reads(
self, preload: Optional[Set[str]] = None
) -> Generator[ReadRecord, None, None]:
"""Generate all reads"""
for batch in self.read_batches(preload=preload):
for read in batch.reads():
yield read
def _select_reads(
self,
selection: List[str],
missing_ok: bool = False,
preload: Optional[Set[str]] = None,
) -> Generator[ReadRecord, None, None]:
"""Generate selected reads"""
for batch in self._select_read_batches(selection, missing_ok, preload=preload):
for read in batch.reads():
yield read
def _reads_batches(
self, preload: Optional[Set[str]] = None
) -> Generator[ReadRecordBatch, None, None]:
"""Generate the record batches"""
signal_cache = None
if preload:
signal_cache = self.inner_file_reader.batch_get_signal(
"samples" in preload,
"sample_count" in preload,
)
for idx in range(self.read_table.num_record_batches):
batch = self.get_batch(idx)
if signal_cache:
batch.set_cached_signal(signal_cache.release_next_batch())
yield batch
def _read_some_batches(
self,
batch_selection: Iterable[int],
preload: Optional[Set[str]] = None,
) -> Generator[ReadRecordBatch, None, None]:
"""Generate the selected record batches"""
signal_cache = None
if preload:
signal_cache = self.inner_file_reader.batch_get_signal_batches(
"samples" in preload,
"sample_count" in preload,
np.array(batch_selection, dtype=np.uint32),
)
for i in batch_selection:
batch = self.get_batch(i)
if signal_cache:
batch.set_cached_signal(signal_cache.release_next_batch())
yield batch
def _select_read_batches(
self,
selection: List[str],
missing_ok: bool = False,
preload: Optional[Set[str]] = None,
) -> Generator[ReadRecordBatch, None, None]:
"""Generate the selected record batches"""
successful_finds, per_batch_counts, batch_rows = self._plan_traversal(
selection, missing_ok=missing_ok
)
if not missing_ok and successful_finds != len(selection):
raise RuntimeError(
f"Failed to find {len(selection) - successful_finds} requested reads in the file"
)
signal_cache: Optional[p5b.Pod5AsyncSignalLoader] = None
if preload:
signal_cache = self.inner_file_reader.batch_get_signal_selection(
"samples" in preload,
"sample_count" in preload,
per_batch_counts,
batch_rows,
)
current_offset = 0
for batch_idx, batch_count in enumerate(per_batch_counts):
current_batch_rows = batch_rows[
current_offset : current_offset + batch_count
]
current_offset += batch_count
batch = self.get_batch(batch_idx)
batch.set_selected_batch_rows(current_batch_rows)
if signal_cache:
batch.set_cached_signal(signal_cache.release_next_batch())
yield batch
def _plan_traversal(
self,
read_ids: Union[Collection[str], npt.NDArray[np.uint8]],
missing_ok: bool = False,
) -> Tuple[int, npt.NDArray[np.uint32], npt.NDArray[np.uint32]]:
"""
Query the file reader indexes to return the number of read_ids which
were found and the batches and rows which are needed to traverse each
read in the selection.
Parameters
----------
read_ids : Collection or numpy.ndarray of read_id strings
The read ids to find in the file
Returns
-------
successful_find_count: int
The number of reads that were found from the array of read_ids given
per_batch_counts: numpy.array[uint32]
The number of rows from the batch row ids to take to form each RecordBatch
batch_rows: numpy.array[uint32]
All batch row ids
"""
if not isinstance(read_ids, np.ndarray):
read_ids = pack_read_ids(read_ids, invalid_ok=missing_ok)
assert isinstance(read_ids, np.ndarray)
batch_rows = np.empty(dtype="u4", shape=read_ids.shape[0])
per_batch_counts = np.empty(dtype="u4", shape=self.batch_count)
successful_find_count = self.inner_file_reader.plan_traversal(
read_ids,
per_batch_counts,
batch_rows,
)
return successful_find_count, per_batch_counts, batch_rows
def _get_signal_batch(self, batch_id: int) -> Signal:
"""Get the :py:class:`Signal` from the signal_reader batch at batch_id"""
if batch_id in self._cached_signal_batches:
return self._cached_signal_batches[batch_id]
batch = self.signal_table.get_batch(batch_id)
signal_batch = Signal(*[batch.column(name) for name in Signal._fields])
self._cached_signal_batches[batch_id] = signal_batch
return signal_batch
def _lookup_run_info(self, batch: ReadRecordBatch, batch_row_id: int) -> RunInfo:
"""Get the :py:class:`RunInfo` from the batch at batch_row_id"""
acquisition_id = batch.columns.run_info[batch_row_id].as_py()
if acquisition_id in self._cached_run_infos:
return self._cached_run_infos[acquisition_id]
run_info = None
for idx in range(self.run_info_table.num_record_batches):
run_info_batch = self.run_info_table.get_batch(idx)
acquisition_id_col = run_info_batch.column("acquisition_id")
for row in range(run_info_batch.num_rows):
if acquisition_id_col[row].as_py() == acquisition_id:
values = {}
for field in fields(RunInfo):
col = run_info_batch.column(field.name)
values[field.name] = col[row].as_py()
if field.name in ("tracking_id", "context_tags"):
values[field.name] = {k: v for k, v in values[field.name]}
run_info = RunInfo(**values)
break
if not run_info:
raise Exception(
f"Failed to find run info '{acquisition_id}' in run info table"
)
self._cached_run_infos[acquisition_id] = run_info
return run_info