|
5 | 5 | from enum import Enum |
6 | 6 | from functools import lru_cache |
7 | 7 | from operator import itemgetter |
8 | | -from typing import TYPE_CHECKING, Any, NamedTuple |
| 8 | +from typing import TYPE_CHECKING, Any, NamedTuple, cast |
9 | 9 |
|
10 | 10 | import numpy as np |
11 | 11 | import numpy.typing as npt |
|
37 | 37 | from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid |
38 | 38 | from zarr.core.common import ( |
39 | 39 | ShapeLike, |
40 | | - concurrent_map, |
41 | 40 | parse_enum, |
42 | 41 | parse_named_configuration, |
43 | 42 | parse_shapelike, |
44 | 43 | product, |
45 | 44 | ) |
46 | | -from zarr.core.config import config |
47 | 45 | from zarr.core.dtype.npy.int import UInt64 |
48 | 46 | from zarr.core.indexing import ( |
49 | 47 | BasicIndexer, |
@@ -102,6 +100,13 @@ async def get( |
102 | 100 | start, stop = _normalize_byte_range_index(value, byte_range) |
103 | 101 | return value[start:stop] |
104 | 102 |
|
| 103 | + async def get_partial_values( |
| 104 | + self, |
| 105 | + prototype: BufferPrototype, |
| 106 | + byte_ranges: Iterable[ByteRequest | None], |
| 107 | + ) -> list[Buffer | None]: |
| 108 | + return [await self.get(prototype, br) for br in byte_ranges] |
| 109 | + |
105 | 110 |
|
106 | 111 | @dataclass(frozen=True) |
107 | 112 | class _ShardingByteSetter(_ShardingByteGetter, ByteSetter): |
@@ -295,14 +300,6 @@ def to_dict_vectorized( |
295 | 300 | return result |
296 | 301 |
|
297 | 302 |
|
298 | | -@dataclass(frozen=True) |
299 | | -class _ChunkCoordsByteSlice: |
300 | | - """Holds a core.indexing.ChunkProjection.chunk_coords and its byte range in a serialized shard.""" |
301 | | - |
302 | | - chunk_coords: tuple[int, ...] |
303 | | - byte_slice: slice |
304 | | - |
305 | | - |
306 | 303 | @dataclass(frozen=True) |
307 | 304 | class ShardingCodec( |
308 | 305 | ArrayBytesCodec, |
@@ -485,26 +482,19 @@ async def _decode_partial_single( |
485 | 482 | all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks} |
486 | 483 |
|
487 | 484 | # reading bytes of all requested chunks |
488 | | - shard_dict_maybe: ShardMapping | None = {} |
| 485 | + shard_dict_maybe: ShardMapping | None = None |
489 | 486 | if self._is_total_shard(all_chunk_coords, chunks_per_shard): |
490 | 487 | # read entire shard |
491 | 488 | shard_dict_maybe = await self._load_full_shard_maybe( |
492 | 489 | byte_getter, chunk_spec.prototype, chunks_per_shard |
493 | 490 | ) |
494 | 491 | else: |
495 | 492 | # read some chunks within the shard |
496 | | - max_gap_bytes = config.get("sharding.read.coalesce_max_gap_bytes") |
497 | | - coalesce_max_bytes = config.get("sharding.read.coalesce_max_bytes") |
498 | | - async_concurrency = config.get("async.concurrency") |
499 | | - |
500 | 493 | shard_dict_maybe = await self._load_partial_shard_maybe( |
501 | 494 | byte_getter, |
502 | 495 | chunk_spec.prototype, |
503 | 496 | chunks_per_shard, |
504 | 497 | all_chunk_coords, |
505 | | - max_gap_bytes, |
506 | | - coalesce_max_bytes, |
507 | | - async_concurrency, |
508 | 498 | ) |
509 | 499 |
|
510 | 500 | if shard_dict_maybe is None: |
@@ -789,112 +779,34 @@ async def _load_partial_shard_maybe( |
789 | 779 | prototype: BufferPrototype, |
790 | 780 | chunks_per_shard: tuple[int, ...], |
791 | 781 | all_chunk_coords: set[tuple[int, ...]], |
792 | | - max_gap_bytes: int, |
793 | | - coalesce_max_bytes: int, |
794 | | - async_concurrency: int, |
795 | 782 | ) -> ShardMapping | None: |
796 | 783 | """ |
797 | 784 | Read chunks from `byte_getter` for the case where the read is less than a full shard. |
798 | 785 | Returns a mapping of chunk coordinates to bytes or None. |
799 | | -
|
800 | | - Reads are coalesced if there are fewer than `max_gap_bytes` bytes between chunks |
801 | | - and the total size of the coalesced read is no more than `coalesce_max_bytes`. |
802 | 786 | """ |
803 | 787 | shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard) |
804 | 788 | if shard_index is None: |
805 | 789 | return None # shard index read failure, the ByteGetter returned None |
806 | 790 |
|
807 | | - chunks = [ |
808 | | - _ChunkCoordsByteSlice(chunk_coords, slice(*chunk_byte_slice)) |
809 | | - for chunk_coords in all_chunk_coords |
810 | | - # Drop chunks where index lookup fails |
811 | | - # e.g. empty chunks when write_empty_chunks = False |
812 | | - if (chunk_byte_slice := shard_index.get_chunk_slice(chunk_coords)) |
813 | | - ] |
| 791 | + # Build parallel lists of chunk coordinates and byte ranges for non-empty chunks |
| 792 | + chunk_coords_list: list[tuple[int, ...]] = [] |
| 793 | + byte_ranges: list[RangeByteRequest] = [] |
| 794 | + for chunk_coords in all_chunk_coords: |
| 795 | + chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords) |
| 796 | + if chunk_byte_slice is not None: |
| 797 | + chunk_coords_list.append(chunk_coords) |
| 798 | + byte_ranges.append(RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1])) |
814 | 799 |
|
815 | | - groups = self._coalesce_chunks(chunks, max_gap_bytes, coalesce_max_bytes) |
| 800 | + if not byte_ranges: |
| 801 | + return {} |
816 | 802 |
|
817 | | - shard_dict: ShardMutableMapping = {} |
818 | | - if len(groups) == 1: |
819 | | - # Avoid thread start overhead when there's only one group |
820 | | - shard_dict_result = await self._get_group_bytes(groups[0], byte_getter, prototype) |
821 | | - # can be None if the ByteGetter returned None when reading chunk data |
822 | | - if shard_dict_result is not None: |
823 | | - shard_dict.update(shard_dict_result) |
824 | | - else: |
825 | | - shard_dicts = await concurrent_map( |
826 | | - [(group, byte_getter, prototype) for group in groups], |
827 | | - self._get_group_bytes, |
828 | | - async_concurrency, |
829 | | - ) |
| 803 | + # Fetch all chunk byte ranges via get_partial_values |
| 804 | + buffers = await byte_getter.get_partial_values(prototype, byte_ranges) |
830 | 805 |
|
831 | | - for shard_dict_result in shard_dicts: |
832 | | - if shard_dict_result is not None: |
833 | | - shard_dict.update(shard_dict_result) |
834 | | - |
835 | | - return shard_dict |
836 | | - |
837 | | - def _coalesce_chunks( |
838 | | - self, |
839 | | - chunks: list[_ChunkCoordsByteSlice], |
840 | | - max_gap_bytes: int, |
841 | | - coalesce_max_bytes: int, |
842 | | - ) -> list[list[_ChunkCoordsByteSlice]]: |
843 | | - """ |
844 | | - Combine chunks from a single shard into groups that should be read together |
845 | | - in a single request to the store. |
846 | | - """ |
847 | | - sorted_chunks = sorted(chunks, key=lambda c: c.byte_slice.start) |
848 | | - |
849 | | - if len(sorted_chunks) == 0: |
850 | | - return [] |
851 | | - |
852 | | - groups = [] |
853 | | - current_group = [sorted_chunks[0]] |
854 | | - |
855 | | - for chunk in sorted_chunks[1:]: |
856 | | - gap_to_chunk = chunk.byte_slice.start - current_group[-1].byte_slice.stop |
857 | | - size_if_coalesced = chunk.byte_slice.stop - current_group[0].byte_slice.start |
858 | | - if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes: |
859 | | - current_group.append(chunk) |
860 | | - else: |
861 | | - groups.append(current_group) |
862 | | - current_group = [chunk] |
863 | | - |
864 | | - groups.append(current_group) |
865 | | - |
866 | | - return groups |
867 | | - |
868 | | - async def _get_group_bytes( |
869 | | - self, |
870 | | - group: list[_ChunkCoordsByteSlice], |
871 | | - byte_getter: ByteGetter, |
872 | | - prototype: BufferPrototype, |
873 | | - ) -> ShardMapping | None: |
874 | | - """ |
875 | | - Reads a possibly coalesced group of one or more chunks from a shard. |
876 | | - Returns a mapping of chunk coordinates to bytes. |
877 | | - """ |
878 | | - # _coalesce_chunks ensures that the group is not empty. |
879 | | - group_start = group[0].byte_slice.start |
880 | | - group_end = group[-1].byte_slice.stop |
881 | | - |
882 | | - # A single call to retrieve the bytes for the entire group. |
883 | | - group_bytes = await byte_getter.get( |
884 | | - prototype=prototype, |
885 | | - byte_range=RangeByteRequest(group_start, group_end), |
886 | | - ) |
887 | | - if group_bytes is None: |
888 | | - return None |
889 | | - |
890 | | - # Extract the bytes corresponding to each chunk in group from group_bytes. |
891 | | - shard_dict = {} |
892 | | - for chunk in group: |
893 | | - chunk_slice = slice( |
894 | | - chunk.byte_slice.start - group_start, |
895 | | - chunk.byte_slice.stop - group_start, |
896 | | - ) |
897 | | - shard_dict[chunk.chunk_coords] = group_bytes[chunk_slice] |
| 806 | + shard_dict: ShardMutableMapping = {} |
| 807 | + for chunk_coords, buf in zip(chunk_coords_list, buffers, strict=True): |
| 808 | + if buf is not None: |
| 809 | + shard_dict[chunk_coords] = buf |
898 | 810 |
|
899 | 811 | return shard_dict |
900 | 812 |
|
|
0 commit comments