Skip to content

Commit 56cf3cf

Browse files
committed
use get_partial_values(), remove explicit coallescing and concurrent_map
1 parent 7039de9 commit 56cf3cf

9 files changed

Lines changed: 230 additions & 820 deletions

File tree

changes/3004.feature.md

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
Optimizes reading multiple chunks from a shard. Reads of nearby chunks within
2-
the same shard are coalesced to reduce the number of calls to the store.
3-
After any coalescing, the resulting byte ranges are read in parallel.
4-
5-
Coalescing respects two config options. Reads are coalesced if there are fewer
6-
than `sharding.read.coalesce_max_gap_bytes` bytes between chunks and the total
7-
size of the coalesced read is no more than `sharding.read.coalesce_max_bytes`.
1+
Optimizes reading multiple chunks from a shard.
2+
Serial calls to `.get()` in the sharding codec have been replaced with
3+
a single call to `.get_partial_values()` which stores may optimize by making
4+
concurrent requests and/or coalescing nearby requests to the same shard.

docs/user-guide/config.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ Configuration options include the following:
3333
- Async and threading options, e.g. `async.concurrency` and `threading.max_workers`
3434
- Selections of implementations of codecs, codec pipelines and buffers
3535
- Enabling GPU support with `zarr.config.enable_gpu()`. See GPU support for more.
36-
- Control request merging when reading multiple chunks from the same shard with `sharding.read.coalesce_max_gap_bytes` and `sharding.read.coalesce_max_bytes`
3736

3837
For selecting custom implementations of codecs, pipelines, buffers and ndbuffers,
3938
first register the implementations in the registry and then select them in the config.

src/zarr/abc/store.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,19 @@
1616

1717
from zarr.core.buffer import Buffer, BufferPrototype
1818

19-
__all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"]
20-
21-
22-
@dataclass
19+
__all__ = [
20+
"ByteGetter",
21+
"ByteSetter",
22+
"Store",
23+
"SupportsDeleteSync",
24+
"SupportsGetSync",
25+
"SupportsSetSync",
26+
"SupportsSyncStore",
27+
"set_or_delete",
28+
]
29+
30+
31+
@dataclass(frozen=True, slots=True)
2332
class RangeByteRequest:
2433
"""Request a specific byte range"""
2534

@@ -29,15 +38,15 @@ class RangeByteRequest:
2938
"""The end of the byte range request (exclusive)."""
3039

3140

32-
@dataclass
41+
@dataclass(frozen=True, slots=True)
3342
class OffsetByteRequest:
3443
"""Request all bytes starting from a given byte offset"""
3544

3645
offset: int
3746
"""The byte offset for the offset range request."""
3847

3948

40-
@dataclass
49+
@dataclass(frozen=True, slots=True)
4150
class SuffixByteRequest:
4251
"""Request up to the last `n` bytes"""
4352

@@ -686,20 +695,57 @@ async def get(
686695
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
687696
) -> Buffer | None: ...
688697

698+
async def get_partial_values(
699+
self,
700+
prototype: BufferPrototype,
701+
byte_ranges: Iterable[ByteRequest | None],
702+
) -> list[Buffer | None]: ...
703+
689704

690705
@runtime_checkable
691706
class ByteSetter(Protocol):
692707
async def get(
693708
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
694709
) -> Buffer | None: ...
695710

711+
async def get_partial_values(
712+
self,
713+
prototype: BufferPrototype,
714+
byte_ranges: Iterable[ByteRequest | None],
715+
) -> list[Buffer | None]: ...
716+
696717
async def set(self, value: Buffer) -> None: ...
697718

698719
async def delete(self) -> None: ...
699720

700721
async def set_if_not_exists(self, default: Buffer) -> None: ...
701722

702723

724+
@runtime_checkable
725+
class SupportsGetSync(Protocol):
726+
def get_sync(
727+
self,
728+
key: str,
729+
*,
730+
prototype: BufferPrototype | None = None,
731+
byte_range: ByteRequest | None = None,
732+
) -> Buffer | None: ...
733+
734+
735+
@runtime_checkable
736+
class SupportsSetSync(Protocol):
737+
def set_sync(self, key: str, value: Buffer) -> None: ...
738+
739+
740+
@runtime_checkable
741+
class SupportsDeleteSync(Protocol):
742+
def delete_sync(self, key: str) -> None: ...
743+
744+
745+
@runtime_checkable
746+
class SupportsSyncStore(SupportsGetSync, SupportsSetSync, SupportsDeleteSync, Protocol): ...
747+
748+
703749
async def set_or_delete(byte_setter: ByteSetter, value: Buffer | None) -> None:
704750
"""Set or delete a value in a byte setter
705751

src/zarr/codecs/sharding.py

Lines changed: 25 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from enum import Enum
66
from functools import lru_cache
77
from operator import itemgetter
8-
from typing import TYPE_CHECKING, Any, NamedTuple
8+
from typing import TYPE_CHECKING, Any, NamedTuple, cast
99

1010
import numpy as np
1111
import numpy.typing as npt
@@ -37,13 +37,11 @@
3737
from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid
3838
from zarr.core.common import (
3939
ShapeLike,
40-
concurrent_map,
4140
parse_enum,
4241
parse_named_configuration,
4342
parse_shapelike,
4443
product,
4544
)
46-
from zarr.core.config import config
4745
from zarr.core.dtype.npy.int import UInt64
4846
from zarr.core.indexing import (
4947
BasicIndexer,
@@ -102,6 +100,13 @@ async def get(
102100
start, stop = _normalize_byte_range_index(value, byte_range)
103101
return value[start:stop]
104102

103+
async def get_partial_values(
104+
self,
105+
prototype: BufferPrototype,
106+
byte_ranges: Iterable[ByteRequest | None],
107+
) -> list[Buffer | None]:
108+
return [await self.get(prototype, br) for br in byte_ranges]
109+
105110

106111
@dataclass(frozen=True)
107112
class _ShardingByteSetter(_ShardingByteGetter, ByteSetter):
@@ -295,14 +300,6 @@ def to_dict_vectorized(
295300
return result
296301

297302

298-
@dataclass(frozen=True)
299-
class _ChunkCoordsByteSlice:
300-
"""Holds a core.indexing.ChunkProjection.chunk_coords and its byte range in a serialized shard."""
301-
302-
chunk_coords: tuple[int, ...]
303-
byte_slice: slice
304-
305-
306303
@dataclass(frozen=True)
307304
class ShardingCodec(
308305
ArrayBytesCodec,
@@ -485,26 +482,19 @@ async def _decode_partial_single(
485482
all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks}
486483

487484
# reading bytes of all requested chunks
488-
shard_dict_maybe: ShardMapping | None = {}
485+
shard_dict_maybe: ShardMapping | None = None
489486
if self._is_total_shard(all_chunk_coords, chunks_per_shard):
490487
# read entire shard
491488
shard_dict_maybe = await self._load_full_shard_maybe(
492489
byte_getter, chunk_spec.prototype, chunks_per_shard
493490
)
494491
else:
495492
# read some chunks within the shard
496-
max_gap_bytes = config.get("sharding.read.coalesce_max_gap_bytes")
497-
coalesce_max_bytes = config.get("sharding.read.coalesce_max_bytes")
498-
async_concurrency = config.get("async.concurrency")
499-
500493
shard_dict_maybe = await self._load_partial_shard_maybe(
501494
byte_getter,
502495
chunk_spec.prototype,
503496
chunks_per_shard,
504497
all_chunk_coords,
505-
max_gap_bytes,
506-
coalesce_max_bytes,
507-
async_concurrency,
508498
)
509499

510500
if shard_dict_maybe is None:
@@ -789,112 +779,34 @@ async def _load_partial_shard_maybe(
789779
prototype: BufferPrototype,
790780
chunks_per_shard: tuple[int, ...],
791781
all_chunk_coords: set[tuple[int, ...]],
792-
max_gap_bytes: int,
793-
coalesce_max_bytes: int,
794-
async_concurrency: int,
795782
) -> ShardMapping | None:
796783
"""
797784
Read chunks from `byte_getter` for the case where the read is less than a full shard.
798785
Returns a mapping of chunk coordinates to bytes or None.
799-
800-
Reads are coalesced if there are fewer than `max_gap_bytes` bytes between chunks
801-
and the total size of the coalesced read is no more than `coalesce_max_bytes`.
802786
"""
803787
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard)
804788
if shard_index is None:
805789
return None # shard index read failure, the ByteGetter returned None
806790

807-
chunks = [
808-
_ChunkCoordsByteSlice(chunk_coords, slice(*chunk_byte_slice))
809-
for chunk_coords in all_chunk_coords
810-
# Drop chunks where index lookup fails
811-
# e.g. empty chunks when write_empty_chunks = False
812-
if (chunk_byte_slice := shard_index.get_chunk_slice(chunk_coords))
813-
]
791+
# Build parallel lists of chunk coordinates and byte ranges for non-empty chunks
792+
chunk_coords_list: list[tuple[int, ...]] = []
793+
byte_ranges: list[RangeByteRequest] = []
794+
for chunk_coords in all_chunk_coords:
795+
chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords)
796+
if chunk_byte_slice is not None:
797+
chunk_coords_list.append(chunk_coords)
798+
byte_ranges.append(RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]))
814799

815-
groups = self._coalesce_chunks(chunks, max_gap_bytes, coalesce_max_bytes)
800+
if not byte_ranges:
801+
return {}
816802

817-
shard_dict: ShardMutableMapping = {}
818-
if len(groups) == 1:
819-
# Avoid thread start overhead when there's only one group
820-
shard_dict_result = await self._get_group_bytes(groups[0], byte_getter, prototype)
821-
# can be None if the ByteGetter returned None when reading chunk data
822-
if shard_dict_result is not None:
823-
shard_dict.update(shard_dict_result)
824-
else:
825-
shard_dicts = await concurrent_map(
826-
[(group, byte_getter, prototype) for group in groups],
827-
self._get_group_bytes,
828-
async_concurrency,
829-
)
803+
# Fetch all chunk byte ranges via get_partial_values
804+
buffers = await byte_getter.get_partial_values(prototype, byte_ranges)
830805

831-
for shard_dict_result in shard_dicts:
832-
if shard_dict_result is not None:
833-
shard_dict.update(shard_dict_result)
834-
835-
return shard_dict
836-
837-
def _coalesce_chunks(
838-
self,
839-
chunks: list[_ChunkCoordsByteSlice],
840-
max_gap_bytes: int,
841-
coalesce_max_bytes: int,
842-
) -> list[list[_ChunkCoordsByteSlice]]:
843-
"""
844-
Combine chunks from a single shard into groups that should be read together
845-
in a single request to the store.
846-
"""
847-
sorted_chunks = sorted(chunks, key=lambda c: c.byte_slice.start)
848-
849-
if len(sorted_chunks) == 0:
850-
return []
851-
852-
groups = []
853-
current_group = [sorted_chunks[0]]
854-
855-
for chunk in sorted_chunks[1:]:
856-
gap_to_chunk = chunk.byte_slice.start - current_group[-1].byte_slice.stop
857-
size_if_coalesced = chunk.byte_slice.stop - current_group[0].byte_slice.start
858-
if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes:
859-
current_group.append(chunk)
860-
else:
861-
groups.append(current_group)
862-
current_group = [chunk]
863-
864-
groups.append(current_group)
865-
866-
return groups
867-
868-
async def _get_group_bytes(
869-
self,
870-
group: list[_ChunkCoordsByteSlice],
871-
byte_getter: ByteGetter,
872-
prototype: BufferPrototype,
873-
) -> ShardMapping | None:
874-
"""
875-
Reads a possibly coalesced group of one or more chunks from a shard.
876-
Returns a mapping of chunk coordinates to bytes.
877-
"""
878-
# _coalesce_chunks ensures that the group is not empty.
879-
group_start = group[0].byte_slice.start
880-
group_end = group[-1].byte_slice.stop
881-
882-
# A single call to retrieve the bytes for the entire group.
883-
group_bytes = await byte_getter.get(
884-
prototype=prototype,
885-
byte_range=RangeByteRequest(group_start, group_end),
886-
)
887-
if group_bytes is None:
888-
return None
889-
890-
# Extract the bytes corresponding to each chunk in group from group_bytes.
891-
shard_dict = {}
892-
for chunk in group:
893-
chunk_slice = slice(
894-
chunk.byte_slice.start - group_start,
895-
chunk.byte_slice.stop - group_start,
896-
)
897-
shard_dict[chunk.chunk_coords] = group_bytes[chunk_slice]
806+
shard_dict: ShardMutableMapping = {}
807+
for chunk_coords, buf in zip(chunk_coords_list, buffers, strict=True):
808+
if buf is not None:
809+
shard_dict[chunk_coords] = buf
898810

899811
return shard_dict
900812

src/zarr/core/config.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,6 @@ def enable_gpu(self) -> ConfigSet:
100100
},
101101
"async": {"concurrency": 10, "timeout": None},
102102
"threading": {"max_workers": None},
103-
"sharding": {
104-
"read": {
105-
"coalesce_max_bytes": 100 * 2**20, # 100MiB
106-
"coalesce_max_gap_bytes": 2**20, # 1MiB
107-
}
108-
},
109103
"json_indent": 2,
110104
"codec_pipeline": {
111105
"path": "zarr.core.codec_pipeline.BatchedCodecPipeline",

0 commit comments

Comments
 (0)