From 714d3cd43e6f22dde650b017dc5ea813b2ee6459 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Sat, 23 May 2026 17:37:59 -0500 Subject: [PATCH] feat(waterdata): Auto-chunk OGC requests over the URL byte limit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The OGC `waterdata` getters previously failed with HTTP 414 when the request URL exceeded the server's ~8 KB byte limit. A common pattern — pulling a long site list from `get_monitoring_locations` and feeding it into `get_daily` — was the main offender: sites_df, _ = get_monitoring_locations(state_name="Ohio") df, md = get_daily( monitoring_location_id=sites_df["monitoring_location_id"].tolist(), parameter_code="00060", time="P7D", ) Introduces a joint chunker that models every multi-value list parameter and the cql-text `filter` (split on top-level `OR`) as a chunkable axis. Greedy halving splits the biggest chunk across all axes until each sub-request URL fits; the chunker fans out under the hood and returns one combined DataFrame. Callers see no API change. Mid-stream 429 / 5xx surface as `ChunkInterrupted` subclasses (`QuotaExhausted` / `ServiceInterrupted`) carrying the partial result plus a `.call` resumable handle — `exc.call.resume()` continues only the still-pending sub-requests. Pre-emptive `RequestExceedsQuota` catches plans that won't fit the remaining rate-limit window; `API_USGS_LIMIT=0` bypasses the check. Behavior changes for paginated / chunked calls: - `BaseMetadata.url` still reflects the user's original query. - `BaseMetadata.header` now carries the LAST page's headers so `x-ratelimit-remaining` is current (was: first page's). - `BaseMetadata.query_time` is now cumulative wall-clock across pages (was: first page's elapsed). Mirrors R `dataRetrieval`'s [#870](https://github.com/DOI-USGS/dataRetrieval/pull/870), generalized from one filter axis to N joint axes. Co-authored-by: Claude Opus 4.7 --- NEWS.md | 4 +- dataretrieval/waterdata/api.py | 27 +- dataretrieval/waterdata/chunking.py | 1246 ++++++++++++++++++++++++++ dataretrieval/waterdata/filters.py | 194 +--- dataretrieval/waterdata/utils.py | 667 ++++++++++---- tests/waterdata_chunking_test.py | 1271 +++++++++++++++++++++++++++ tests/waterdata_filters_test.py | 261 ++---- tests/waterdata_test.py | 79 +- tests/waterdata_utils_test.py | 301 ++++++- 9 files changed, 3374 insertions(+), 676 deletions(-) create mode 100644 dataretrieval/waterdata/chunking.py create mode 100644 tests/waterdata_chunking_test.py diff --git a/NEWS.md b/NEWS.md index 7761e29b..7519c62f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,5 @@ +**05/17/2026:** The OGC `waterdata` getters (`get_daily`, `get_continuous`, `get_field_measurements`, and the rest of the multi-value-capable functions) now transparently chunk requests whose URLs would otherwise exceed the server's ~8 KB byte limit. + **05/16/2026:** Fixed silent truncation in the paginated `waterdata` request loops (`_walk_pages` and `get_stats_data`). Mid-pagination failures (HTTP 429, 5xx, network error) were previously swallowed — pagination would quietly stop and the function would return whatever rows it had collected, leaving callers with truncated DataFrames they had no way to detect. The loops now status-check every page like the initial request and raise `RuntimeError` on any failure, with the upstream exception chained as `__cause__` and a short menu of recovery actions (wait and retry, reduce the request, or obtain an API token) in the message. **Behavior change**: callers that previously consumed partial DataFrames on transient upstream blips will now see an exception; retry the call (possibly with a smaller `limit` or narrower query). **05/07/2026:** Bumped the declared minimum Python version from **3.8** to **3.9** (`pyproject.toml`'s `requires-python` and the ruff target). This brings the manifest in line with what was already being tested — CI's matrix has long covered only 3.9, 3.13, and 3.14, the `waterdata` test module already skipped itself on Python < 3.10, and several modules already use 3.9-only stdlib (e.g. `zoneinfo`). Users on 3.8 will no longer be able to install the package; please upgrade. @@ -36,4 +38,4 @@ **03/01/2024:** USGS data availability and format have changed on Water Quality Portal (WQP). Since March 2024, data obtained from WQP legacy profiles will not include new USGS data or recent updates to existing data. All USGS data (up to and beyond March 2024) are available using the new WQP beta services. You can access the beta services by setting `legacy=False` in the functions in the `wqp` module. -To view the status of changes in data availability and code functionality, visit: https://doi-usgs.github.io/dataRetrieval/articles/Status.html \ No newline at end of file +To view the status of changes in data availability and code functionality, visit: https://doi-usgs.github.io/dataRetrieval/articles/Status.html diff --git a/dataretrieval/waterdata/api.py b/dataretrieval/waterdata/api.py index ad268194..106501dd 100644 --- a/dataretrieval/waterdata/api.py +++ b/dataretrieval/waterdata/api.py @@ -113,7 +113,7 @@ def get_daily( data are released on the condition that neither the USGS nor the United States Government may be held liable for any damages resulting from its use. This field reflects the approval status of each record, and is either - "Approved", meaining processing review has been completed and the data is + "Approved", meaning processing review has been completed and the data is approved for publication, or "Provisional" and subject to revision. For more information about provisional data, go to: https://waterdata.usgs.gov/provisional-data-statement/. @@ -230,6 +230,21 @@ def get_daily( ... parameter_code="00060", ... last_modified="P7D", ... ) + + >>> # Chain queries: pull all stream sites in a state, then their + >>> # daily discharge for the last week. The site list can be hundreds + >>> # of values long — the request is transparently chunked across + >>> # multiple sub-requests so the URL stays under the server's byte + >>> # limit. Combined output looks like a single query. + >>> sites_df, _ = dataretrieval.waterdata.get_monitoring_locations( + ... state_name="Ohio", + ... site_type="Stream", + ... ) + >>> df, md = dataretrieval.waterdata.get_daily( + ... monitoring_location_id=sites_df["monitoring_location_id"].tolist(), + ... parameter_code="00060", + ... time="P7D", + ... ) """ service = "daily" output_id = "daily_id" @@ -259,7 +274,7 @@ def get_continuous( convert_type: bool = True, ) -> tuple[pd.DataFrame, BaseMetadata]: """ - Continuous data provide instantanous water conditions. + Continuous data provide instantaneous water conditions. This is an early version of the continuous endpoint that is feature-complete and is being made available for limited use. Geometries are not included @@ -320,7 +335,7 @@ def get_continuous( data are released on the condition that neither the USGS nor the United States Government may be held liable for any damages resulting from its use. This field reflects the approval status of each record, and is either - "Approved", meaining processing review has been completed and the data is + "Approved", meaning processing review has been completed and the data is approved for publication, or "Provisional" and subject to revision. For more information about provisional data, go to: https://waterdata.usgs.gov/provisional-data-statement/. @@ -1254,7 +1269,7 @@ def get_latest_continuous( data are released on the condition that neither the USGS nor the United States Government may be held liable for any damages resulting from its use. This field reflects the approval status of each record, and is either - "Approved", meaining processing review has been completed and the data is + "Approved", meaning processing review has been completed and the data is approved for publication, or "Provisional" and subject to revision. For more information about provisional data, go to: https://waterdata.usgs.gov/provisional-data-statement/. @@ -1451,7 +1466,7 @@ def get_latest_daily( data are released on the condition that neither the USGS nor the United States Government may be held liable for any damages resulting from its use. This field reflects the approval status of each record, and is either - "Approved", meaining processing review has been completed and the data is + "Approved", meaning processing review has been completed and the data is approved for publication, or "Provisional" and subject to revision. For more information about provisional data, go to: https://waterdata.usgs.gov/provisional-data-statement/. @@ -1633,7 +1648,7 @@ def get_field_measurements( data are released on the condition that neither the USGS nor the United States Government may be held liable for any damages resulting from its use. This field reflects the approval status of each record, and is either - "Approved", meaining processing review has been completed and the data is + "Approved", meaning processing review has been completed and the data is approved for publication, or "Provisional" and subject to revision. For more information about provisional data, go to: https://waterdata.usgs.gov/provisional-data-statement/. diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py new file mode 100644 index 00000000..c6eb9945 --- /dev/null +++ b/dataretrieval/waterdata/chunking.py @@ -0,0 +1,1246 @@ +"""Joint URL-byte chunking for the Water Data OGC getters. + +A Water Data query has several chunkable axes: every multi-value list +parameter (sites, parameter codes, …) plus the cql-text ``filter``, +which splits along its top-level OR clauses. Any of them can fan the +URL past the server's ~8 KB byte limit. ``ChunkPlan`` picks a fan-out +for each axis that minimizes total sub-requests under the URL budget; +``ChunkedCall`` iterates the joint cartesian product so every +sub-request URL fits. Requests that already fit get a trivial +single-step plan — ``ChunkedCall`` has one code path either way. + +Quota: after the first sub-request ``ChunkedCall`` reads +``x-ratelimit-remaining``; if the rest of the plan won't fit, it +raises ``RequestExceedsQuota`` before burning more budget. Set +``API_USGS_LIMIT=0`` to skip this pre-emptive check and attempt the +full plan anyway. + +Interruption: any mid-stream transient failure (429, 5xx) surfaces +as a ``ChunkInterrupted`` subclass — ``QuotaExhausted`` for 429, +``ServiceInterrupted`` for 5xx. The exception carries ``.call``, a +``ChunkedCall`` handle that owns the already-completed sub-request +state. Call ``.call.resume()`` once the underlying condition clears +to resume; only the still-pending sub-requests are re-issued. +``Retry-After`` (when the server sets it) is surfaced on the +exception as ``.retry_after``. + +Dedup: list-axis chunks don't overlap; filter-axis chunks can, so +``_combine_chunk_frames`` dedupes by feature ``id``. ``properties``, +``bbox``, date intervals, ``limit``, ``skip_geometry``, and +``filter``/``filter_lang`` themselves are never sliced as list axes +(the filter is partitioned along its top-level OR axis instead). +""" + +from __future__ import annotations + +import copy +import functools +import itertools +import math +import os +from collections.abc import Callable, Iterator +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass +from typing import Any, ClassVar +from urllib.parse import quote_plus + +import pandas as pd +import requests +from requests.structures import CaseInsensitiveDict + +from .filters import ( + _check_numeric_filter_pitfall, + _is_chunkable, + _split_top_level_or, +) + +# Empirically the API replies HTTP 414 above ~8200 bytes of full URL — +# matches nginx's default ``large_client_header_buffers`` of 8 KB. 8000 +# leaves ~200 bytes for request-line framing and proxy variance. +_WATERDATA_URL_BYTE_LIMIT = 8000 + +# Default rule: any list-shaped kwarg with >1 element is chunked across +# sub-requests — each chunk becomes a comma-joined sub-list in the URL. +# The OGC getters expose ~90 such list-shaped params (IDs, codes, +# statuses, ...), all chunkable, so it's shorter to enumerate the +# exceptions than to maintain an allowlist that grows with the API. +# Exceptions, by reason: +# - response shape: ``properties`` defines the columns; sharding +# would yield different schemas per chunk. +# - structured: ``bbox`` is a fixed 4-element coord tuple. +# - intervals: date/time ranges are not enumerable sets. +# - handled elsewhere: ``filter`` becomes its own axis in +# ``_extract_axes`` (joiner ``" OR "``); +# comma-joining CQL clauses would emit +# malformed expressions. +# - scalar by contract: ``limit``, ``skip_geometry``, ``filter_lang`` +# — a list value would be a type-erasure smuggle. +_NEVER_CHUNK = frozenset( + { + "properties", + "bbox", + "datetime", + "last_modified", + "begin", + "begin_utc", + "end", + "end_utc", + "time", + "filter", + "filter_lang", + "limit", + "skip_geometry", + } +) + +# Response header USGS uses to advertise remaining hourly quota. +_QUOTA_HEADER = "x-ratelimit-remaining" + +# Session shared across all sub-requests of a single chunked call so +# paginated-loop helpers downstream (``_walk_pages``) reuse one +# connection pool across the whole fan-out. ``None`` when not inside a +# chunked call — paginated helpers fall back to their own short-lived +# session in that case. +_chunked_session: ContextVar[requests.Session | None] = ContextVar( + "_chunked_session", default=None +) + + +@contextmanager +def _publish_session(session: requests.Session) -> Iterator[None]: + """ + Make ``session`` visible to :func:`get_active_session` for the + duration of the ``with`` block via the ``_chunked_session`` + ContextVar. Wraps the set/reset token dance so callers don't have to. + """ + token = _chunked_session.set(session) + try: + yield + finally: + _chunked_session.reset(token) + + +def get_active_session() -> requests.Session | None: + """ + Return the chunker's currently-published session, or ``None``. + + Public accessor for the ``_chunked_session`` ContextVar so + sibling modules (notably :func:`dataretrieval.waterdata.utils._session`) + don't have to reach into the private ContextVar directly. + + Returns + ------- + requests.Session or None + The session published by :func:`_publish_session` if currently + inside a :class:`ChunkedCall` ``resume`` block; ``None`` otherwise. + """ + return _chunked_session.get() + + +# Separators the two axis kinds use to join their atoms back into +# URL text. List axes comma-join values (``site=USGS-A,USGS-B``); the +# filter axis OR-joins clauses (``filter=a='1' OR a='2'``). +_LIST_SEP = "," +_OR_SEP = " OR " + +_FetchOnce = Callable[[dict[str, Any]], tuple[pd.DataFrame, requests.Response]] + + +class _RetryableTransportError(RuntimeError): + """ + Base for typed HTTP transport failures the chunker recognizes as + transient. + + Raised by :func:`dataretrieval.waterdata.utils._raise_for_non_200` + and walked by :func:`_classify_chunk_error`. One subclass per + recoverable HTTP status family (429 → :class:`RateLimited`, + 5xx → :class:`ServiceUnavailable`); ``ChunkedCall`` wraps them as + resumable :class:`ChunkInterrupted` subclasses. + + Parameters + ---------- + message : str + Human-readable error message. + retry_after : float, optional + Seconds to wait before retrying, parsed from the + ``Retry-After`` response header. + + Attributes + ---------- + retry_after : float or None + Seconds to wait before retrying, parsed from the + ``Retry-After`` response header. ``None`` when the header was + absent or unparseable. + """ + + def __init__(self, message: str, *, retry_after: float | None = None) -> None: + super().__init__(message) + self.retry_after = retry_after + + +class RateLimited(_RetryableTransportError): + """ + A USGS Water Data API request was rejected with HTTP 429. + + Exposed as a typed exception so callers (notably the multi-value + chunker) can detect rate-limit failures via ``isinstance`` instead + of string-matching error messages. + """ + + +class ServiceUnavailable(_RetryableTransportError): + """ + A USGS Water Data API request was rejected with HTTP 5xx. + + Surfaced as a typed exception (parallel to :class:`RateLimited`) + so ``ChunkedCall`` can treat transient server failures as + resumable interruptions rather than fatal programmer errors. + """ + + +class RequestTooLarge(ValueError): + """ + No chunking plan fits the URL byte limit. + + Raised when even the smallest reducible plan (every list axis at + singleton chunks and the filter at one clause per sub-request) + still exceeds the server's byte limit. Shrink the input lists, + simplify the filter, or split the call manually. + """ + + +class RequestExceedsQuota(ValueError): + """ + Remaining rate-limit window can't cover the rest of the chunked plan. + + Raised after a sub-request when ``x-ratelimit-remaining`` in the + response shows the rest of the plan can't fit in the current per-key + rate-limit window. The chunks completed so far have already been + issued and consumed quota; ``ChunkedCall`` stops here rather than + burn more quota on a call that will fail mid-way. The completed + work is preserved on ``.call`` (the originating ``ChunkedCall``) + so callers can recover its ``partial_frame`` / ``partial_response`` + and, once the rate-limit window resets, call ``.call.resume()`` + to continue. + + Attributes + ---------- + planned_chunks : int + Total sub-requests the joint plan would issue. + available : int + Sub-requests this caller can still issue in the current window + (``x-ratelimit-remaining`` + chunks already completed). + deficit : int + ``planned_chunks - available`` — how far over budget the call + would run if it continued. + call : ChunkedCall or None + The originating call handle. ``None`` on hand-constructed + exceptions (test fixtures); otherwise the live handle whose + ``partial_frame`` / ``partial_response`` expose the work + completed before the check fired and whose ``resume()`` can be + called once the rate-limit window rolls over. + """ + + def __init__( + self, + *, + planned_chunks: int, + available: int, + deficit: int, + call: ChunkedCall | None = None, + ) -> None: + super().__init__( + f"Request would issue {planned_chunks} sub-requests but only " + f"{available} fit in the current rate-limit window (short by " + f"{deficit}). Wait for the window to reset, request a higher " + f"per-key quota, narrow the query, or set " + f"API_USGS_LIMIT=0 to bypass this check and risk a " + f"mid-stream 429 (recoverable via QuotaExhausted.resume())." + ) + self.planned_chunks = planned_chunks + self.available = available + self.deficit = deficit + self.call = call + + +class ChunkInterrupted(RuntimeError): + """ + Base class for mid-stream chunk failures whose completed work is + preserved and resumable. + + A ``ChunkInterrupted`` subclass means: a sub-request failed, but + ``ChunkedCall`` still owns whatever completed successfully before + the failure. Call ``self.call.resume()`` to pick up where the + failure stopped you — only still-pending sub-requests are + re-issued. + + Subclasses describe *why* ``ChunkedCall`` stopped so callers can + pick a retry policy: :class:`QuotaExhausted` for 429 (wait for the + rate-limit window), :class:`ServiceInterrupted` for 5xx (wait for + the upstream to recover). The ``.call`` handle is the same object + across every interruption of a single chunked call — frames + accumulate across retries. + + Attributes + ---------- + call : ChunkedCall or None + Resumable handle into the ``ChunkedCall`` that raised this + exception. ``None`` only on hand-constructed exceptions (test + fixtures), where ``.call``-derived accessors degrade to + empty/``None``. + retry_after : float or None + Seconds the server suggested waiting (``Retry-After`` header). + ``None`` when the server gave no hint. + completed_chunks : int + Number of sub-requests successfully completed before the failure. + total_chunks : int + Total sub-requests in the plan. + partial_frame : pandas.DataFrame + Combined frame of work completed by the moment this exception + was raised. Snapshot at raise time — does NOT advance on a + later ``call.resume()`` (use ``exc.call.partial_frame`` for + the live view). + partial_response : requests.Response or None + Aggregated response covering the completed sub-requests at + raise time; ``None`` if nothing had completed yet. Same + snapshot semantics as ``partial_frame``. + + Examples + -------- + Retry on any transient interruption, honoring the server's + ``Retry-After`` hint when present and falling back to a fixed wait + otherwise. Each new interruption keeps the already-completed work + intact — only the still-pending sub-requests are re-issued. + + .. code-block:: python + + import time + from dataretrieval.waterdata import get_daily + from dataretrieval.waterdata.chunking import ChunkInterrupted + + try: + df, md = get_daily(monitoring_location_id=long_list_of_sites) + except ChunkInterrupted as exc: + while True: + time.sleep(exc.retry_after or 5 * 60) + try: + df, md = exc.call.resume() + break + except ChunkInterrupted as next_exc: + exc = next_exc + """ + + # Subclasses override with a ``str.format`` template; the format + # call sees ``completed_chunks`` and ``total_chunks`` as kwargs. + _MESSAGE_TEMPLATE: ClassVar[str] = ( + "Chunked request interrupted after {completed_chunks}/" + "{total_chunks} sub-requests; call .call.resume() to continue." + ) + + def __init__( + self, + *, + completed_chunks: int, + total_chunks: int, + call: ChunkedCall | None = None, + retry_after: float | None = None, + ) -> None: + super().__init__( + self._MESSAGE_TEMPLATE.format( + completed_chunks=completed_chunks, total_chunks=total_chunks + ) + ) + self.completed_chunks = completed_chunks + self.total_chunks = total_chunks + self.call = call + self.retry_after = retry_after + # Snapshot partial state at raise time so the exception's view + # stays stable across later ``call.resume()`` advances; the + # live view lives on ``call.partial_frame``/``.partial_response``. + # ``partial_frame`` gets a defensive ``.copy()`` because + # ``_combine_chunk_frames`` may return a chunk frame verbatim + # in the single-completed-chunk fast path; ``partial_response`` + # already comes via ``copy.copy`` from ``_combine_chunk_responses``. + if call is None: + self.partial_frame: pd.DataFrame = pd.DataFrame() + self.partial_response: requests.Response | None = None + else: + self.partial_frame = call.partial_frame.copy() + self.partial_response = call.partial_response + + +class QuotaExhausted(ChunkInterrupted): + """ + A sub-request returned HTTP 429 — the per-key rate-limit window + is exhausted. Subclass of :class:`ChunkInterrupted`. + + For a chunked call (``total_chunks > 1``) reached past chunk 0, + the post-first-chunk :class:`RequestExceedsQuota` check normally + short-circuits before burning quota on a plan that won't fit; + arrival here typically means a concurrent caller drained the + window faster than predicted. ``partial_frame`` holds what + completed first. + + For a single-shot call (``total_chunks == 1``) or a 429 on the + very first chunk, ``partial_frame`` is empty and + ``partial_response`` is ``None``; the original ``RateLimited`` is + on ``__cause__``. + """ + + _MESSAGE_TEMPLATE = ( + "HTTP 429 after {completed_chunks}/{total_chunks} sub-requests; " + "catch QuotaExhausted (or ChunkInterrupted) to access " + ".partial_frame or .call.resume() once the rate-limit " + "window has rolled over." + ) + + +class ServiceInterrupted(ChunkInterrupted): + """ + A sub-request returned HTTP 5xx — the upstream service failed + transiently. Subclass of :class:`ChunkInterrupted`. + + The completed sub-requests are preserved on ``.call``; once the + upstream recovers, ``.call.resume()`` resumes only the + still-pending work. + """ + + _MESSAGE_TEMPLATE = ( + "Service error after {completed_chunks}/{total_chunks} " + "sub-requests; catch ServiceInterrupted (or ChunkInterrupted) " + "and call .call.resume() once the upstream service recovers." + ) + + +def _request_bytes(req: requests.PreparedRequest) -> int: + """ + Total bytes of a prepared request: URL + body. + + GET routes have ``body=None`` and reduce to URL length. POST routes + (CQL2 JSON body) need body bytes — the URL stays short regardless + of payload, so URL-only sizing would underestimate the request and + skip chunking when it's needed. + + Parameters + ---------- + req : requests.PreparedRequest + The prepared request to size. + + Returns + ------- + int + ``len(req.url) + len(req.body)`` where ``req.body`` is treated + as 0 bytes when ``None`` and UTF-8 encoded when ``str``. + + Raises + ------ + TypeError + If ``req.body`` is not ``None``, ``bytes``/``bytearray``, or + ``str``. Size-based planning needs a deterministic byte count, + so generators and file-like streams are rejected up front + rather than silently treated as zero bytes. + """ + body = req.body + if body is None: + body_len = 0 + elif isinstance(body, (bytes, bytearray)): + body_len = len(body) + elif isinstance(body, str): + body_len = len(body.encode("utf-8")) + else: + raise TypeError( + f"multi_value_chunked cannot size a request body of type " + f"{type(body).__name__!r}; pass str, bytes, or None." + ) + return len(req.url) + body_len + + +@dataclass(frozen=True) +class _Axis: + """ + A single chunkable axis of one user-level request — a list of + atomic units and the separator that joins them in the URL. + + Both multi-value list parameters (``sites=[...]``, joiner ``","``) + and the cql-text ``filter`` (split on top-level ``OR``, joiner + ``" OR "``) fit this shape, so a single greedy halving loop in + ``ChunkPlan._plan`` handles both — no need for two separate + algorithms. + + Attributes + ---------- + arg_key : str + The args-dict key this axis substitutes back into when a + sub-request is rendered. + atoms : tuple of str + The smallest indivisible units along this axis (one site, one + OR-clause, …). A "chunk" is a contiguous slice of ``atoms``. + joiner : str + Separator placed between atoms when they are joined back into + URL text — ``","`` for list axes, ``" OR "`` for the filter + axis. + """ + + arg_key: str + atoms: tuple[str, ...] + joiner: str + + def chunk_bytes(self, chunk: list[str]) -> int: + """ + URL-encoded bytes a chunk contributes when substituted. + + ``quote_plus`` is faithful to what the real URL builder + produces, so values containing characters that expand under URL + encoding (``%``, ``+``, ``/``, ``&``, …) can't be mis-ranked. + + Parameters + ---------- + chunk : list of str + A contiguous slice of ``self.atoms``. + + Returns + ------- + int + Length of ``quote_plus(self.joiner.join(chunk))``. + """ + return len(quote_plus(self.joiner.join(map(str, chunk)))) + + def render(self, chunk: list[str]) -> Any: + """ + Convert a chunk into the form the URL builder expects. + + List axes yield a fresh list of atoms (``build_request`` will + comma-join); the filter axis yields a pre-joined string (CQL + doesn't take a list). + + Parameters + ---------- + chunk : list of str + A contiguous slice of ``self.atoms``. + + Returns + ------- + list of str or str + ``list(chunk)`` for list axes, ``self.joiner.join(chunk)`` + for the filter axis. + """ + return list(chunk) if self.joiner == _LIST_SEP else self.joiner.join(chunk) + + +def _extract_axes(args: dict[str, Any]) -> list[_Axis]: + """ + Build the chunkable-axis set from a request's args. + + Multi-value list params with more than one element each become an + axis. The cql-text filter (when chunkable and split into more than + one top-level OR-clause) becomes one too. Anything in + ``_NEVER_CHUNK`` is excluded except ``filter`` itself, which is + handled separately so its atoms are clauses not characters. + + Parameters + ---------- + args : dict[str, Any] + The user-level request kwargs (the same dict that would be + passed to ``build_request``). + + Returns + ------- + list[_Axis] + Zero or more axes in insertion order: list axes first (one + per eligible kwarg, in ``args`` order), then the filter axis + if present. + """ + axes: list[_Axis] = [] + for key, value in args.items(): + if key in _NEVER_CHUNK: + continue + if isinstance(value, (list, tuple)) and len(value) > 1: + axes.append(_Axis(arg_key=key, atoms=tuple(value), joiner=_LIST_SEP)) + + filter_expr = args.get("filter") + if _is_chunkable(filter_expr, args.get("filter_lang")): + _check_numeric_filter_pitfall(filter_expr) + clauses = _split_top_level_or(filter_expr) + if len(clauses) >= 2: + axes.append(_Axis(arg_key="filter", atoms=tuple(clauses), joiner=_OR_SEP)) + return axes + + +class ChunkPlan: + """ + Strategy for issuing one user-level request as a sequence of + sub-requests whose URLs each fit ``url_limit``. + + Constructing a plan *is* planning: + ``ChunkPlan(args, build_request, url_limit)`` extracts the + chunkable axes, runs greedy halving on the biggest chunk across + all axes, and stores the result. + + Passthrough requests (no chunkable axes, or already fitting) are + represented as a trivial plan with empty ``axes`` / ``chunks`` and + ``total == 1``; :meth:`iter_sub_args` yields the original args + unchanged so the ``ChunkedCall`` loop is the same shape either + way. + + Parameters + ---------- + args : dict[str, Any] + The user-level request kwargs. + build_request : Callable[..., requests.PreparedRequest] + Factory that turns a kwargs dict into a sized prepared + request, e.g. ``_construct_api_requests``. + url_limit : int + Byte budget for the prepared request (URL + body). + + Attributes + ---------- + args : dict + The original user-level args this plan was built for. Bound to + the plan so :meth:`iter_sub_args` is self-contained. + axes : list[_Axis] + The chunkable axes of ``args``: each multi-value list + parameter, plus the cql-text filter (if any) split on top-level + OR. Empty in the passthrough case. + chunks : dict[str, list[list[str]]] + Per-axis partition: ``chunks[axis.arg_key]`` is the list of + atom-sublists this axis is split into. Empty in passthrough. + canonical_url : str or None + URL of the full original request, used to overwrite the first + chunk's ``response.url`` so ``BaseMetadata`` reflects the + user's full query. ``None`` on the nothing-to-chunk passthrough + path — ``fetch_once``'s response already carries the canonical + URL there, so ``ChunkedCall`` skips the override to avoid an + extra ``build_request`` call on the hot path. + + Raises + ------ + RequestTooLarge + If the request needs chunking but even the singleton plan + doesn't fit ``url_limit``. + """ + + def __init__( + self, + args: dict[str, Any], + build_request: Callable[..., requests.PreparedRequest], + url_limit: int, + ) -> None: + self.args = args + self.axes: list[_Axis] = [] + self.chunks: dict[str, list[list[str]]] = {} + self.canonical_url: str | None = None + + axes = _extract_axes(args) + # No chunkable axes → skip ``build_request`` entirely; the + # common Water Data call shape shouldn't pay for an unused + # request prep on the passthrough hot path. + if not axes: + return + + initial_request = build_request(**args) + self.canonical_url = initial_request.url + if _request_bytes(initial_request) <= url_limit: + return + + self.axes = axes + self.chunks = {axis.arg_key: [list(axis.atoms)] for axis in axes} + self._plan(build_request, url_limit) + + def _plan( + self, + build_request: Callable[..., requests.PreparedRequest], + url_limit: int, + ) -> None: + """ + Greedy-halve the biggest chunk across all axes until the + worst-case sub-request URL fits ``url_limit``. Mutates + ``self.chunks`` in place; treats list axes and the filter axis + uniformly — each is just a list of atoms joined by its axis's + separator. + + Raises + ------ + RequestTooLarge + If even the singleton plan (every axis at one atom per + chunk) still exceeds ``url_limit``. + """ + while True: + worst = self._worst_case_args() + if _request_bytes(build_request(**worst)) <= url_limit: + return + + biggest_axis: _Axis | None = None + biggest_idx = -1 + biggest_size = -1 + for axis in self.axes: + for idx, chunk in enumerate(self.chunks[axis.arg_key]): + if len(chunk) <= 1: + continue + size = axis.chunk_bytes(chunk) + if size > biggest_size: + biggest_axis, biggest_idx, biggest_size = axis, idx, size + + if biggest_axis is None: + raise RequestTooLarge( + f"Request exceeds {url_limit} bytes (URL + body) at the " + f"smallest reducible plan (every axis at one atom per " + f"sub-request). Reduce input sizes, shorten or simplify " + f"the filter, or split the call manually." + ) + axis_chunks = self.chunks[biggest_axis.arg_key] + chunk = axis_chunks[biggest_idx] + mid = len(chunk) // 2 + axis_chunks[biggest_idx : biggest_idx + 1] = [chunk[:mid], chunk[mid:]] + + def _worst_case_args(self) -> dict[str, Any]: + """ + Args dict representing the largest sub-request the current + ``self.chunks`` partition will issue — each axis's longest + (by URL-encoded bytes) chunk rendered back in. + """ + out = dict(self.args) + for axis in self.axes: + worst = max(self.chunks[axis.arg_key], key=axis.chunk_bytes) + out[axis.arg_key] = axis.render(worst) + return out + + @property + def total(self) -> int: + """ + Total sub-request count: product of per-axis chunk counts. + + Returns + ------- + int + ``1`` for the passthrough plan, otherwise the cartesian + product of ``len(chunks[ax.arg_key])`` across all axes. + """ + return math.prod((len(self.chunks[ax.arg_key]) for ax in self.axes), start=1) + + def iter_sub_args(self) -> Iterator[dict[str, Any]]: + """ + Yield substituted args for each sub-request, in deterministic + order — cartesian product over axes in extraction order. + + The same plan yields the same sub-args sequence on every + invocation, so resume is well-defined. + + Yields + ------ + dict[str, Any] + A copy of ``self.args`` with each axis's current chunk + substituted under its ``arg_key``. + """ + if not self.axes: + yield dict(self.args) + return + chunk_lists = [self.chunks[ax.arg_key] for ax in self.axes] + for combo in itertools.product(*chunk_lists): + sub_args = dict(self.args) + for axis, chunk in zip(self.axes, combo): + sub_args[axis.arg_key] = axis.render(chunk) + yield sub_args + + def execute(self, fetch_once: _FetchOnce) -> tuple[pd.DataFrame, requests.Response]: + """ + Run the plan and return the combined ``(frame, response)``. + + Thin wrapper around ``ChunkedCall(self, fetch_once).resume()``; + see :class:`ChunkedCall` for the per-sub-request semantics. + + Parameters + ---------- + fetch_once : Callable + Function that issues a single sub-request, given the + substituted args dict, and returns ``(frame, response)``. + + Returns + ------- + df : pandas.DataFrame + Combined data from every successful sub-request. + response : requests.Response + Aggregated response (canonical URL, last page's headers, + cumulative elapsed time). + + Raises + ------ + ChunkInterrupted + On a mid-stream transient failure + (:class:`QuotaExhausted` for 429, + :class:`ServiceInterrupted` for 5xx). The resumable handle + is on ``exc.call``. + RequestExceedsQuota + When the rate-limit window can't cover the remaining plan. + """ + return ChunkedCall(self, fetch_once).resume() + + +def _quota_check_disabled() -> bool: + """ + Check whether the pre-emptive quota check is disabled. + + Read at call time (not import time) so test patches via + ``monkeypatch.setenv`` take effect. + + Returns + ------- + bool + ``True`` when the environment variable ``API_USGS_LIMIT`` is + set to ``"0"`` (stripped), bypassing the post-first-chunk + :class:`RequestExceedsQuota` check. + """ + return os.environ.get("API_USGS_LIMIT", "").strip() == "0" + + +def _read_remaining(response: requests.Response) -> int | None: + """ + Parse the ``x-ratelimit-remaining`` header from a response. + + Parameters + ---------- + response : requests.Response + A response that may or may not carry the quota header. + + Returns + ------- + int or None + The parsed integer, or ``None`` when the header is missing or + unparseable. ``ChunkedCall`` treats ``None`` as "no quota + signal" and skips the post-first-chunk plan check. + """ + raw = response.headers.get(_QUOTA_HEADER) + if raw is None: + return None + try: + return int(raw) + except (TypeError, ValueError): + return None + + +def _classify_chunk_error( + exc: BaseException, +) -> tuple[type[ChunkInterrupted], float | None] | None: + """ + Classify a fetch error as a known transient (resumable) failure. + + Walks the ``__cause__`` chain of ``exc`` looking for a known typed + transport failure. Returns the matching ``ChunkInterrupted`` + subclass and any ``Retry-After`` hint, or ``None`` if the error is + not a recognized transient — in which case ``ChunkedCall`` + re-raises rather than wrapping (programmer errors and unknown + failures shouldn't masquerade as resumable). + + Parameters + ---------- + exc : BaseException + The exception raised by a sub-request. + + Returns + ------- + tuple[type[ChunkInterrupted], float or None] or None + ``(interrupted_class, retry_after)`` for recognized transient + failures; ``None`` otherwise. + + Notes + ----- + ``_walk_pages`` re-wraps mid-pagination failures as + ``RuntimeError`` with the typed transport exception linked as + ``__cause__``, so this function must walk the chain rather than + just ``isinstance`` the top-level exception. + + Bare ``requests.exceptions.RequestException`` (ConnectionError, + Timeout, SSLError, …) is also treated as a transient transport + failure and wrapped as :class:`ServiceInterrupted` — these don't + inherit from ``RuntimeError`` and would otherwise escape the + chunker's catch with no resumable handle. + """ + cur: BaseException | None = exc + while cur is not None: + if isinstance(cur, RateLimited): + return QuotaExhausted, cur.retry_after + if isinstance(cur, ServiceUnavailable): + return ServiceInterrupted, cur.retry_after + if isinstance(cur, requests.exceptions.RequestException): + return ServiceInterrupted, None + cur = cur.__cause__ + return None + + +def _combine_chunk_frames(frames: list[pd.DataFrame]) -> pd.DataFrame: + """ + Concatenate per-chunk frames, dropping empties and deduping by ``id``. + + Parameters + ---------- + frames : list[pandas.DataFrame] + One frame per completed sub-request. + + Returns + ------- + pandas.DataFrame + The concatenated, deduplicated result. Empty when every input + frame is empty. + + Notes + ----- + ``_get_resp_data`` returns a plain ``pd.DataFrame()`` on empty + responses; concatenating it with real ``GeoDataFrame``s downgrades + the result to plain ``DataFrame`` and strips geometry/CRS, so + empties are dropped first. Dedup on the pre-rename feature ``id`` + keeps overlapping user OR-clauses from producing duplicate rows + across chunks. + + Dedup is restricted to rows whose ``id`` is non-null. ``pandas`` + treats NaN==NaN as a duplicate for ``drop_duplicates``, so a + blanket call would collapse every id-less row into a single one — + silent data loss if any chunk emits features without an + ``id`` field. + """ + non_empty = [f for f in frames if not f.empty] + if not non_empty: + # Preserve the frame type (GeoDataFrame vs DataFrame) of the + # input even when every chunk is empty — ``_get_resp_data`` + # returns ``gpd.GeoDataFrame()`` on empty geopd responses, and + # returning a plain ``pd.DataFrame()`` here would downgrade + # the type a downstream ``pd.concat([result, geo_page])`` to a + # plain DataFrame and strip geometry/CRS. + return frames[0] if frames else pd.DataFrame() + if len(non_empty) == 1: + # Single-completed-chunk fast path. Return a copy so callers + # who treat ``ChunkedCall.partial_frame`` as a fresh result + # (the property docstring says "live; recomputed per access") + # don't accidentally mutate ``_chunks[0][0]`` in place. + return non_empty[0].copy() + combined = pd.concat(non_empty, ignore_index=True) + if "id" in combined.columns: + has_id = combined["id"].notna() + if has_id.all(): + combined = combined.drop_duplicates(subset="id", ignore_index=True) + elif has_id.any(): + # Mixed: dedupe only the id-bearing rows; preserve id-less + # rows verbatim (their order relative to id-bearing rows + # may shift, which is acceptable — dedup can't be id-keyed + # for rows without an id). + id_rows = combined[has_id].drop_duplicates(subset="id") + no_id_rows = combined[~has_id] + combined = pd.concat([id_rows, no_id_rows], ignore_index=True) + return combined + + +def _combine_chunk_responses( + responses: list[requests.Response], canonical_url: str | None +) -> requests.Response: + """ + Fold per-sub-request responses into a single aggregated response. + + Returns a shallow copy of ``responses[0]`` with ``.headers`` set to + the last response's (so ``x-ratelimit-remaining`` reflects current + state), ``.elapsed`` set to total wall-clock across every response, + and ``.url`` set to the canonical original-query URL so + ``BaseMetadata`` reflects the user's full request rather than the + first chunk. + + Parameters + ---------- + responses : list[requests.Response] + One response per completed sub-request, in execution order. + canonical_url : str or None + URL of the unchunked original request. ``None`` skips the URL + override — used by the trivial-passthrough path where + ``fetch_once`` already returns a response whose ``.url`` is + the original-query URL. + + Returns + ------- + requests.Response + A shallow copy of the first response with aggregated + ``headers``, ``elapsed``, and ``url``. The function is + idempotent (the input responses' ``headers`` / ``elapsed`` / + ``url`` are never mutated), so it's safe to call repeatedly + via :attr:`ChunkedCall.partial_response` during error + inspection or resume retries. ``headers`` on the returned + object is a fresh ``CaseInsensitiveDict``, so mutations there + don't back-propagate into any chunk's underlying response. + Note that other ``Response`` fields (``_content``, ``raw``, + ``cookies``, ``request``) are still aliased to the first + chunk by the shallow copy — callers that mutate those will + affect the underlying chunk response. + """ + # ``copy.copy`` lets repeated calls re-sum elapsed from scratch + # rather than re-mutating ``responses[0]`` in place. The headers + # dict is then rewrapped in a fresh ``CaseInsensitiveDict`` so the + # aggregate's headers don't share identity with — or leak mutations + # back into — any underlying response on ``ChunkedCall._chunks``. + head = copy.copy(responses[0]) + if len(responses) > 1: + head.headers = CaseInsensitiveDict(responses[-1].headers) + head.elapsed = sum( + (r.elapsed for r in responses[1:]), start=responses[0].elapsed + ) + else: + head.headers = CaseInsensitiveDict(responses[0].headers) + if canonical_url is not None: + head.url = canonical_url + return head + + +class ChunkedCall: + """ + Stateful handle for a chunked call. + + Holds the in-flight state (per-sub-request frames and responses) + and exposes a single :meth:`resume` entry point that drives the + call from wherever it is to completion — used both for the first + invocation (from :meth:`ChunkPlan.execute`) and for subsequent + retries after a :class:`ChunkInterrupted`. + + A ``ChunkedCall`` is created internally when a :class:`ChunkPlan` + executes; callers reach it via :attr:`ChunkInterrupted.call` on + the exception raised by a mid-stream failure. + + :meth:`resume` is idempotent: it skips sub-requests already + completed (``self.completed_chunks`` is the cursor) and re-issues + only the still-pending ones. The sub-request + ordering matches :meth:`ChunkPlan.iter_sub_args`, which is + deterministic, so each call picks up exactly where the previous + one stopped. + + Parameters + ---------- + plan : ChunkPlan + The chunking plan to execute. + fetch_once : Callable + Function that issues a single sub-request, given the + substituted args dict, and returns ``(frame, response)``. + + Attributes + ---------- + plan : ChunkPlan + The plan being driven (read-only after construction). + fetch_once : Callable + The per-sub-request fetch function. + completed_chunks : int + Number of sub-requests successfully completed so far. + total_chunks : int + Total sub-requests in ``plan`` (``== plan.total``). + partial_frame : pandas.DataFrame + Combined frame of completed sub-requests (live; recomputed per + access). + partial_response : requests.Response or None + Aggregated response with canonical URL restored, or ``None`` + when nothing has completed yet (live; recomputed per access). + """ + + def __init__(self, plan: ChunkPlan, fetch_once: _FetchOnce) -> None: + self.plan = plan + self.fetch_once = fetch_once + # One entry per completed sub-request, in execution order. + # A single list keeps the (frame, response) pair atomic so the + # ``len(_chunks)`` cursor can't ever drift between two parallel + # lists. + self._chunks: list[tuple[pd.DataFrame, requests.Response]] = [] + + @property + def completed_chunks(self) -> int: + return len(self._chunks) + + @property + def total_chunks(self) -> int: + return self.plan.total + + @property + def partial_frame(self) -> pd.DataFrame: + """ + Concatenated, deduplicated frame of sub-requests that have + completed so far. + + Live — recomputed on each access so it reflects current state + across resume attempts. + + Returns + ------- + pandas.DataFrame + Combined frame of completed sub-requests, or an empty + ``DataFrame`` when nothing has completed. + """ + if not self._chunks: + return pd.DataFrame() + return _combine_chunk_frames([frame for frame, _ in self._chunks]) + + @property + def partial_response(self) -> requests.Response | None: + """ + Aggregated response with the canonical URL restored to the + user's full original query. + + Live — recomputed on each access. + + Returns + ------- + requests.Response or None + Aggregated response when at least one sub-request has + completed, ``None`` otherwise. + """ + if not self._chunks: + return None + return _combine_chunk_responses( + [resp for _, resp in self._chunks], self.plan.canonical_url + ) + + def resume(self) -> tuple[pd.DataFrame, requests.Response]: + """ + Drive the chunked call to completion. + + Opens one ``requests.Session`` for the run and publishes it on + the ``_chunked_session`` ``ContextVar`` so paginated-loop + helpers downstream (``_walk_pages``) reuse the same connection + pool across every sub-request instead of handshaking fresh on + each. The session is closed when ``resume`` returns or raises; + a follow-up ``resume`` call (after a ``ChunkInterrupted``) + opens a new one. + + Idempotent: starts from chunk 0 on the first call, then from + the cursor (``self.completed_chunks``) on every subsequent + call. Re-issues only sub-requests that haven't already + completed. + + Returns + ------- + df : pandas.DataFrame + Combined data from every successful sub-request. + response : requests.Response + Aggregated response (canonical URL, last page's headers, + cumulative elapsed time). + + Raises + ------ + ChunkInterrupted + On a mid-stream transient failure + (:class:`QuotaExhausted` for 429, + :class:`ServiceInterrupted` for 5xx). The resumable handle + is on ``exc.call`` — wait for the underlying condition to + clear and call ``exc.call.resume()`` again. + RequestExceedsQuota + When the rate-limit window can't cover the remaining plan + (checked after the first sub-request). + """ + with requests.Session() as session, _publish_session(session): + completed = len(self._chunks) + for i, sub_args in enumerate(self.plan.iter_sub_args()): + if i < completed: + continue + self._issue(sub_args) + frames = [frame for frame, _ in self._chunks] + responses = [resp for _, resp in self._chunks] + return ( + _combine_chunk_frames(frames), + _combine_chunk_responses(responses, self.plan.canonical_url), + ) + + def _issue(self, sub_args: dict[str, Any]) -> None: + # Catch both ``RuntimeError`` (the layer's typed contract: + # ``RateLimited`` / ``ServiceUnavailable`` / mid-pagination + # wrapper) and ``requests.exceptions.RequestException`` + # (transport-level failures like ConnectionError / Timeout / + # SSLError that bubble up unmodified from + # ``sess.send(initial_req)`` and don't inherit from + # RuntimeError). Both routes go through ``_classify_chunk_error`` + # so transient failures become resumable ``ChunkInterrupted`` + # subclasses; unknown failures re-raise to preserve their type. + try: + chunk = self.fetch_once(sub_args) + except (RuntimeError, requests.exceptions.RequestException) as exc: + classification = _classify_chunk_error(exc) + if classification is None: + raise + interrupted_class, retry_after = classification + raise interrupted_class( + completed_chunks=len(self._chunks), + total_chunks=self.plan.total, + call=self, + retry_after=retry_after, + ) from exc + self._chunks.append(chunk) + if len(self._chunks) < self.plan.total: + self._check_quota_remaining() + + def _check_quota_remaining(self) -> None: + if _quota_check_disabled(): + return + _, last_response = self._chunks[-1] + remaining = _read_remaining(last_response) + completed = len(self._chunks) + pending = self.plan.total - completed + if remaining is None or remaining >= pending: + return + raise RequestExceedsQuota( + planned_chunks=self.plan.total, + available=remaining + completed, + deficit=pending - remaining, + call=self, + ) + + +def multi_value_chunked( + *, + build_request: Callable[..., requests.PreparedRequest], + url_limit: int | None = None, +) -> Callable[[_FetchOnce], _FetchOnce]: + """ + Decorate a fetch function to transparently chunk over-budget requests. + + Splits multi-value list params and cql-text filters across + sub-requests so each fits the URL byte limit. Builds a + :class:`ChunkPlan` and runs it: passthrough requests are a trivial + single-step plan, so the decorated function has one code path + either way. + + Parameters + ---------- + build_request : Callable[..., requests.PreparedRequest] + Factory that turns a kwargs dict into a sized prepared + request, e.g. ``_construct_api_requests``. Called during + planning to measure each candidate plan. + url_limit : int, optional + Byte budget for the prepared request (URL + body). When + ``None`` (default), the module-level + ``_WATERDATA_URL_BYTE_LIMIT`` is resolved at call time so test + patches via ``monkeypatch.setattr`` take effect. + + Returns + ------- + Callable + A decorator that wraps a ``fetch_once(args) -> (df, response)`` + callable into one that accepts the same shape but executes the + underlying plan transparently. + + Raises + ------ + RequestTooLarge + If no plan can fit ``url_limit``. + RequestExceedsQuota + After the first sub-request, if the remaining plan can't fit + the current rate-limit window. + ChunkInterrupted + On a mid-execution 429 (:class:`QuotaExhausted`) or 5xx + (:class:`ServiceInterrupted`). See :class:`ChunkedCall` for + the resume semantics. + + See Also + -------- + ChunkPlan : Planning shape (axes, partitioning, passthrough). + ChunkedCall : Per-sub-request execution and resume semantics. + """ + + def decorator(fetch_once: _FetchOnce) -> _FetchOnce: + @functools.wraps(fetch_once) + def wrapper( + args: dict[str, Any], + ) -> tuple[pd.DataFrame, requests.Response]: + limit = _WATERDATA_URL_BYTE_LIMIT if url_limit is None else url_limit + return ChunkPlan(args, build_request, limit).execute(fetch_once) + + return wrapper + + return decorator diff --git a/dataretrieval/waterdata/filters.py b/dataretrieval/waterdata/filters.py index 4c136b82..5e1c0a67 100644 --- a/dataretrieval/waterdata/filters.py +++ b/dataretrieval/waterdata/filters.py @@ -1,47 +1,27 @@ """CQL ``filter`` support for the Water Data OGC getters. -Two names are public to the rest of the package: +Public: - ``FILTER_LANG``: the type alias used for the ``filter_lang`` kwarg. -- ``chunked``: the decorator ``utils.py`` applies to its single-request - fetch function. It runs the lexicographic-comparison pitfall guard, - splits long cql-text filters at top-level ``OR`` so each sub-request - fits under the server's URL byte limit, and concatenates the results. -Other CQL shapes (``AND``, ``NOT``, ``LIKE``, spatial/temporal predicates, -function calls) are forwarded verbatim — only top-level ``OR`` chunks -losslessly into independent sub-queries whose result sets can be union'd. +Internal helpers used by ``chunking.multi_value_chunked``'s joint +planner: ``_split_top_level_or`` (clause partitioning), +``_is_chunkable`` (filter-language gate), and +``_check_numeric_filter_pitfall`` (the lexicographic-comparison guard). + +Other CQL shapes (``AND``, ``NOT``, ``LIKE``, spatial/temporal +predicates, function calls) are forwarded verbatim — only top-level +``OR`` chunks losslessly into independent sub-queries whose result sets +can be union'd. """ from __future__ import annotations -import functools import re -from collections.abc import Callable -from typing import Any, Literal, TypeVar -from urllib.parse import quote_plus - -import pandas as pd -import requests +from typing import Any, Literal FILTER_LANG = Literal["cql-text", "cql-json"] -# Conservative fallback budget when ``_chunk_cql_or`` is called without -# an explicit ``max_len``. The ``chunked`` decorator computes a tighter -# per-request budget from ``_WATERDATA_URL_BYTE_LIMIT``. -_CQL_FILTER_CHUNK_LEN = 5000 - -# Empirically the API replies HTTP 414 above ~8200 bytes of full URL — -# matches nginx's default ``large_client_header_buffers`` of 8 KB. 8000 -# leaves ~200 bytes for request-line framing and proxy variance. -_WATERDATA_URL_BYTE_LIMIT = 8000 - -# Conservative over-estimate of URL bytes used by everything *except* -# the filter value. Used only by the fast path in -# ``_effective_filter_budget`` to skip the probe when the encoded filter -# clearly already fits. -_NON_FILTER_URL_HEADROOM = 1000 - _NUM = r"-?(?:\d+(?:\.\d+)?|\.\d+)(?:[eE][+-]?\d+)?" _IDENT = r"[A-Za-z_]\w*" @@ -120,69 +100,6 @@ def _split_top_level_or(expr: str) -> list[str]: return [p for p in parts if p] -def _chunk_cql_or(expr: str, max_len: int = _CQL_FILTER_CHUNK_LEN) -> list[str]: - """Split ``expr`` into OR-chunks each under ``max_len`` characters. - - Only top-level ``OR`` chains can be recombined losslessly as a disjunction - of independent sub-queries. Returns ``[expr]`` unchanged when the whole - expression already fits, when there is no top-level ``OR``, or when any - single clause exceeds ``max_len`` (sending it as-is and surfacing the - server's 414 is clearer than silently dropping data). - """ - if len(expr) <= max_len: - return [expr] - parts = _split_top_level_or(expr) - if len(parts) < 2 or any(len(p) > max_len for p in parts): - return [expr] - - chunks = [] - current: list[str] = [] - current_len = 0 - for part in parts: - join_cost = len(" OR ") if current else 0 - if current and current_len + join_cost + len(part) > max_len: - chunks.append(" OR ".join(current)) - current = [part] - current_len = len(part) - else: - current.append(part) - current_len += join_cost + len(part) - if current: - chunks.append(" OR ".join(current)) - return chunks - - -def _effective_filter_budget( - args: dict[str, Any], - filter_expr: str, - build_request: Callable[..., Any], -) -> int: - """Raw-CQL byte budget that, after URL-encoding, fits the URL byte limit. - - The server caps total URL length, not raw CQL length. We probe the - non-filter URL bytes by building the request with a 1-byte placeholder - filter, subtract from the URL limit to get the bytes available for the - encoded filter, then convert back to raw CQL bytes via the *maximum* - per-clause encoding ratio (a chunk could contain only the heavier-encoding - clauses, so budgeting by the average ratio could overflow). - """ - # Fast path: encoded filter clearly fits with room for any plausible - # non-filter URL. Skips the PreparedRequest build and splitter scan. - encoded_len = len(quote_plus(filter_expr)) - if encoded_len + _NON_FILTER_URL_HEADROOM <= _WATERDATA_URL_BYTE_LIMIT: - return len(filter_expr) + 1 - - probe = build_request(**{**args, "filter": "x"}) - available_url_bytes = _WATERDATA_URL_BYTE_LIMIT - (len(probe.url) - 1) - if available_url_bytes <= 0: - # Non-filter URL already over the limit. Pass through unchanged so - # the caller sees one 414 instead of N parallel sub-request failures. - return len(filter_expr) + 1 - parts = _split_top_level_or(filter_expr) or [filter_expr] - encoding_ratio = max(len(quote_plus(p)) / len(p) for p in parts) - return max(100, int(available_url_bytes / encoding_ratio)) - - def _check_numeric_filter_pitfall(filter_expr: str) -> None: """Raise if the filter pairs a field with an unquoted numeric literal. @@ -243,92 +160,3 @@ def _is_chunkable(filter_expr: Any, filter_lang: Any) -> bool: and bool(filter_expr) and filter_lang in {None, "cql-text"} ) - - -def _combine_chunk_frames(frames: list[pd.DataFrame]) -> pd.DataFrame: - """Concatenate per-chunk frames, dropping empties and deduping by ``id``. - - ``_get_resp_data`` returns a plain ``pd.DataFrame()`` on empty responses; - concat'ing it with real GeoDataFrames downgrades the result to plain - DataFrame and strips geometry/CRS, so empties are dropped first. Dedup - on the pre-rename feature ``id`` keeps overlapping user OR-clauses from - producing duplicate rows across chunks. - """ - non_empty = [f for f in frames if not f.empty] - if not non_empty: - return pd.DataFrame() - if len(non_empty) == 1: - return non_empty[0] - combined = pd.concat(non_empty, ignore_index=True) - if "id" in combined.columns: - combined = combined.drop_duplicates(subset="id", ignore_index=True) - return combined - - -def _combine_chunk_responses( - responses: list[requests.Response], -) -> requests.Response: - """Return one response: first chunk's URL/headers + summed ``elapsed``. - - Mutates the first response in place (only ``elapsed``); downstream only - reads ``elapsed`` (in ``BaseMetadata.query_time``), URL, and headers. - """ - head = responses[0] - if len(responses) > 1: - head.elapsed = sum((r.elapsed for r in responses[1:]), start=head.elapsed) - return head - - -_FetchOnce = TypeVar( - "_FetchOnce", - bound=Callable[[dict[str, Any]], tuple[pd.DataFrame, requests.Response]], -) - - -def chunked(*, build_request: Callable[..., Any]) -> Callable[[_FetchOnce], _FetchOnce]: - """Decorator that adds CQL-filter chunking to a single-request fetch. - - The wrapped function has signature ``(args: dict) -> (frame, response)`` - and represents one HTTP round-trip. The decorator inspects ``args``: - - - No chunkable filter: pass through unchanged. - - Chunkable cql-text filter: run the lexicographic-pitfall guard, split - into URL-length-safe sub-expressions, call the wrapped function once - per chunk, concatenate frames (drop empties, dedup by feature ``id``), - and return an aggregated response (first chunk's URL/headers, summed - ``elapsed``). - - Either way the return shape matches the undecorated function's, so the - caller wraps the response in ``BaseMetadata`` the same way in both paths. - - ``build_request`` is injected so the decorator can probe URL length - without importing any specific HTTP builder; it receives the same kwargs - the wrapped function's ``args`` would and returns a prepared-request-like - object with a ``.url`` attribute. - """ - - def decorator(fetch_once: _FetchOnce) -> _FetchOnce: - @functools.wraps(fetch_once) - def wrapper( - args: dict[str, Any], - ) -> tuple[pd.DataFrame, requests.Response]: - filter_expr = args.get("filter") - if not _is_chunkable(filter_expr, args.get("filter_lang")): - return fetch_once(args) - - _check_numeric_filter_pitfall(filter_expr) - budget = _effective_filter_budget(args, filter_expr, build_request) - chunks = _chunk_cql_or(filter_expr, max_len=budget) - - frames: list[pd.DataFrame] = [] - responses: list[requests.Response] = [] - for chunk in chunks: - frame, response = fetch_once({**args, "filter": chunk}) - frames.append(frame) - responses.append(response) - - return _combine_chunk_frames(frames), _combine_chunk_responses(responses) - - return wrapper # type: ignore[return-value] - - return decorator diff --git a/dataretrieval/waterdata/utils.py b/dataretrieval/waterdata/utils.py index 9245bb92..58d4673d 100644 --- a/dataretrieval/waterdata/utils.py +++ b/dataretrieval/waterdata/utils.py @@ -1,20 +1,29 @@ from __future__ import annotations +import copy import json import logging import os import re -from collections.abc import Iterable, Mapping -from datetime import datetime -from typing import Any, get_args +from collections.abc import Callable, Iterable, Iterator, Mapping +from contextlib import contextmanager +from datetime import datetime, timedelta +from typing import Any, TypeVar, get_args from zoneinfo import ZoneInfo import pandas as pd import requests +from requests.structures import CaseInsensitiveDict from dataretrieval import __version__ from dataretrieval.utils import BaseMetadata -from dataretrieval.waterdata import filters +from dataretrieval.waterdata import chunking +from dataretrieval.waterdata.chunking import ( + _QUOTA_HEADER, + RateLimited, + ServiceUnavailable, + get_active_session, +) from dataretrieval.waterdata.types import ( PROFILE_LOOKUP, PROFILES, @@ -410,33 +419,104 @@ def _error_body(resp: requests.Response): ) +def _parse_retry_after(value: str | None) -> float | None: + """ + Parse a USGS ``Retry-After`` header into seconds. + + Parameters + ---------- + value : str or None + The raw header value, or ``None`` if absent. + + Returns + ------- + float or None + Non-negative delta-seconds, clamped at zero. ``None`` when the + header is absent or unparseable; ``ChunkedCall`` treats + ``None`` as "fall back to my own retry policy". + + Notes + ----- + USGS sends ``Retry-After`` as integer delta-seconds (empirically + verified — e.g. ``Retry-After: 2619``). The HTTP spec also allows + HTTP-date form, but USGS doesn't use it, so this function doesn't + bother parsing it. + """ + if not value: + return None + try: + return max(0.0, float(value.strip())) + except ValueError: + return None + + def _raise_for_non_200(resp: requests.Response) -> None: - """Raise ``RuntimeError(_error_body(resp))`` if ``resp`` is not 200. + """ + Raise a typed exception for any non-200 response. + + Routes through :func:`_error_body` (USGS-API-aware: handles + 429/403 specially, extracts ``code``/``description`` from JSON + error bodies) rather than ``Response.raise_for_status``, which + raises ``HTTPError`` with a generic message. - Routes through ``_error_body`` (USGS-API-aware: handles 429/403 - specially, extracts ``code``/``description`` from JSON error bodies) - rather than ``Response.raise_for_status``, which raises - ``HTTPError`` with a generic message. + Parameters + ---------- + resp : requests.Response + The HTTP response to inspect. + + Raises + ------ + RateLimited + On HTTP 429 — typed so ``ChunkedCall`` can wrap as a resumable + :class:`~dataretrieval.waterdata.chunking.QuotaExhausted`. + ServiceUnavailable + On HTTP 5xx — typed so ``ChunkedCall`` can wrap as a resumable + :class:`~dataretrieval.waterdata.chunking.ServiceInterrupted`. + RuntimeError + On any other non-200 (4xx other than 429) — these are + programmer errors that retry won't fix. """ - if resp.status_code != 200: - raise RuntimeError(_error_body(resp)) + status = resp.status_code + if status == 200: + return + body = _error_body(resp) + retry_after = _parse_retry_after(resp.headers.get("Retry-After")) + if status == 429: + raise RateLimited(body, retry_after=retry_after) + if 500 <= status < 600: + raise ServiceUnavailable(body, retry_after=retry_after) + raise RuntimeError(body) def _paginated_failure_message(pages_collected: int, cause: BaseException) -> str: - """User-facing message for a mid-pagination failure. + """ + Build a user-facing message for a mid-pagination failure. The API exposes no resume cursor, so the caller's only recovery is to retry the whole call — the message lists the practical knobs, tailored to whether the failure was rate-limit (429) or something else. + + Parameters + ---------- + pages_collected : int + Number of pages successfully fetched before the failure. + cause : BaseException + The underlying exception that interrupted pagination. + + Returns + ------- + str + A message suitable for the ``RuntimeError`` that ``_walk_pages`` + and ``get_stats_data`` raise from the original exception. """ cause_str = str(cause).removesuffix(".") # Some ``requests`` exceptions (e.g. ``Timeout()`` with no args) - # stringify to empty; fall back to the class name so the wrapper - # message is always informative. + # stringify to empty; fall back to the class name so the + # returned message is always informative. if not cause_str.strip(): cause_str = type(cause).__name__ - if cause_str.startswith("429"): + if isinstance(cause, RateLimited): action = "wait for the rate-limit window to reset and retry" else: action = "retry the request (possibly after a short backoff)" @@ -554,7 +634,9 @@ def _construct_api_requests( return request.prepare() -def _next_req_url(resp: requests.Response) -> str | None: +def _next_req_url( + resp: requests.Response, *, body: dict[str, Any] | None = None +) -> str | None: """ Extracts the URL for the next page of results from an HTTP response from a water data endpoint. @@ -563,6 +645,10 @@ def _next_req_url(resp: requests.Response) -> str | None: ---------- resp : requests.Response The HTTP response object containing JSON data and headers. + body : dict, optional + Pre-parsed JSON body for ``resp``. When provided, skips the + ``resp.json()`` call — useful when the caller has already + decoded the body for its own use (avoids a second parse pass). Returns ------- @@ -578,14 +664,15 @@ def _next_req_url(resp: requests.Response) -> str | None: "rel" and "href" keys. - Checks for the "next" relation in the "links" to determine the next URL. """ - body = resp.json() + if body is None: + body = resp.json() if not body.get("numberReturned"): return None header_info = resp.headers if os.getenv("API_USGS_PAT", ""): logger.info( "Remaining requests this hour: %s", - header_info.get("x-ratelimit-remaining", ""), + header_info.get(_QUOTA_HEADER, ""), ) for link in body.get("links", []): if link.get("rel") == "next": @@ -595,7 +682,12 @@ def _next_req_url(resp: requests.Response) -> str | None: return None -def _get_resp_data(resp: requests.Response, geopd: bool) -> pd.DataFrame: +def _get_resp_data( + resp: requests.Response, + geopd: bool, + *, + body: dict[str, Any] | None = None, +) -> pd.DataFrame: """ Extracts and normalizes data from an HTTP response containing GeoJSON features. @@ -607,33 +699,66 @@ def _get_resp_data(resp: requests.Response, geopd: bool) -> pd.DataFrame: geopd : bool Indicates whether geopandas is installed and should be used to handle geometries. + body : dict, optional + Pre-parsed JSON body for ``resp``. When provided, skips the + ``resp.json()`` call — useful when the caller has already + decoded the body for its own use (avoids a second parse pass). Returns ------- gpd.GeoDataFrame or pd.DataFrame - A geopandas GeoDataFrame if geometry is included, or a pandas DataFrame - containing the feature properties and each row's service-specific id. - Returns an empty pandas DataFrame if no features are returned. + A ``GeoDataFrame`` when ``geopd`` is True; otherwise a plain + ``DataFrame`` carrying the feature properties plus an ``id`` + column and a ``geometry`` column (coordinates list) where the + response includes them. Returns an empty ``DataFrame`` when no + features are returned. + + Notes + ----- + The non-geopandas branch builds the frame directly from each + feature's ``properties`` dict, plus the top-level ``id`` and + ``geometry.coordinates`` columns — but adds the ``id`` and + ``geometry`` columns only when at least one feature actually + carries them. This skips the GeoJSON envelope entirely, so + newly-added Feature-level fields (e.g. ``geometry.type`` after + USGS migrated to full GeoJSON geometry objects) can't leak into + the result frame; no reactive drop-list needs maintenance every + time the upstream schema grows. """ - # Check if it's an empty response - body = resp.json() + if body is None: + body = resp.json() if not body.get("numberReturned"): - return pd.DataFrame() + # Preserve the GeoDataFrame type on empty short-circuit so a + # downstream ``pd.concat([empty_page, geo_page])`` doesn't + # downgrade the geopd-installed user's result to a plain + # DataFrame (stripping geometry/CRS). + return gpd.GeoDataFrame() if geopd else pd.DataFrame() + + # Defensive: a 200 with ``numberReturned > 0`` but missing + # ``features`` is a real schema-drift shape (mirrors the guard in + # ``_handle_stats_nesting``). Treat as empty rather than crash with + # ``KeyError`` — the wrapped failure would otherwise look like a + # transient transport error to ``_paginate``'s exception handler. + features = body.get("features") or [] + if not features: + return gpd.GeoDataFrame() if geopd else pd.DataFrame() - # If geopandas not installed, return a pandas dataframe if not geopd: - df = pd.json_normalize(body["features"], sep="_") - df = df.drop( - columns=["type", "geometry", "AsGeoJSON(geometry)"], errors="ignore" - ) - df.columns = [col.replace("properties_", "") for col in df.columns] - df.rename(columns={"geometry_coordinates": "geometry"}, inplace=True) - df = df.loc[:, ~df.columns.duplicated()] + df = pd.json_normalize([f.get("properties") or {} for f in features], sep="_") + # Always materialize the ``id`` column (may be all-None) so + # ``_arrange_cols``'s ``df.rename(columns={"id": output_id})`` + # produces the documented service-specific output_id column + # (daily_id, channel_measurements_id, …) even if the upstream + # response carried no feature-level id. + df["id"] = [f.get("id") for f in features] + geoms = [(f.get("geometry") or {}).get("coordinates") for f in features] + if any(g is not None for g in geoms): + df["geometry"] = geoms return df # Organize json into geodataframe and make sure id column comes along. - df = gpd.GeoDataFrame.from_features(body["features"]) - df["id"] = pd.json_normalize(body["features"])["id"].values + df = gpd.GeoDataFrame.from_features(features) + df["id"] = pd.json_normalize(features)["id"].values df = df[["id"] + [col for col in df.columns if col != "id"]] # If no geometry present, then return pandas dataframe. A geodataframe @@ -644,97 +769,269 @@ def _get_resp_data(resp: requests.Response, geopd: bool) -> pd.DataFrame: return df -def _walk_pages( +@contextmanager +def _session(client: requests.Session | None) -> Iterator[requests.Session]: + """ + Yield a usable session, picking the best available source. + + Resolution order: + + 1. ``client`` if the caller supplied one (borrowed; not closed + here — the caller owns its lifecycle). + 2. The chunker's shared session if we're inside a ``ChunkedCall`` + fan-out (per :func:`chunking.get_active_session`). Borrowed; + ``ChunkedCall.resume`` closes it on exit. + 3. A fresh short-lived ``requests.Session`` opened here and closed + on context exit. + + Parameters + ---------- + client : requests.Session or None + A caller-owned session to borrow, or ``None`` to defer to the + chunker's shared session or a temporary one. + + Yields + ------ + requests.Session + The chosen session. + """ + if client is not None: + yield client + return + shared = get_active_session() + if shared is not None: + yield shared + return + with requests.Session() as new: + yield new + + +def _aggregate_paginated_response( + initial: requests.Response, + last: requests.Response, + total_elapsed: timedelta, +) -> requests.Response: + """ + Build a single response covering a paginated call. + + Returns a shallow copy of ``initial`` with ``.headers`` set to the + LAST page's (so downstream sees current ``x-ratelimit-remaining``) + and ``.elapsed`` set to total wall-clock. The canonical + ``initial.url`` is preserved (it's the user's original query). + Both ``initial`` and ``last`` are left unmutated, mirroring the + convention of + :func:`dataretrieval.waterdata.chunking._combine_chunk_responses`. + + Parameters + ---------- + initial : requests.Response + First-page response (the canonical one for ``md.url``). + last : requests.Response + Last-page response — supplies the headers to copy over. + total_elapsed : datetime.timedelta + Cumulative wall-clock across every page, including ``initial``. + + Returns + ------- + requests.Response + A shallow copy of ``initial`` with ``.headers`` set to a fresh + ``CaseInsensitiveDict`` and ``.elapsed`` set to the + cumulative wall-clock. ``initial.headers`` / ``initial.elapsed`` + are never mutated, so callers holding a pre-pagination + reference still see the original first-page values. Other + ``Response`` fields (``_content``, ``raw``, ``cookies``, + ``request``) are still aliased to ``initial`` by the shallow + copy — callers that mutate those will affect ``initial``. + """ + final = copy.copy(initial) + final.headers = CaseInsensitiveDict(last.headers) + final.elapsed = total_elapsed + return final + + +_Cursor = TypeVar("_Cursor") + + +def _paginate( + initial_req: requests.PreparedRequest, + *, geopd: bool, - req: requests.PreparedRequest, + parse_response: Callable[[requests.Response], tuple[pd.DataFrame, _Cursor | None]], + follow_up: Callable[[_Cursor, requests.Session], requests.Response], client: requests.Session | None = None, ) -> tuple[pd.DataFrame, requests.Response]: """ - Iterates through paginated API responses and aggregates the results - into a single DataFrame. + Drive a paginated request to completion. + + Common shape behind :func:`_walk_pages` and :func:`get_stats_data`: + send the initial request, then loop calling ``follow_up`` until + ``parse_response`` reports a ``None`` cursor, accumulating frames + and elapsed time. Any mid-pagination failure raises + ``RuntimeError`` wrapping the cause — the API exposes no resume + cursor, so the caller's only recovery is to retry the whole call. Parameters ---------- + initial_req : requests.PreparedRequest + First-page request to send. geopd : bool - Indicates whether geopandas is installed and should be used for handling - geometries. - req : requests.PreparedRequest - The initial HTTP request to send. - client : Optional[requests.Session], default None - An optional HTTP client to use for requests. If not provided, a new - client is created. + Whether ``geopandas`` is available — logged once at WARNING + level when ``False`` (matches historical behavior of both + callers). + parse_response : callable + ``resp -> (df, next_cursor_or_None)``. Returns the page's + DataFrame and the cursor (URL, token, …) used to drive + ``follow_up`` for the next page; ``None`` terminates the loop. + follow_up : callable + ``(cursor, session) -> requests.Response``. Builds and sends + the next-page request. + client : requests.Session, optional + Caller-borrowed session. ``None`` (default) means use the + chunker's shared session (if inside a chunked call) or open + a temporary one. Returns ------- - pd.DataFrame - A DataFrame containing the aggregated results from all pages. - requests.Response - The initial response object containing metadata about the first request. + df : pandas.DataFrame + Concatenation of every page's parsed frame. + response : requests.Response + A shallow copy of the first-page response, with ``.headers`` + rebuilt as a fresh ``CaseInsensitiveDict`` reflecting the last + page and ``.elapsed`` set to cumulative wall-clock. The + canonical URL is preserved from the first page. The original + first-page response is not mutated. Raises ------ RuntimeError - On a non-200 initial response (bare message from ``_error_body``) - or any failure on a subsequent page (wrapped message built by - ``_paginated_failure_message`` with the original exception - chained as ``__cause__``). + On a non-200 initial response (typed + :class:`~dataretrieval.waterdata.chunking.RateLimited` / + :class:`~dataretrieval.waterdata.chunking.ServiceUnavailable` + for 429/5xx, otherwise plain ``RuntimeError`` from + :func:`_error_body`), on an initial-page parse failure + (wrapped via :func:`_paginated_failure_message` with the + original exception on ``__cause__``), or any failure on a + subsequent page (same wrapping). requests.exceptions.RequestException Network-level failures on the *initial* request (e.g. - ``ConnectionError``, ``Timeout``) propagate unmodified to - preserve their specific type for callers that branch on it. - Equivalent failures on *subsequent* pages are caught and - re-raised as ``RuntimeError`` per the rule above. + ``ConnectionError``, ``Timeout``) propagate unmodified so + callers can branch on the specific type; equivalent failures + on subsequent pages are wrapped per above. """ - logger.info("Requesting: %s", req.url) - + logger.info("Requesting: %s", initial_req.url) if not geopd: logger.warning( "Geopandas not installed. Geometries will be flattened " "into pandas DataFrames." ) - # Get first response from client - # using GET or POST call - close_client = client is None - client = client or requests.Session() - try: - resp = client.send(req) + with _session(client) as sess: + resp = sess.send(initial_req) _raise_for_non_200(resp) - - # Store the initial response for metadata + # Keep the original-request response as the "canonical" one for + # ``md.url`` reproducibility; ``.headers`` and ``.elapsed`` get + # overwritten with latest/cumulative values below. initial_response = resp + total_elapsed = resp.elapsed - # Grab some aspects of the original request: headers and the - # request type (GET or POST) - method = req.method.upper() - headers = dict(req.headers) - content = req.body if method == "POST" else None - - # List to collect dataframes from each page - dfs = [_get_resp_data(resp, geopd=geopd)] - curr_url = _next_req_url(resp) - while curr_url: + try: + df, cursor = parse_response(resp) + except Exception as e: # noqa: BLE001 + # Initial-page parse failures (malformed JSON, missing + # ``features``, schema drift) get the same wrapped-message + # treatment as follow-up failures so callers see a + # consistent diagnostic regardless of which page broke. + logger.warning("Initial response parse failed.") + raise RuntimeError(_paginated_failure_message(0, e)) from e + dfs = [df] + while cursor is not None: try: - resp = client.request( - method, - curr_url, - headers=headers, - data=content if method == "POST" else None, - ) + resp = follow_up(cursor, sess) _raise_for_non_200(resp) - dfs.append(_get_resp_data(resp, geopd=geopd)) - curr_url = _next_req_url(resp) + df, cursor = parse_response(resp) + dfs.append(df) + total_elapsed += resp.elapsed except Exception as e: # noqa: BLE001 logger.warning( - "Request failed for URL: %s. Data download interrupted.", curr_url + "Request failed at cursor %r. Data download interrupted.", + cursor, ) raise RuntimeError(_paginated_failure_message(len(dfs), e)) from e - # Concatenate all pages at once for efficiency - return pd.concat(dfs, ignore_index=True), initial_response - finally: - if close_client: - client.close() + # Aggregate headers / elapsed onto a COPY of the initial + # response so the user's caller never sees an in-place + # mutation of the response object they may have inspected + # mid-pagination via a hook or test fixture. + final_response = _aggregate_paginated_response( + initial_response, resp, total_elapsed + ) + return pd.concat(dfs, ignore_index=True), final_response + + +def _walk_pages( + geopd: bool, + req: requests.PreparedRequest, + client: requests.Session | None = None, +) -> tuple[pd.DataFrame, requests.Response]: + """ + Iterate through paginated OGC API responses and aggregate into one + DataFrame. + + Thin wrapper that hands off to :func:`_paginate` with OGC-specific + strategies: pages are parsed via :func:`_get_resp_data` and the + next-page cursor is the URL from the response's ``links`` array + (per :func:`_next_req_url`). + + Parameters + ---------- + geopd : bool + Whether geopandas is installed (drives geometry handling). + req : requests.PreparedRequest + The initial HTTP request to send. + client : requests.Session, optional + Caller-borrowed session; ``None`` defers session management to + :func:`_paginate`. + + Returns + ------- + pd.DataFrame + A DataFrame containing the aggregated results from all pages. + requests.Response + Aggregated response — initial-request URL (for query identity), + final page's headers (so downstream sees current rate-limit + state), and cumulative ``elapsed`` summed across pages. + + Raises + ------ + RuntimeError + See :func:`_paginate`. + requests.exceptions.RequestException + See :func:`_paginate`. + """ + method = req.method # ``PreparedRequest.method`` is already upper-cased. + headers = dict(req.headers) + content = req.body if method == "POST" else None + + def parse_response(resp: requests.Response) -> tuple[pd.DataFrame, str | None]: + body = resp.json() + # Coerce falsy cursors (empty href, etc.) to None so + # _paginate's `while cursor is not None` terminates instead of + # spinning on a meaningless value. + return ( + _get_resp_data(resp, geopd=geopd, body=body), + _next_req_url(resp, body=body) or None, + ) + + def follow_up(cursor: str, sess: requests.Session) -> requests.Response: + return sess.request(method, cursor, headers=headers, data=content) + + return _paginate( + req, + geopd=geopd, + parse_response=parse_response, + follow_up=follow_up, + client=client, + ) def _deal_with_empty( @@ -864,7 +1161,6 @@ def _type_cols(df: pd.DataFrame) -> pd.DataFrame: "construction_date", "end", "end_utc", - "datetime", # unused "last_modified", "time", ] @@ -920,7 +1216,8 @@ def get_ogc_data( output_id : str The name of the output identifier to use in the request. service : str - The OGC service type (e.g., "wfs", "wms"). + The OGC API collection name (e.g., ``"daily"``, + ``"monitoring-locations"``, ``"continuous"``). Returns ------- @@ -957,17 +1254,18 @@ def get_ogc_data( return return_list, BaseMetadata(response) -@filters.chunked(build_request=_construct_api_requests) +@chunking.multi_value_chunked(build_request=_construct_api_requests) def _fetch_once( args: dict[str, Any], ) -> tuple[pd.DataFrame, requests.Response]: """Send one prepared-args OGC request; return the frame + response. - Filter chunking is added orthogonally by the ``@filters.chunked`` - decorator: with no filter (or an un-chunkable one) the decorator - passes ``args`` through to this body; with a chunkable filter it - fans out and calls this body once per sub-filter, then combines. - Either way the return shape is ``(frame, response)``. + ``@chunking.multi_value_chunked`` models every multi-value list + parameter and the cql-text filter as a chunkable axis, greedy-halves + the biggest chunk across all axes until each sub-request URL fits, + and iterates the cartesian product. With no chunkable inputs the + decorator passes args through unchanged. Either way the return + shape is ``(frame, response)``. """ req = _construct_api_requests(**args) return _walk_pages(geopd=GEOPANDAS, req=req) @@ -985,30 +1283,59 @@ def _handle_stats_nesting( ---------- body : Dict[str, Any] The JSON response body from the statistics service containing nested data. + geopd : bool, optional + Whether ``geopandas`` is available — when ``True`` the returned + frame is a ``GeoDataFrame``; when ``False`` (default) a plain + ``pd.DataFrame`` is returned with geometry flattened. Returns ------- pd.DataFrame A DataFrame containing the flattened statistical data. + + Notes + ----- + The non-geopandas branch uses the same schema-aware extraction as + :func:`_get_resp_data`: it builds the per-feature outer frame + directly from each feature's ``properties`` (minus the nested + ``data`` field, which is unrolled separately below via the + ``record_path`` json_normalize), then adds ``id`` and ``geometry`` + only when present. Skipping the GeoJSON envelope keeps newly-added + fields like ``geometry.type`` from leaking into the result. """ if body is None: - return pd.DataFrame() - + return gpd.GeoDataFrame() if geopd else pd.DataFrame() + + # An empty (or missing) features list — a real mid-pagination + # shape — would otherwise crash the downstream merge with + # ``KeyError: 'monitoring_location_id'`` because neither df nor + # dat would carry the merge key. Bail out with an empty frame — + # ``GeoDataFrame`` when geopd is available so the eventual + # ``pd.concat`` with non-empty geo pages doesn't downgrade to a + # plain DataFrame and strip geometry/CRS. + features = body.get("features") or [] + if not features: + return gpd.GeoDataFrame() if geopd else pd.DataFrame() + + # The geopd-missing warning is emitted once at the top of + # ``get_stats_data`` (parallel to ``_walk_pages``); doing it here + # would log per page. if not geopd: - logger.info( - "Geopandas not installed. Geometries will be flattened " - "into pandas DataFrames." - ) - - # If geopandas not installed, return a pandas dataframe - # otherwise return a geodataframe - if not geopd: - df = pd.json_normalize(body["features"]).drop( - columns=["type", "properties.data"], errors="ignore" - ) + outer_props = [ + {k: v for k, v in (f.get("properties") or {}).items() if k != "data"} + for f in features + ] + df = pd.json_normalize(outer_props, sep=".") df.columns = df.columns.str.split(".").str[-1] + # Stats features don't carry a top-level ``id`` field — the + # geopandas branch (``GeoDataFrame.from_features``) doesn't + # surface one either, so the non-geopd branch stays + # consistent by NOT adding an id column. + geoms = [(f.get("geometry") or {}).get("coordinates") for f in features] + if any(g is not None for g in geoms): + df["geometry"] = geoms else: - df = gpd.GeoDataFrame.from_features(body["features"]).drop( + df = gpd.GeoDataFrame.from_features(features).drop( columns=["data"], errors="ignore" ) @@ -1022,7 +1349,6 @@ def _handle_stats_nesting( ["features", "properties", "data", "parameter_code"], ["features", "properties", "data", "unit_of_measure"], ["features", "properties", "data", "parent_time_series_id"], - # ["features", "geometry", "coordinates"], ], meta_prefix="", errors="ignore", @@ -1138,75 +1464,41 @@ def get_stats_data( """ url = f"{STATISTICS_API_URL}/{service}" - - headers = _default_headers() - request = requests.Request( method="GET", url=url, - headers=headers, + headers=_default_headers(), params=args, ) req = request.prepare() - logger.info("Request: %s", req.url) - - # create temp client if not provided - # and close it after the request is done - close_client = client is None - client = client or requests.Session() - - try: - resp = client.send(req) - _raise_for_non_200(resp) - - # Store the initial response for metadata - initial_response = resp - - # Grab some aspects of the original request: headers and the - # request type (GET or POST) - method = req.method.upper() - headers = dict(req.headers) + method = req.method # ``PreparedRequest.method`` is already upper-cased. + headers = dict(req.headers) + def parse_response(resp: requests.Response) -> tuple[pd.DataFrame, str | None]: body = resp.json() - all_dfs = [_handle_stats_nesting(body, geopd=GEOPANDAS)] - - # Look for a next code in the response body - next_token = body["next"] - - while next_token: - args["next_token"] = next_token - - try: - resp = client.request( - method, - url=url, - params=args, - headers=headers, - ) - _raise_for_non_200(resp) - body = resp.json() - all_dfs.append(_handle_stats_nesting(body, geopd=GEOPANDAS)) - next_token = body["next"] - except Exception as e: # noqa: BLE001 - logger.warning( - "Request failed for URL: %s (next_token=%s). " - "Data download interrupted.", - url, - next_token, - ) - raise RuntimeError(_paginated_failure_message(len(all_dfs), e)) from e - - dfs = pd.concat(all_dfs, ignore_index=True) if len(all_dfs) > 1 else all_dfs[0] + # Coerce falsy cursors ("", 0) to None so _paginate terminates. + # USGS uses "next": null at end-of-stream, but defensive coerce + # protects against any "" sentinel a future schema might use. + return _handle_stats_nesting(body, geopd=GEOPANDAS), body.get("next") or None + + def follow_up(cursor: str, sess: requests.Session) -> requests.Response: + # Build a fresh params dict per page so the caller's ``args`` is + # never mutated. + return sess.request( + method, url=url, params={**args, "next_token": cursor}, headers=headers + ) - # . If expand percentiles is True, make each percentile - # its own row in the returned dataset. - if expand_percentiles: - dfs = _expand_percentiles(dfs) + df, response = _paginate( + req, + geopd=GEOPANDAS, + parse_response=parse_response, + follow_up=follow_up, + client=client, + ) - return dfs, BaseMetadata(initial_response) - finally: - if close_client: - client.close() + if expand_percentiles: + df = _expand_percentiles(df) + return df, BaseMetadata(response) def _check_profiles( @@ -1365,23 +1657,36 @@ def _get_args( local_vars: dict[str, Any], exclude: set[str] | None = None ) -> dict[str, Any]: """ - Standardize parameter filtering for WaterData API functions. - - Filters out internal function arguments ('service', 'output_id') - and None values from the provided local variables dictionary. - Additional variables can be excluded via the 'exclude' parameter. + Build the API-request kwargs dict from a getter's ``locals()``. + + Drops bookkeeping keys (``service``, ``output_id``, anything in + ``exclude``) and ``None``-valued kwargs, then normalizes the + remaining values: + + - ``monitoring_location_id`` is validated against the AGENCY-ID + format (per :func:`_check_monitoring_location_id`). + - ``properties`` is materialized to ``list[str]`` (a bare string + gets wrapped in a single-element list so downstream + ``",".join(properties)`` doesn't iterate per character). + - Any other ``Iterable[str]`` that isn't in ``_NO_NORMALIZE_PARAMS`` + is materialized to ``list[str]`` via + :func:`_normalize_str_iterable` so downstream code that branches + on ``isinstance(v, (list, tuple))`` works for ``pandas.Series``, + ``numpy.ndarray``, generators, etc. + - Scalars, strings, and ``_NO_NORMALIZE_PARAMS`` values pass through + unchanged. Parameters ---------- local_vars : dict[str, Any] - Dictionary of local variables, typically from `locals()`. + Dictionary of local variables, typically from ``locals()``. exclude : set[str], optional Additional keys to exclude from the resulting dictionary. Returns ------- dict[str, Any] - Filtered dictionary of arguments for API requests. + Filtered and normalized arguments for API requests. """ to_exclude = {"service", "output_id"} if exclude: diff --git a/tests/waterdata_chunking_test.py b/tests/waterdata_chunking_test.py new file mode 100644 index 00000000..d9a54a7d --- /dev/null +++ b/tests/waterdata_chunking_test.py @@ -0,0 +1,1271 @@ +"""Tests for ``dataretrieval.waterdata.chunking``. + +These tests exercise the joint planner with a fake ``build_request`` +whose URL byte length is a deterministic function of its inputs: + +- non-chunkable args contribute ``base_bytes``, +- every multi-value list contributes ``len(",".join(map(str, v)))``, +- the ``filter`` kwarg contributes ``len(filter)``. + +That isolates planner behaviour from the real HTTP request builder. +The one exception is +``test_joint_planner_url_construction_long_filter_and_long_sites``, +which uses the real ``_construct_api_requests`` so URL-encoding +surprises (``%``, ``+``, ``/``, ``&``, …) can't pass against a fake +and then fail in production. +""" + +import datetime +import sys +from unittest import mock +from urllib.parse import quote_plus + +import pandas as pd +import pytest + +if sys.version_info < (3, 10): + pytest.skip("Skip entire module on Python < 3.10", allow_module_level=True) + +from dataretrieval.waterdata import chunking as _chunking +from dataretrieval.waterdata.chunking import ( + _LIST_SEP, + _OR_SEP, + _QUOTA_HEADER, + ChunkInterrupted, + ChunkPlan, + QuotaExhausted, + RateLimited, + RequestExceedsQuota, + RequestTooLarge, + ServiceInterrupted, + ServiceUnavailable, + _chunked_session, + _extract_axes, + _read_remaining, + multi_value_chunked, +) +from dataretrieval.waterdata.utils import _construct_api_requests + + +class _FakeReq: + __slots__ = ("url", "body") + + def __init__(self, url, body=None): + self.url = url + self.body = body + + +def _fake_build(*, base=200, **kwargs): + """Fake build_request: URL length deterministic in its inputs. + + Mirrors the GET-routed shape: payload goes in the URL, body is None. + List/string values are URL-encoded via ``quote_plus`` so the fake's + byte count matches what the real ``_construct_api_requests`` would + produce; otherwise an alphanumeric test could pass against the fake + but fail in production once values containing ``%``, ``+``, ``/``, + ``&`` etc. (which expand under encoding) reach the same code path. + """ + bytes_ = base + for v in kwargs.values(): + if isinstance(v, (list, tuple)): + bytes_ += len(quote_plus(",".join(map(str, v)))) + elif isinstance(v, str): + bytes_ += len(quote_plus(v)) + return _FakeReq("x" * bytes_) + + +def test_never_chunk_covers_all_date_range_params(): + """``_NEVER_CHUNK`` and ``_DATE_RANGE_PARAMS`` are maintained in + separate modules (chunker vs request builder) for layering reasons, + but every date-range param MUST be excluded from chunking — a + range value isn't an enumerable set to split. Guard against drift: + adding a new param to ``_DATE_RANGE_PARAMS`` without also adding + it to ``_NEVER_CHUNK`` would silently let the chunker try to + comma-join an interval string.""" + from dataretrieval.waterdata.chunking import _NEVER_CHUNK + from dataretrieval.waterdata.utils import _DATE_RANGE_PARAMS + + missing = _DATE_RANGE_PARAMS - _NEVER_CHUNK + assert not missing, ( + f"_DATE_RANGE_PARAMS contains entries not in _NEVER_CHUNK: " + f"{sorted(missing)}. Add them to chunking._NEVER_CHUNK." + ) + + +def test_extract_axes_picks_up_list_dims_and_filter(): + """Every multi-value list parameter becomes one axis with ``","`` + joiner; the cql-text filter becomes one axis with ``" OR "`` joiner + and its atoms are the top-level OR-clauses.""" + args = { + "monitoring_location_id": ["USGS-A", "USGS-B"], + "parameter_code": ["00060", "00065"], + "filter": "a='1' OR b='2' OR c='3'", + } + axes = _extract_axes(args) + by_key = {ax.arg_key: ax for ax in axes} + assert set(by_key) == {"monitoring_location_id", "parameter_code", "filter"} + assert by_key["monitoring_location_id"].joiner == _LIST_SEP + assert by_key["monitoring_location_id"].atoms == ("USGS-A", "USGS-B") + assert by_key["parameter_code"].joiner == _LIST_SEP + assert by_key["filter"].joiner == _OR_SEP + assert by_key["filter"].atoms == ("a='1'", "b='2'", "c='3'") + + +def test_extract_axes_skips_singletons_and_never_chunk_params(): + """Length-1 lists and ``_NEVER_CHUNK`` params (``bbox``, ``limit``, + date intervals, ...) produce no axes — there's nothing to split.""" + args = { + "monitoring_location_id": ["USGS-A"], # length 1 + "bbox": [-95, 40, -90, 45], + "limit": 100, + "filter": "a='1'", # one clause, no OR to split + } + assert _extract_axes(args) == [] + + +def test_chunk_plan_returns_passthrough_when_no_chunkable_axes(): + """Scalar args with nothing to chunk → passthrough, even at a + URL limit the request technically exceeds (the server may 414, + but ``ChunkPlan`` has nothing to split).""" + args = {"monitoring_location_id": "scalar-only"} + plan = ChunkPlan(args, _fake_build, url_limit=10) + assert plan.axes == [] + assert plan.total == 1 + + +def test_chunk_plan_greedy_halving_targets_largest_axis_chunk(): + """The biggest chunk across all axes halves first — when one list + axis dominates URL bytes, only it gets split until it stops being + the largest.""" + args = { + "monitoring_location_id": ["X" * 30, "Y" * 30, "Z" * 30, "W" * 30], + "parameter_code": ["00060", "00065"], + } + # full URL ≈ 200 + 123 + 12 = 335; force splitting the heavy axis only. + plan = ChunkPlan(args, _fake_build, url_limit=310) + assert len(plan.chunks["monitoring_location_id"]) > 1 + assert len(plan.chunks["parameter_code"]) == 1 + + +def test_chunk_plan_raises_request_too_large_at_singleton_floor(): + """Limit below the singleton-per-axis floor → ``RequestTooLarge``; + there's nothing left to shrink.""" + args = {"monitoring_location_id": ["A", "B"]} + # base=200 alone exceeds limit=100; chunking can't help. + with pytest.raises(RequestTooLarge, match="smallest reducible"): + ChunkPlan(args, _fake_build, url_limit=100) + + +def test_chunk_plan_fans_out_filter_when_list_alone_cannot_fit(): + """When the request can only fit by chunking BOTH a list axis AND + the filter axis, the plan ends up with chunk counts >1 on at + least one of the two axis kinds.""" + clauses = [f"f='{i}'" for i in range(10)] + args = { + "monitoring_location_id": ["A" * 10, "B" * 10, "C" * 10, "D" * 10], + "filter": " OR ".join(clauses), + } + plan = ChunkPlan(args, _fake_build, url_limit=240) + # At least one axis must end up split. + assert any(len(plan.chunks[ax.arg_key]) > 1 for ax in plan.axes) + + +def test_chunk_plan_minimizes_total_sub_requests(): + """When both axes need shrinking, picking smaller filter chunks + frees URL budget for larger list chunks, and vice versa. The + planner should pick the allocation with the *fewest* total + sub-requests, not just the first allocation that fits.""" + # 16 short clauses (no inflation under URL encoding so the math is + # tractable). Each clause = 5 bytes (e.g. "f='0'"); full filter ≈ + # 16*5 + 15*4 = 140 bytes raw. + clauses = [f"f='{i}'" for i in range(16)] + args = { + "sites": ["S" * 30 for _ in range(8)], # 8 sites @ 30 chars + "filter": " OR ".join(clauses), + } + # Tight limit forces both axes to participate. + plan = ChunkPlan(args, _fake_build, url_limit=380) + # Plan must beat the bail-floor-style worst case (8 singletons × 16 + # filter chunks = 128 sub-requests) by a healthy margin. + assert plan.total < 128 + + +def test_chunk_plan_raises_when_smallest_plan_doesnt_fit(): + """If even the most aggressive joint plan (singleton lists + + singleton filter clauses) still exceeds the limit, surface + RequestTooLarge — there's nothing left to shrink.""" + args = { + "monitoring_location_id": ["A" * 10, "B" * 10], + "filter": "x='12345' OR x='67890'", # min clause is 9 chars + } + # Base 200 + singleton site (10) + singleton clause (9) = 219; limit + # below 219 → no joint plan can fit. + with pytest.raises(RequestTooLarge): + ChunkPlan(args, _fake_build, url_limit=210) + + +def test_chunk_plan_passthrough_when_request_fits(): + """URL under limit → trivial passthrough plan (no axes, total=1), + and ``iter_sub_args`` yields exactly one sub-args dict equal to + the original args.""" + args = {"monitoring_location_id": ["A", "B", "C"], "limit": 100} + plan = ChunkPlan(args, _fake_build, url_limit=8000) + assert plan.axes == [] + assert plan.total == 1 + subs = list(plan.iter_sub_args()) + assert len(subs) == 1 + assert subs[0] == args + + +def test_multi_value_chunked_passes_through_when_url_fits(): + """No planning needed → decorator calls underlying function exactly once + with the original args.""" + calls = [] + + @multi_value_chunked(build_request=_fake_build, url_limit=8000) + def fetch(args): + calls.append(args) + return pd.DataFrame(), mock.Mock( + elapsed=datetime.timedelta(seconds=0.1), headers={} + ) + + fetch({"monitoring_location_id": ["A", "B"]}) + assert len(calls) == 1 + assert calls[0]["monitoring_location_id"] == ["A", "B"] + + +def test_multi_value_chunked_emits_cartesian_product(): + """Two chunkable axes, each split into 2 chunks → exactly 4 sub-requests, + each pairing one chunk from each axis.""" + calls = [] + + @multi_value_chunked(build_request=_fake_build, url_limit=240) + def fetch(args): + calls.append({k: v for k, v in args.items() if k in ("sites", "pcodes")}) + return pd.DataFrame(), mock.Mock( + elapsed=datetime.timedelta(seconds=0.1), headers={} + ) + + fetch( + { + "sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10], + "pcodes": ["P1" * 10, "P2" * 10, "P3" * 10, "P4" * 10], + } + ) + # Both heavy → planner should split both axes. Confirm a cartesian shape: + # every unique site-chunk pairs with every unique pcode-chunk. + sites_seen = {tuple(c["sites"]) for c in calls} + pcodes_seen = {tuple(c["pcodes"]) for c in calls} + assert len(calls) == len(sites_seen) * len(pcodes_seen) + assert len(sites_seen) > 1 + assert len(pcodes_seen) > 1 + + +def test_multi_value_chunked_emits_3d_cartesian_product(): + """Three chunkable axes, each forced to split → exhaustive cartesian + product across all three. Verifies the halving loop in + ``ChunkPlan._plan`` handles N>2 axes uniformly and the ``ChunkedCall`` + ``itertools.product`` enumerates every combination exactly once.""" + calls = [] + + @multi_value_chunked(build_request=_fake_build, url_limit=240) + def fetch(args): + calls.append(tuple(tuple(args[k]) for k in ("sites", "pcodes", "stats"))) + return pd.DataFrame(), mock.Mock( + elapsed=datetime.timedelta(seconds=0.1), headers={} + ) + + fetch( + { + "sites": ["S" * 12 + str(i) for i in range(4)], + "pcodes": ["P" * 12 + str(i) for i in range(4)], + "stats": ["T" * 12 + str(i) for i in range(4)], + } + ) + + # Three independent axes — every (site_chunk, pcode_chunk, stat_chunk) + # triple must appear exactly once. Confirm: + sites_seen = {c[0] for c in calls} + pcodes_seen = {c[1] for c in calls} + stats_seen = {c[2] for c in calls} + + assert len(sites_seen) > 1, "sites axis was not split" + assert len(pcodes_seen) > 1, "pcodes axis was not split" + assert len(stats_seen) > 1, "stats axis was not split" + + # Cartesian shape: # sub-requests == product of unique chunks across axes + expected = len(sites_seen) * len(pcodes_seen) * len(stats_seen) + assert len(calls) == expected, ( + f"expected {expected} cartesian-product sub-requests, got {len(calls)}" + ) + # And no triple repeats (exhaustive enumeration, no duplicates). + assert len(set(calls)) == len(calls) + # The chunked values, when unioned across calls, recover the original list. + assert {x for tup in sites_seen for x in tup} == { + "S" * 12 + str(i) for i in range(4) + } + assert {x for tup in pcodes_seen for x in tup} == { + "P" * 12 + str(i) for i in range(4) + } + assert {x for tup in stats_seen for x in tup} == { + "T" * 12 + str(i) for i in range(4) + } + + +def test_multi_value_chunked_lazy_url_limit(monkeypatch): + """``url_limit=None`` → resolve chunking._WATERDATA_URL_BYTE_LIMIT at call + time, so tests that patch the constant affect this decorator too.""" + calls = [] + + @multi_value_chunked(build_request=_fake_build) # url_limit defaults to None + def fetch(args): + calls.append(args) + return pd.DataFrame(), mock.Mock( + elapsed=datetime.timedelta(seconds=0.1), headers={} + ) + + monkeypatch.setattr(_chunking, "_WATERDATA_URL_BYTE_LIMIT", 240) + # 4 sites of 10 chars → exceeds 240 → planner splits. + fetch({"sites": ["S" * 10 + str(i) for i in range(4)]}) + assert len(calls) > 1, "patched constant should drive chunking" + + +def test_chunked_session_shared_across_sub_requests(): + """Every sub-request of one chunked call sees the same + ``requests.Session`` on the ``_chunked_session`` ContextVar, so + downstream paginated helpers (``_walk_pages``) can reuse the + connection pool instead of handshaking fresh on each sub-request.""" + sessions_seen = [] + + @multi_value_chunked(build_request=_fake_build, url_limit=240) + def fetch(args): + sessions_seen.append(_chunked_session.get()) + return pd.DataFrame(), mock.Mock( + elapsed=datetime.timedelta(seconds=0.1), headers={} + ) + + # Outside a chunked call: no session published. + assert _chunked_session.get() is None + + fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + + # Plan must actually fan out — otherwise the test isn't exercising + # the shared-session path. + assert len(sessions_seen) > 1 + # Every sub-request saw a Session, not None. + assert all(s is not None for s in sessions_seen) + # And it was the same object every time. + assert len({id(s) for s in sessions_seen}) == 1 + # On exit the ContextVar is reset to its default. + assert _chunked_session.get() is None + + +def test_chunked_session_isolated_per_resume(): + """A follow-up ``resume`` after an interruption opens a fresh + session — the previous one was closed when its ``resume`` returned. + The ContextVar is reset between calls so leakage can't carry + a closed session into the retry.""" + state = {"i": 0, "blow_up": True} + + @multi_value_chunked(build_request=_fake_build, url_limit=240) + def fetch(args): + i = state["i"] + state["i"] += 1 + if i == 1 and state["blow_up"]: + raise RateLimited("429: Too many requests.") + return ( + pd.DataFrame({"sites": list(args["sites"])}), + mock.Mock( + elapsed=datetime.timedelta(seconds=0.1), + headers={_QUOTA_HEADER: "500"}, + ), + ) + + with pytest.raises(QuotaExhausted) as excinfo: + fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + + # First resume's session is closed; ContextVar is reset. + assert _chunked_session.get() is None + + state["blow_up"] = False + excinfo.value.call.resume() + # Second resume's session is also cleaned up. + assert _chunked_session.get() is None + + +def _quota_response(remaining: int | str | None) -> mock.Mock: + """A mock requests.Response-like object whose ``x-ratelimit-remaining`` + header reflects the given value (None → header absent).""" + resp = mock.Mock(elapsed=datetime.timedelta(seconds=0.1)) + resp.headers = {} if remaining is None else {_QUOTA_HEADER: str(remaining)} + return resp + + +def test_read_remaining_parses_header(): + assert _read_remaining(_quota_response(42)) == 42 + + +def test_read_remaining_returns_none_when_header_missing(): + """No rate-limit header → ``None`` so ``ChunkedCall`` can branch + on ``is None`` instead of comparing against a magic sentinel.""" + assert _read_remaining(_quota_response(None)) is None + + +def test_read_remaining_returns_none_on_malformed_header(): + """Non-integer header value → ``None`` so a parse failure doesn't + trip the quota check.""" + assert _read_remaining(_quota_response("not-a-number")) is None + + +def test_request_exceeds_quota_after_first_chunk(): + """Plan totals 4 sub-requests. The first response reports + ``x-ratelimit-remaining=1`` — only 2 sub-requests fit total + (the one just issued + 1 more). The wrapper must raise + ``RequestExceedsQuota`` *before* issuing chunk 2, and the + exception must carry a ``.call`` handle so the first chunk's + already-fetched data is recoverable.""" + calls: list[dict] = [] + + def fetch(args): + calls.append(args) + return pd.DataFrame({"sites": list(args["sites"])}), _quota_response(1) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + + with pytest.raises(RequestExceedsQuota) as excinfo: + decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + + err = excinfo.value + assert err.planned_chunks == 4 + assert err.available == 2 # remaining=1 + the chunk we just spent + assert err.deficit == 2 + assert len(calls) == 1, "only the first chunk should have been issued" + # The originating ChunkedCall is exposed on .call so the first + # chunk's already-fetched data is recoverable. + assert err.call is not None + assert err.call.completed_chunks == 1 + assert not err.call.partial_frame.empty + + +def test_request_exceeds_quota_message_reports_deficit(): + """The error must surface planned / available / deficit so callers + know precisely how far over budget the call is.""" + e = RequestExceedsQuota(planned_chunks=10, available=4, deficit=6) + msg = str(e) + assert "10" in msg + assert "4" in msg + assert "6" in msg + + +def test_request_exceeds_quota_not_raised_when_plan_fits(): + """If ``x-ratelimit-remaining`` is large enough to cover the rest + of the plan, ``ChunkedCall`` proceeds normally.""" + remaining_seq = iter([100, 99, 98, 97]) + + def fetch(args): + return ( + pd.DataFrame({"sites": list(args["sites"])}), + _quota_response(next(remaining_seq)), + ) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + df, _ = decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + assert len(df) == 4 + + +def test_no_quota_check_when_header_absent(): + """Without an ``x-ratelimit-remaining`` header ``ChunkedCall`` + has no quota signal and must NOT synthesize a + ``RequestExceedsQuota``; every planned sub-request runs.""" + + def fetch(args): + return pd.DataFrame({"sites": list(args["sites"])}), _quota_response(None) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + df, _ = decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + assert len(df) == 4 + + +def test_quota_exhausted_on_mid_call_429(): + """Mid-call 429 (a concurrent caller drained the window) surfaces + as ``QuotaExhausted`` carrying the partial frame plus the chunk + offset so callers can resume after the window resets.""" + state = {"i": 0} + + def fetch(args): + i = state["i"] + state["i"] += 1 + if i == 2: + # Match _walk_pages's wrapping: a generic mid-pagination + # RuntimeError with the typed RateLimited as __cause__. + try: + raise RateLimited("429: Too many requests made.") + except RateLimited as cause: + raise RuntimeError( + "Paginated request failed after collecting 0 page(s): " + "429: Too many requests made." + ) from cause + return ( + pd.DataFrame({"i": [i], "sites": list(args["sites"])}), + _quota_response(500), + ) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + with pytest.raises(QuotaExhausted) as excinfo: + decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10]}) + + err = excinfo.value + assert err.completed_chunks == 2 # chunks 0 and 1 completed; 429 hit on i=2 + assert err.total_chunks == 5 + assert err.partial_frame is not None + assert set(err.partial_frame["i"]) == {0, 1} + + +def test_quota_exhausted_on_first_chunk_429_has_no_partial_response(): + """A 429 on the very first sub-request means no responses have + completed; ``partial_response`` is ``None`` (and ``partial_frame`` + is empty) so callers can branch on that to distinguish "abort + before any data arrived" from "abort after partial collection".""" + + def fetch(args): + raise RateLimited("429: Too many requests made.") + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + with pytest.raises(QuotaExhausted) as excinfo: + decorated({"sites": ["S1" * 10, "S2" * 10]}) + err = excinfo.value + assert err.completed_chunks == 0 + assert err.partial_response is None + assert err.partial_frame.empty + + +def test_quota_exhausted_resume_picks_up_where_429_stopped(): + """After a mid-call 429 ``ChunkedCall`` raises ``QuotaExhausted``; + once the window resets, ``e.call.resume()`` re-issues only the + sub-requests that hadn't completed and returns the full combined + result. Chunks completed before the 429 are not re-fetched.""" + # The fake fetch 429s on the third call, then succeeds on every + # subsequent call. We track which sub-args have been issued so we + # can assert chunks 0/1 aren't re-fetched on resume. + fetched_sites: list[tuple[str, ...]] = [] + rate_limited_once = {"fired": False} + + def fetch(args): + if len(fetched_sites) == 2 and not rate_limited_once["fired"]: + rate_limited_once["fired"] = True + raise RateLimited("429: Too many requests made.") + site_tuple = tuple(args["sites"]) + fetched_sites.append(site_tuple) + return ( + pd.DataFrame({"sites": list(site_tuple)}), + _quota_response(500), + ) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + sites = ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10] + + # First attempt: 429 on the third sub-request. + with pytest.raises(QuotaExhausted) as excinfo: + decorated({"sites": sites}) + err = excinfo.value + assert err.completed_chunks == 2 + pre_resume_count = len(fetched_sites) + assert pre_resume_count == 2 # chunks 0 and 1 completed + + # Resume: re-issues only the still-pending sub-requests. + df, _ = err.call.resume() + + # Three more fetches happened on resume (chunks 2, 3, 4); chunks 0 + # and 1 were not re-fetched. + assert len(fetched_sites) - pre_resume_count == 3, ( + f"expected 3 new fetches on resume (chunks 2, 3, 4); got " + f"{len(fetched_sites) - pre_resume_count}" + ) + # Every original site appears in the combined frame exactly once. + assert sorted(df["sites"].tolist()) == sorted(sites) + + +def test_quota_exhausted_resume_can_reraise_on_persistent_429(): + """If the window is still empty when the caller resumes, + ``call.resume()`` raises ``QuotaExhausted`` again — the + ``ChunkedCall``'s in-flight state carries forward, so a + subsequent resume after a longer wait still picks up cleanly.""" + state = {"attempts": 0} + + def fetch(args): + i = state["attempts"] + state["attempts"] += 1 + # First attempt 429s on chunk 2. Resume attempt 429s on what + # would be chunk 2 again (still the first un-completed + # sub-request). + if i == 2 or i == 3: + raise RateLimited("429: Too many requests made.") + return ( + pd.DataFrame({"i": [i], "sites": list(args["sites"])}), + _quota_response(500), + ) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + sites = ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10] + + with pytest.raises(QuotaExhausted) as first: + decorated({"sites": sites}) + with pytest.raises(QuotaExhausted) as second: + first.value.call.resume() + + # Both exceptions report the same completed_chunks count — the + # second resume didn't make progress (it 429'd on the same chunk). + assert first.value.completed_chunks == 2 + assert second.value.completed_chunks == 2 + + +def test_resume_produces_dataset_identical_to_uninterrupted_run(): + """End-to-end resume equivalence: the same chunked query run twice + — once straight through, once with a mid-stream 429 + + ``call.resume()`` — must yield byte-identical combined frames. + Guards against off-by-one errors in the resume cursor (re-fetching + the chunk that 429'd, or skipping past it) and any ordering drift + ``_combine_chunk_frames`` might introduce when its input list is + built incrementally.""" + + def make_fetch(rate_limit_at_call: int | None): + """Build a fresh fetch_once whose Nth call raises ``RateLimited`` + (once) and whose every other call returns a deterministic frame + keyed by the sub-args's sites.""" + state = {"calls": 0, "tripped": False} + + def fetch(args): + state["calls"] += 1 + if state["calls"] == rate_limit_at_call and not state["tripped"]: + state["tripped"] = True + raise RateLimited("429: Too many requests made.") + sites = list(args["sites"]) + return ( + pd.DataFrame( + { + "id": sites, + "first_site": [sites[0]] * len(sites), + "chunk_size": [len(sites)] * len(sites), + } + ), + _quota_response(500), + ) + + return fetch + + # 16 sites at url_limit=240 forces several chunks; the chunking + # plan is deterministic, so both runs traverse the same sub-args + # sequence. + sites = ["S" * 10 + str(i) for i in range(16)] + + # Run A: uninterrupted. + fetch_a = make_fetch(rate_limit_at_call=None) + decorated_a = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch_a) + df_a, _ = decorated_a({"sites": sites}) + + # Run B: trigger 429 on the third sub-request, then resume. + fetch_b = make_fetch(rate_limit_at_call=3) + decorated_b = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch_b) + with pytest.raises(QuotaExhausted) as excinfo: + decorated_b({"sites": sites}) + # The 429 must hit mid-stream — otherwise the test isn't exercising + # what we think it is. + assert 0 < excinfo.value.completed_chunks < excinfo.value.total_chunks + df_b, _ = excinfo.value.call.resume() + + # Sanity: both runs must have actually chunked (otherwise the + # 429-mid-stream branch wasn't exercised). + assert excinfo.value.total_chunks > 1 + + # The combined DataFrames must be byte-identical: same rows in the + # same order, same dtypes. ``check_like=False`` keeps row-order + # comparison strict so a permutation introduced by the resume path + # would still fail. + pd.testing.assert_frame_equal(df_a, df_b) + + # And every original site must be present exactly once. + assert sorted(df_a["id"].tolist()) == sorted(sites) + + +def test_chunker_passes_through_non_429_runtime_error(): + """A non-429 ``RuntimeError`` (e.g. a 500) is not a quota signal; + it must propagate unchanged so callers see the real cause.""" + state = {"i": 0} + + def fetch(args): + i = state["i"] + state["i"] += 1 + if i == 2: + raise RuntimeError("500: Internal server error.") + return ( + pd.DataFrame({"sites": list(args["sites"])}), + _quota_response(500), + ) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + with pytest.raises(RuntimeError, match=r"^500:"): + decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10]}) + + +def test_chunker_wraps_service_unavailable_as_resumable(): + """A typed ``ServiceUnavailable`` (HTTP 5xx) is a transient + transport failure: ``ChunkedCall`` must wrap it as + ``ServiceInterrupted`` carrying the partial state, parallel to how + a 429 becomes ``QuotaExhausted``. Once the upstream recovers, + ``.call.resume()`` resumes only the still-pending sub-requests.""" + state = {"i": 0, "blow_up": True} + + def fetch(args): + i = state["i"] + state["i"] += 1 + if i == 2 and state["blow_up"]: + try: + raise ServiceUnavailable("503: Service unavailable.") + except ServiceUnavailable as cause: + raise RuntimeError(str(cause)) from cause + return ( + pd.DataFrame({"sites": list(args["sites"])}), + _quota_response(500), + ) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + with pytest.raises(ServiceInterrupted) as excinfo: + decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10]}) + err = excinfo.value + # Resumable: handle on .call with already-completed work preserved. + assert err.call is not None + assert err.completed_chunks == 2 + assert err.total_chunks == 5 + assert not err.call.partial_frame.empty + # Upstream recovers; resuming completes the call. + state["blow_up"] = False + df, _ = err.call.resume() + assert set(df["sites"]) == {"S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10} + + +def test_chunk_interrupted_base_class_catches_both(): + """``ChunkInterrupted`` is the common base for 429/5xx + interruptions, so callers who want one retry policy across all + transient failures can catch the base class. ``QuotaExhausted`` + and ``ServiceInterrupted`` must both subclass it.""" + assert issubclass(QuotaExhausted, ChunkInterrupted) + assert issubclass(ServiceInterrupted, ChunkInterrupted) + # Sanity: ``ChunkInterrupted`` is itself a ``RuntimeError`` so + # bare ``except RuntimeError`` callers don't suddenly miss the + # wrapped failures after this refactor. + assert issubclass(ChunkInterrupted, RuntimeError) + + +def test_connection_error_wrapped_as_service_interrupted(): + """A bare ``requests.exceptions.ConnectionError`` (or any other + transport-level RequestException) doesn't inherit from + ``RuntimeError``; without the widened catch in ``_issue`` it + would escape uncaught and the user would lose the resumable + handle to ``.call.resume()``. Verify ``ChunkedCall`` wraps it as + ``ServiceInterrupted`` so partial progress is preserved.""" + import requests as _requests + + state = {"i": 0, "blow_up": True} + + def fetch(args): + i = state["i"] + state["i"] += 1 + if i == 2 and state["blow_up"]: + raise _requests.exceptions.ConnectionError("connection reset") + return ( + pd.DataFrame({"sites": list(args["sites"])}), + _quota_response(500), + ) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + with pytest.raises(ServiceInterrupted) as excinfo: + decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10]}) + + err = excinfo.value + assert err.completed_chunks == 2 + assert err.call is not None + # The transport exception is on __cause__ so callers can drill in if needed. + assert isinstance(err.__cause__, _requests.exceptions.ConnectionError) + # Resume after the upstream recovers. + state["blow_up"] = False + df, _ = err.call.resume() + assert set(df["sites"]) == {"S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10} + + +def test_service_interrupted_exposes_partial_frame_and_response(): + """Both ``QuotaExhausted`` AND ``ServiceInterrupted`` carry + ``partial_frame`` / ``partial_response`` directly on the + exception. Previously only ``QuotaExhausted`` had them, so a + generic ``except ChunkInterrupted as exc: log(exc.partial_frame)`` + crashed with AttributeError on 5xx.""" + state = {"i": 0} + + def fetch(args): + i = state["i"] + state["i"] += 1 + if i == 2: + try: + raise ServiceUnavailable("503: Service unavailable.") + except ServiceUnavailable as cause: + raise RuntimeError(str(cause)) from cause + return ( + pd.DataFrame({"sites": list(args["sites"])}), + _quota_response(500), + ) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + with pytest.raises(ServiceInterrupted) as excinfo: + decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10]}) + err = excinfo.value + # Direct attribute access works for both subclasses now. + assert hasattr(err, "partial_frame") + assert hasattr(err, "partial_response") + assert not err.partial_frame.empty + assert err.partial_response is not None + + +def test_partial_frame_snapshot_stable_across_resume(): + """``exc.partial_frame`` / ``exc.partial_response`` snapshot the + state at raise time. Calling ``exc.call.resume()`` advances the + underlying ``ChunkedCall`` but must NOT mutate the snapshot on + the exception — otherwise a diagnostic that reads + ``exc.partial_frame`` after a resume sees post-resume state under + a name that promises pre-resume state.""" + state = {"i": 0, "blow_up": True} + + def fetch(args): + i = state["i"] + state["i"] += 1 + if i == 2 and state["blow_up"]: + raise RateLimited("429: Too many requests.") + return ( + pd.DataFrame({"sites": list(args["sites"])}), + _quota_response(500), + ) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + with pytest.raises(QuotaExhausted) as excinfo: + decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10]}) + err = excinfo.value + snapshot_rows = len(err.partial_frame) + assert snapshot_rows > 0 # two chunks worth of data captured + + # Resume; the live view on .call grows. + state["blow_up"] = False + err.call.resume() + assert len(err.call.partial_frame) > snapshot_rows + + # The exception's snapshot must NOT advance. + assert len(err.partial_frame) == snapshot_rows + + +def test_partial_frame_snapshot_is_a_copy_when_single_chunk(): + """``_combine_chunk_frames`` returns ``non_empty[0]`` verbatim on + its single-frame fast path. ``ChunkInterrupted.__init__`` must + therefore defensively ``.copy()`` so an in-place mutation of the + underlying chunk frame (e.g. user diagnostic code adding a + column on the live view) doesn't leak through the snapshot. + Companion to ``test_partial_frame_snapshot_stable_across_resume``, + which uses ≥2 completed chunks and so goes through + ``pd.concat`` (which already produces a fresh frame).""" + state = {"i": 0, "blow_up": True} + + def fetch(args): + i = state["i"] + state["i"] += 1 + if i == 1 and state["blow_up"]: + raise RateLimited("429: Too many requests.") + return ( + pd.DataFrame({"sites": list(args["sites"])}), + _quota_response(500), + ) + + # 4 sites at url_limit=240 → 2 sub-requests. The 429 fires on the + # SECOND sub-request, so the exception captures exactly ONE + # completed chunk — the path where _combine_chunk_frames aliases. + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + with pytest.raises(QuotaExhausted) as excinfo: + decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + err = excinfo.value + assert err.completed_chunks == 1 + + snapshot_cols = list(err.partial_frame.columns) + # Mutate the underlying chunk in place — the snapshot must NOT + # reflect the mutation. + err.call._chunks[0][0]["extra"] = 0 + assert list(err.partial_frame.columns) == snapshot_cols + assert "extra" not in err.partial_frame.columns + + +def test_combine_chunk_responses_returns_independent_headers(): + """The aggregated response's ``.headers`` must be a fresh + ``CaseInsensitiveDict`` — mutations by downstream callers + (logging hooks, metadata extensions) must not back-propagate into + the underlying chunk response's headers, which still live on + ``ChunkedCall._chunks``.""" + from dataretrieval.waterdata.chunking import _combine_chunk_responses + + r0 = mock.Mock( + elapsed=datetime.timedelta(seconds=0.1), headers={"X-Foo": "0"}, url="u0" + ) + r1 = mock.Mock( + elapsed=datetime.timedelta(seconds=0.2), headers={"X-Foo": "1"}, url="u1" + ) + head = _combine_chunk_responses([r0, r1], canonical_url=None) + + # Aggregate carries the last chunk's headers... + assert head.headers["X-Foo"] == "1" + # ...but mutating the aggregate must not back-propagate. + head.headers["X-Trace-Id"] = "abc" + assert "X-Trace-Id" not in r1.headers + assert "X-Trace-Id" not in r0.headers + + +def test_paginate_terminates_on_empty_string_cursor(): + """``_paginate``'s loop predicate is ``while cursor is not None``. + Parse-response wrappers in ``_walk_pages`` / ``get_stats_data`` + coerce falsy non-None values to None so an empty-string next- + cursor (a real-but-unusual end-of-stream sentinel some pagination + APIs use) doesn't trap us in an infinite ``follow_up('')`` loop.""" + import datetime as _dt + from unittest import mock as _mock + + import requests as _requests + + from dataretrieval.waterdata import utils as _utils + + # Synthesize an OGC response with numberReturned > 0 and a "next" + # link whose href is an empty string — simulating a server-side + # sentinel that ``_next_req_url`` reads as ``""``. + body_with_empty_next = { + "numberReturned": 1, + "features": [{"id": "1", "properties": {"val": "a"}}], + "links": [{"rel": "next", "href": ""}], + } + resp = _mock.MagicMock(spec=_requests.Response) + resp.status_code = 200 + resp.url = "https://example.com/items?limit=1" + resp.elapsed = _dt.timedelta(seconds=0.1) + resp.headers = {} + resp.json.return_value = body_with_empty_next + + client = _mock.MagicMock(spec=_requests.Session) + client.send.return_value = resp + + req = _mock.MagicMock(spec=_requests.PreparedRequest) + req.method = "GET" + req.headers = {} + req.body = None + req.url = "https://example.com/items?limit=1" + + df, final = _utils._walk_pages(geopd=False, req=req, client=client) + + # Single send + zero follow-ups: the loop terminated on the empty cursor. + assert client.send.called + assert not client.request.called + assert len(df) == 1 + + +def test_combine_chunk_frames_does_not_collapse_none_ids(): + """``drop_duplicates(subset='id')`` treats NaN==NaN as duplicate, + so a blanket dedup would collapse every id-less row into one — + silent data loss. The function must dedupe only the id-bearing + rows and preserve id-less rows verbatim.""" + import numpy as np + + from dataretrieval.waterdata.chunking import _combine_chunk_frames + + # Frame A has real ids; frame B has feature-IDs of None for two + # different rows that must both survive. + df_a = pd.DataFrame({"id": ["x", "y"], "val": [1, 2]}) + df_b = pd.DataFrame({"id": [np.nan, np.nan], "val": [3, 4]}) + combined = _combine_chunk_frames([df_a, df_b]) + + # 4 rows preserved: 2 id-bearing + 2 id-less (NaN rows NOT merged). + assert len(combined) == 4 + assert sorted(combined["val"].tolist()) == [1, 2, 3, 4] + + +def test_combine_chunk_frames_still_dedupes_overlapping_ids(): + """The original dedup contract — overlapping OR-clause partitions + that produce duplicate-id rows across chunks must still collapse + to one row — has to keep working when ids ARE present.""" + from dataretrieval.waterdata.chunking import _combine_chunk_frames + + df_a = pd.DataFrame({"id": ["x", "y"], "val": [1, 2]}) + df_b = pd.DataFrame({"id": ["y", "z"], "val": [2, 3]}) + combined = _combine_chunk_frames([df_a, df_b]) + assert sorted(combined["id"].tolist()) == ["x", "y", "z"] + + +def test_retry_after_surfaces_on_quota_exhausted(): + """If the 429 response includes a ``Retry-After`` header, that + delay must travel from the typed transport exception + (``RateLimited.retry_after``) onto ``QuotaExhausted`` so callers + can honor the server's hint instead of guessing a wait.""" + state = {"i": 0} + + def fetch(args): + state["i"] += 1 + if state["i"] >= 3: + try: + raise RateLimited("429: Too many requests.", retry_after=42.0) + except RateLimited as cause: + raise RuntimeError(str(cause)) from cause + return ( + pd.DataFrame({"sites": list(args["sites"])}), + _quota_response(500), + ) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + with pytest.raises(QuotaExhausted) as excinfo: + decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10]}) + assert excinfo.value.retry_after == 42.0 + + +def test_quota_exhausted_message_points_at_resume(): + """The error message must surface the chunk offset and the resume + affordance — ``partial_frame`` is a footgun without it.""" + e = QuotaExhausted( + completed_chunks=7, + total_chunks=20, + ) + msg = str(e) + assert "7/20" in msg + assert "429" in msg + assert ".call.resume()" in msg + + +def test_request_bytes_rejects_non_sizable_body(): + """``_request_bytes`` requires a deterministic byte count up front; + silently treating an unknown body as zero would under-chunk and let + the request blow past the server's POST-body limit. Generators, + iterables, and file-like objects must surface as ``TypeError``.""" + from dataretrieval.waterdata.chunking import _request_bytes + + class _FakeReqWithGenBody: + url = "https://example.com/foo" + body = (b"x" for _ in range(3)) + + with pytest.raises(TypeError, match="cannot size a request body"): + _request_bytes(_FakeReqWithGenBody()) + + +def test_request_bytes_handles_supported_body_types(): + """Sanity-check the supported body types: None (GET), bytes (raw + POST), str (JSON-as-string POST).""" + from dataretrieval.waterdata.chunking import _request_bytes + + class _Req: + def __init__(self, url, body): + self.url = url + self.body = body + + assert _request_bytes(_Req("ab", None)) == 2 + assert _request_bytes(_Req("ab", b"cd")) == 4 + assert _request_bytes(_Req("ab", "cd")) == 4 + assert _request_bytes(_Req("ab", bytearray(b"cd"))) == 4 + + +def test_multi_value_chunked_restores_canonical_url(): + """When chunking fans out, the aggregated response's ``.url`` must + reflect the *user's original* query (rebuilt from the unchunked + args), not the first chunk's URL. Callers logging ``md.url`` for + reproducibility need the full query.""" + sites = ["S" * 10 + str(i) for i in range(4)] + sub_urls: list[str] = [] + + @multi_value_chunked(build_request=_fake_build, url_limit=240) + def fetch(args): + # Each sub-response carries the chunked sub_args's URL, so + # without canonical restoration the first chunk's URL would + # leak through to md.url. + sub_url = _fake_build(**args).url + sub_urls.append(sub_url) + resp = mock.Mock(elapsed=datetime.timedelta(seconds=0.1)) + resp.headers = {} + resp.url = sub_url + return pd.DataFrame(), resp + + _df, md = fetch({"sites": sites}) + + assert len(sub_urls) > 1, "test setup error: chunker didn't fan out" + # md.url must equal the URL the unchunked query would have produced. + assert md.url == _fake_build(sites=sites).url + # And differ from every sub-request's URL (each carries a smaller list). + assert all(md.url != u for u in sub_urls) + # The canonical URL is strictly bigger byte-wise than any sub-request. + assert all(len(md.url) > len(u) for u in sub_urls) + + +def test_extract_axes_skips_filter_passed_as_list(): + """Defensive guard: ``filter`` is documented as a string. If a caller + mistakenly passes it as a list, ``_extract_axes`` must NOT create a + comma-joined list axis for it — comma-joining CQL clauses inside + the URL would produce a malformed filter expression. The filter + axis is built only via top-level-OR splitting of the string form.""" + args = { + "monitoring_location_id": ["USGS-A", "USGS-B"], + "filter": ["a='1'", "a='2'"], # malformed input + "filter_lang": ["cql-text", "cql-json"], # ditto + } + keys = {ax.arg_key for ax in _extract_axes(args)} + assert keys == {"monitoring_location_id"} + + +def test_extract_axes_skips_scalar_contract_params(): + """``limit`` and ``skip_geometry`` are scalars by contract + (``int | None`` and ``bool | None`` respectively). If a caller smuggles + a list through type erasure (e.g. ``limit=["100","200"]`` after a + bad cast), ``_extract_axes`` must NOT treat it as a multi-value + axis. Chunking ``limit`` would silently fan into separate + paginated queries with different per-request caps; chunking + ``skip_geometry`` would emit sub-requests with conflicting + geometry-output settings.""" + args = { + "monitoring_location_id": ["USGS-A", "USGS-B"], + "limit": ["100", "200"], + "skip_geometry": ["true", "false"], + } + keys = {ax.arg_key for ax in _extract_axes(args)} + assert keys == {"monitoring_location_id"} + + +def test_joint_planner_url_construction_long_filter_and_long_sites(): + """Realistic stress: 20 datetime OR-clauses combined with 100 USGS + site IDs. Every sub-request URL built from the plan must fit the + 8000-byte limit, the joint planner must beat the naive "filter at + bail-floor, chunk lists" approach, and the partitioned filters + must union to the user's original filter expression. + + Uses the real ``_construct_api_requests`` builder so the test + catches URL-encoding surprises that a fake builder would miss. + """ + # Realistic AGENCY-ID site format: USGS-{8 digits}. 500 sites is + # enough to force the URL well past the 8000-byte server limit + # without any filter contribution. + sites = [f"USGS-{i:08d}" for i in range(500)] + # 20 datetime equality clauses; each ~30 bytes raw, more after URL + # encoding (the apostrophes and `:` characters expand). + clauses = [ + f"time='2024-{m:02d}-{d:02d}T00:00:00Z'" + for m in range(1, 6) + for d in (1, 8, 15, 22) + ] + assert len(clauses) == 20 + filter_expr = " OR ".join(clauses) + + args = { + "service": "daily", + "monitoring_location_id": sites, + "filter": filter_expr, + } + url_limit = 8000 + + plan = ChunkPlan(args, _construct_api_requests, url_limit) + assert plan.total > 1, "expected non-trivial plan for over-limit request" + + # Walk every sub-request the plan would issue and assert URL fits. + over_limit = [] + for sub_args in plan.iter_sub_args(): + req = _construct_api_requests(**sub_args) + url_len = len(req.url) + (len(req.body) if req.body else 0) + if url_len > url_limit: + over_limit.append((url_len, sub_args)) + assert not over_limit, ( + f"{len(over_limit)} sub-request(s) exceeded the URL limit; " + f"first: {over_limit[0]}" + ) + + # Each axis's chunks must union back to its original atoms exactly + # once — no clause or site dropped, no duplicates introduced. + for axis in plan.axes: + seen = [a for chunk in plan.chunks[axis.arg_key] for a in chunk] + assert sorted(seen) == sorted(axis.atoms), ( + f"axis {axis.arg_key} partition lost or duplicated atoms" + ) + + # Plan must beat the bail-floor-style worst case (singleton sites + # × all filter clauses singleton = 500 * 20 = 10,000) — uniform + # greedy halving of these inputs cuts that by at least 20×. + assert plan.total < 500, ( + f"joint plan emitted {plan.total} sub-requests (expected <500)" + ) + + +def test_combine_chunk_frames_all_empty_preserves_geo_type(): + """Regression: when every chunk returns an empty frame, + ``_combine_chunk_frames`` must not downgrade an empty + ``GeoDataFrame`` to a plain ``DataFrame``. The whole reason the + function drops empties before concat is to prevent that downgrade + — the all-empty short-circuit was independently dropping it.""" + pytest.importorskip("geopandas") + import geopandas as gpd + + from dataretrieval.waterdata.chunking import _combine_chunk_frames + + empty_gdfs = [gpd.GeoDataFrame() for _ in range(3)] + combined = _combine_chunk_frames(empty_gdfs) + assert isinstance(combined, gpd.GeoDataFrame), ( + f"all-empty combine returned {type(combined).__name__}; expected GeoDataFrame" + ) + + +def test_combine_chunk_frames_single_frame_is_safe_to_mutate(): + """Regression: the single-completed-chunk fast path returned the + underlying chunk frame verbatim, so a caller mutating + ``call.partial_frame`` (documented as a live view) would mutate + ``_chunks[0][0]`` in place. The fast path now returns a copy.""" + from dataretrieval.waterdata.chunking import _combine_chunk_frames + + chunk = pd.DataFrame({"id": ["A", "B"], "value": [1, 2]}) + returned = _combine_chunk_frames([chunk]) + returned["new_col"] = "x" + assert "new_col" not in chunk.columns + + +def test_iter_sub_args_passthrough_yields_a_copy(): + """Regression: the no-axes passthrough yielded ``self.args`` + directly while the chunked branch did ``dict(self.args)``. A + ``fetch_once`` that mutated the dict it received would silently + corrupt ``ChunkPlan.args``. The passthrough now copies too.""" + args = {"monitoring_location_id": ["USGS-A"], "limit": 100} + plan = ChunkPlan(args, _fake_build, url_limit=8000) + sub = next(plan.iter_sub_args()) + sub["monitoring_location_id"] = "mutated" + sub["new_key"] = "leaked" + assert plan.args["monitoring_location_id"] == ["USGS-A"] + assert "new_key" not in plan.args + + +def test_quota_check_fires_after_every_chunk_not_just_first(): + """Regression: ``_check_quota_after_first`` was gated on + ``len(_chunks) == 1`` so it only fired after chunk 0; a concurrent + caller draining the window mid-call (or a partially-rolled-over + quota on resume) went undetected. The check now fires after every + non-final chunk.""" + # 4-chunk plan. Chunks 0 and 1 report plenty of remaining quota; + # chunk 2's response reports remaining=0 with one chunk still + # pending. The check must fire after chunk 2, NOT silently let + # chunk 3 hit a mid-stream 429. + responses = iter([500, 500, 0]) + calls: list[dict] = [] + + def fetch(args): + calls.append(args) + return pd.DataFrame({"sites": list(args["sites"])}), _quota_response( + next(responses) + ) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + with pytest.raises(RequestExceedsQuota) as excinfo: + decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + err = excinfo.value + assert err.planned_chunks == 4 + # 3 completed + 0 remaining = 3 available; 1 pending; deficit 1. + assert err.available == 3 + assert err.deficit == 1 + assert len(calls) == 3, "only chunks 0-2 should have been issued" + # .call carries the in-flight call so the user can recover. + assert err.call is not None + assert err.call.completed_chunks == 3 diff --git a/tests/waterdata_filters_test.py b/tests/waterdata_filters_test.py index 545f7039..9d9d183e 100644 --- a/tests/waterdata_filters_test.py +++ b/tests/waterdata_filters_test.py @@ -7,11 +7,7 @@ import pytest from dataretrieval.waterdata.filters import ( - _CQL_FILTER_CHUNK_LEN, - _WATERDATA_URL_BYTE_LIMIT, _check_numeric_filter_pitfall, - _chunk_cql_or, - _effective_filter_budget, _split_top_level_or, ) from dataretrieval.waterdata.utils import _construct_api_requests @@ -35,11 +31,6 @@ def _fake_response(url="https://example.test", elapsed_ms=1): ) -def _build_request(**kwargs): - """Wrapper that matches the ``build_request`` callable shape.""" - return _construct_api_requests(**kwargs) - - def test_construct_filter_passthrough(): """`filter` is forwarded verbatim as a query parameter.""" expr = ( @@ -113,35 +104,6 @@ def test_split_top_level_or_single_clause(): ] -def test_chunk_cql_or_short_passthrough(): - expr = "time >= '2023-01-01T00:00:00Z'" - assert _chunk_cql_or(expr, max_len=1000) == [expr] - - -def test_chunk_cql_or_splits_into_multiple(): - clause = "(time >= '2023-01-01T00:00:00Z' AND time <= '2023-01-01T00:30:00Z')" - expr = " OR ".join([clause] * 200) - chunks = _chunk_cql_or(expr, max_len=1000) - # each chunk must be under the budget - assert all(len(c) <= 1000 for c in chunks) - # rejoined chunks must cover every clause - rejoined_clauses = sum(len(c.split(" OR ")) for c in chunks) - assert rejoined_clauses == 200 - # and must be a valid OR chain (each chunk is itself a top-level OR of clauses) - assert len(chunks) > 1 - - -def test_chunk_cql_or_unsplittable_returns_input(): - big = "value > 0 AND " + ("A " * 4000) - assert _chunk_cql_or(big, max_len=1000) == [big] - - -def test_chunk_cql_or_single_clause_over_budget_returns_input(): - huge_clause = "(value > " + "9" * 6000 + ")" - expr = f"{huge_clause} OR (value > 0)" - assert _chunk_cql_or(expr, max_len=1000) == [expr] - - @pytest.mark.parametrize( "service", [ @@ -167,41 +129,47 @@ def test_construct_filter_on_all_ogc_services(service): assert qs["filter-lang"] == ["cql-text"] -def test_long_filter_fans_out_into_multiple_requests(): - """An oversized top-level OR filter triggers multiple HTTP requests - whose results are concatenated.""" - from dataretrieval.waterdata import get_continuous - +def _filter_chunking_clauses(n: int = 300) -> str: + """Stock long filter used by the end-to-end fan-out tests below.""" clause = ( "(time >= '2023-01-{day:02d}T00:00:00Z' " "AND time <= '2023-01-{day:02d}T00:30:00Z')" ) - expr = " OR ".join(clause.format(day=(i % 28) + 1) for i in range(300)) - assert len(expr) > _CQL_FILTER_CHUNK_LEN + return " OR ".join(clause.format(day=(i % 28) + 1) for i in range(n)) - sent_filters = [] - def fake_construct_api_requests(**kwargs): - sent_filters.append(kwargs.get("filter")) - return _fake_prepared_request() +def _filter_size_aware_build(**kwargs): + """Fake ``_construct_api_requests`` whose returned URL length scales + with the request's ``filter`` value, so the joint planner naturally + triggers chunking on long filters.""" + return _fake_prepared_request( + url=f"https://example.test/?filter={kwargs.get('filter', '')}", + ) - def fake_walk_pages(*_args, **_kwargs): + +def test_long_filter_fans_out_into_multiple_requests(): + """An oversized top-level OR filter triggers multiple HTTP + sub-requests via the joint planner; every original clause is + preserved across sub-requests; results concatenate to one row per + sub-request given the one-row-per-chunk mock.""" + from dataretrieval.waterdata import get_continuous + + expr = _filter_chunking_clauses() + sent_filters: list[str] = [] + + def fake_walk_pages(*, geopd, req): idx = len(sent_filters) - frame = pd.DataFrame({"id": [f"chunk-{idx}"], "value": [idx]}) - return frame, _fake_response() + sent_filters.append(_query_params(req).get("filter", [None])[0]) + return pd.DataFrame({"id": [f"chunk-{idx}"], "value": [idx]}), _fake_response() with ( mock.patch( "dataretrieval.waterdata.utils._construct_api_requests", - side_effect=fake_construct_api_requests, + side_effect=_filter_size_aware_build, ), mock.patch( "dataretrieval.waterdata.utils._walk_pages", side_effect=fake_walk_pages ), - mock.patch( - "dataretrieval.waterdata.filters._effective_filter_budget", - return_value=_CQL_FILTER_CHUNK_LEN, - ), ): df, _ = get_continuous( monitoring_location_id="USGS-07374525", @@ -210,51 +178,38 @@ def fake_walk_pages(*_args, **_kwargs): filter_lang="cql-text", ) - # Mocking _effective_filter_budget bypasses the URL-length probe, so - # sent_filters contains only real chunk requests. Assert invariants: - # chunking happened, every original clause is preserved exactly once - # in order, each chunk stays under the budget, and the mock's - # one-row-per-chunk responses concatenate to a row per chunk. expected_parts = _split_top_level_or(expr) assert len(sent_filters) > 1 - rejoined_parts = [] + rejoined_parts: list[str] = [] for chunk in sent_filters: rejoined_parts.extend(_split_top_level_or(chunk)) assert rejoined_parts == expected_parts assert len(df) == len(sent_filters) - assert all(len(chunk) <= _CQL_FILTER_CHUNK_LEN for chunk in sent_filters) def test_long_filter_deduplicates_cross_chunk_overlap(): - """Features returned by multiple chunks (same feature `id`) are - deduplicated in the concatenated result.""" + """Features returned by multiple sub-requests with the same ``id`` + are deduplicated in the concatenated result.""" from dataretrieval.waterdata import get_continuous - clause = ( - "(time >= '2023-01-{day:02d}T00:00:00Z' " - "AND time <= '2023-01-{day:02d}T00:30:00Z')" - ) - expr = " OR ".join(clause.format(day=(i % 28) + 1) for i in range(300)) - + expr = _filter_chunking_clauses() call_count = {"n": 0} def fake_walk_pages(*_args, **_kwargs): call_count["n"] += 1 - frame = pd.DataFrame({"id": ["shared-feature"], "value": [1]}) - return frame, _fake_response() + return ( + pd.DataFrame({"id": ["shared-feature"], "value": [1]}), + _fake_response(), + ) with ( mock.patch( "dataretrieval.waterdata.utils._construct_api_requests", - return_value=_fake_prepared_request(), + side_effect=_filter_size_aware_build, ), mock.patch( "dataretrieval.waterdata.utils._walk_pages", side_effect=fake_walk_pages ), - mock.patch( - "dataretrieval.waterdata.filters._effective_filter_budget", - return_value=_CQL_FILTER_CHUNK_LEN, - ), ): df, _ = get_continuous( monitoring_location_id="USGS-07374525", @@ -263,56 +218,46 @@ def fake_walk_pages(*_args, **_kwargs): filter_lang="cql-text", ) - # Chunking must have happened (otherwise dedup wouldn't be exercised). - assert call_count["n"] > 1 - # Even though each chunk returned a feature, dedup by id collapses them. - assert len(df) == 1 + assert call_count["n"] > 1 # chunking must have happened + assert len(df) == 1 # dedup by ``id`` collapses the duplicates def test_empty_chunks_do_not_downgrade_geodataframe(): - """A mix of empty and non-empty chunk responses must not downgrade a - GeoDataFrame-typed result to a plain DataFrame. ``_get_resp_data`` - returns ``pd.DataFrame()`` on empty responses, which would otherwise - strip geometry/CRS from the concatenated output.""" + """A mix of empty and non-empty sub-request responses must not + downgrade a GeoDataFrame-typed result to a plain DataFrame. + ``_get_resp_data`` returns ``pd.DataFrame()`` on empty responses, + which would otherwise strip geometry/CRS from the concatenated + output.""" pytest.importorskip("geopandas") import geopandas as gpd from shapely.geometry import Point from dataretrieval.waterdata import get_continuous - clause = ( - "(time >= '2023-01-{day:02d}T00:00:00Z' " - "AND time <= '2023-01-{day:02d}T00:30:00Z')" - ) - expr = " OR ".join(clause.format(day=(i % 28) + 1) for i in range(300)) - + expr = _filter_chunking_clauses() call_count = {"n": 0} def fake_walk_pages(*_args, **_kwargs): call_count["n"] += 1 - # Chunk 2 returns empty; chunks 1 and 3 return GeoDataFrames. if call_count["n"] == 2: - frame = pd.DataFrame() - else: - frame = gpd.GeoDataFrame( + return pd.DataFrame(), _fake_response() + return ( + gpd.GeoDataFrame( {"id": [f"feat-{call_count['n']}"], "value": [call_count["n"]]}, geometry=[Point(call_count["n"], call_count["n"])], crs="EPSG:4326", - ) - return frame, _fake_response() + ), + _fake_response(), + ) with ( mock.patch( "dataretrieval.waterdata.utils._construct_api_requests", - return_value=_fake_prepared_request(), + side_effect=_filter_size_aware_build, ), mock.patch( "dataretrieval.waterdata.utils._walk_pages", side_effect=fake_walk_pages ), - mock.patch( - "dataretrieval.waterdata.filters._effective_filter_budget", - return_value=_CQL_FILTER_CHUNK_LEN, - ), ): df, _ = get_continuous( monitoring_location_id="USGS-07374525", @@ -321,119 +266,11 @@ def fake_walk_pages(*_args, **_kwargs): filter_lang="cql-text", ) - # The empty chunk must not have stripped the GeoDataFrame type. assert isinstance(df, gpd.GeoDataFrame) assert "geometry" in df.columns assert df.crs is not None -def test_effective_filter_budget_respects_url_limit(): - """The computed budget, once encoded, fits within the URL byte limit - alongside the other query params.""" - from urllib.parse import quote_plus - - filter_expr = "(time >= '2023-01-15T00:00:00Z' AND time <= '2023-01-15T00:30:00Z')" - args = { - "service": "continuous", - "monitoring_location_id": "USGS-02238500", - "parameter_code": "00060", - "filter": filter_expr, - "filter_lang": "cql-text", - } - raw_budget = _effective_filter_budget(args, filter_expr, _build_request) - - # Build a chunk exactly at the raw budget (padded with the clause repeated) - # and confirm the full URL it produces stays under the URL byte limit. - padded = (" OR ".join([filter_expr] * 200))[:raw_budget] - req = _construct_api_requests(**{**args, "filter": padded}) - assert len(req.url) <= _WATERDATA_URL_BYTE_LIMIT - # And the budget scales inversely with encoding ratio (sanity). - assert raw_budget < _WATERDATA_URL_BYTE_LIMIT - # Quick sanity on the encoding math itself. - assert len(quote_plus(padded)) <= _WATERDATA_URL_BYTE_LIMIT - - -def test_effective_filter_budget_uses_max_clause_ratio(): - """Heavy clauses clustered in one part of the filter must not be able - to push any chunk over the URL limit. The budget is computed against - the max per-clause encoding ratio, not the whole-filter average, so - a chunk of only-heaviest-clauses still fits.""" - from urllib.parse import quote_plus - - heavy = ( - "(time >= '2023-01-15T00:00:00Z' AND time <= '2023-01-15T00:30:00Z' " - "AND approval_status IN ('Approved','Provisional','Revised'))" - ) - light = "(time >= '2023-01-15T00:00:00Z' AND time <= '2023-01-15T00:30:00Z')" - # Heavy ratio < light ratio for these shapes; cluster them at opposite - # ends so the chunker must produce at least one light-only chunk. - clauses = [heavy] * 100 + [light] * 400 - expr = " OR ".join(clauses) - args = { - "service": "continuous", - "monitoring_location_id": "USGS-02238500", - "filter": expr, - "filter_lang": "cql-text", - } - budget = _effective_filter_budget(args, expr, _build_request) - chunks = _chunk_cql_or(expr, max_len=budget) - assert len(chunks) > 1 - - # Every chunk, once built into a full request, fits under the URL byte - # limit — even the all-light chunks that have a higher-than-average ratio. - for chunk in chunks: - req = _construct_api_requests(**{**args, "filter": chunk}) - assert len(req.url) <= _WATERDATA_URL_BYTE_LIMIT, ( - f"chunk url {len(req.url)} exceeds {_WATERDATA_URL_BYTE_LIMIT}" - ) - - # Budget should be tight enough that a chunk of only-light clauses - # (the heavier-encoding shape here) still fits. - assert len(quote_plus(light)) * (budget // len(light)) < _WATERDATA_URL_BYTE_LIMIT - - -def test_effective_filter_budget_passes_through_when_no_url_space(): - """If the non-filter URL already exceeds the byte limit, chunking - cannot make the request succeed. The budget helper should signal - pass-through (return a budget larger than the filter) so - ``_chunk_cql_or`` emits one chunk — one 414 from the server is - clearer than a burst of N guaranteed-414 sub-requests.""" - expr = " OR ".join( - ["(time >= '2023-01-15T00:00:00Z' AND time <= '2023-01-15T00:30:00Z')"] * 50 - ) - fake_build = mock.Mock( - return_value=_fake_prepared_request(url="https://example.test/" + "A" * 9000) - ) - budget = _effective_filter_budget({"filter": expr}, expr, fake_build) - # Budget is large enough that _chunk_cql_or returns the expression - # unchanged (passthrough) rather than producing many small chunks. - assert budget > len(expr) - assert _chunk_cql_or(expr, max_len=budget) == [expr] - - -def test_effective_filter_budget_shrinks_with_more_url_params(): - """Adding more scalar query params consumes URL bytes and should - shrink the raw filter budget accordingly. Use a filter large enough - to skip the short-circuit fast path so the probe actually runs.""" - clause = "(time >= '2023-01-15T00:00:00Z' AND time <= '2023-01-15T00:30:00Z')" - expr = " OR ".join([clause] * 100) - sparse_args = { - "service": "continuous", - "monitoring_location_id": "USGS-02238500", - "filter": expr, - "filter_lang": "cql-text", - } - dense_args = { - **sparse_args, - "parameter_code": "00060", - "statistic_id": "00003", - "last_modified": "2023-01-01T00:00:00Z/2023-12-31T23:59:59Z", - } - sparse_budget = _effective_filter_budget(sparse_args, expr, _build_request) - dense_budget = _effective_filter_budget(dense_args, expr, _build_request) - assert dense_budget < sparse_budget - - def test_cql_json_filter_is_not_chunked(): """Chunking applies only to cql-text; cql-json is passed through unchanged.""" from dataretrieval.waterdata import get_continuous diff --git a/tests/waterdata_test.py b/tests/waterdata_test.py index 18e78594..24eb6eff 100644 --- a/tests/waterdata_test.py +++ b/tests/waterdata_test.py @@ -49,7 +49,7 @@ reruns=2, reruns_delay=5, only_rerun=[ - r"RuntimeError:\s*(?:429|5\d\d):", # _raise_for_non_200 output + r"(?:RateLimited|RuntimeError):\s*(?:429|5\d\d):", # _raise_for_non_200 output r"ConnectionError", r"ReadTimeout|ConnectTimeout|Timeout", ], @@ -609,24 +609,20 @@ def test_get_channel(): class TestCheckMonitoringLocationId: - """Tests for _check_monitoring_location_id input validation. + """Tests for the AGENCY-ID-specific layer over ``_normalize_str_iterable``. + + Generic type/iterable normalization is covered by + ``TestNormalizeStrIterable`` below; this suite holds only the format + check (``AGENCY-NUMBER`` shape) and the public-API integration smokes. Regression tests for GitHub issue #188. """ def test_valid_string(self): - """A correctly formatted string passes and is returned unchanged.""" + """Happy-path smoke: the wrapper still routes through normalization + for a well-formed AGENCY-ID string.""" assert _check_monitoring_location_id("USGS-01646500") == "USGS-01646500" - def test_valid_list(self): - """A list of correctly formatted strings passes without error.""" - ids = ["USGS-01646500", "USGS-02238500"] - assert _check_monitoring_location_id(ids) == ids - - def test_none_passes(self): - """None is allowed (optional parameter).""" - assert _check_monitoring_location_id(None) is None - def test_integer_raises_type_error(self): """An integer ID raises TypeError with a helpful AGENCY-ID hint.""" with pytest.raises(TypeError, match="not int") as exc_info: @@ -635,11 +631,6 @@ def test_integer_raises_type_error(self): # helper alone doesn't carry. assert "USGS-01646500" in str(exc_info.value) - def test_integer_in_list_raises_type_error(self): - """An integer inside a list raises TypeError.""" - with pytest.raises(TypeError, match="not int"): - _check_monitoring_location_id(["USGS-01646500", 5129115]) - def test_missing_agency_prefix_raises_value_error(self): """A string without the AGENCY- prefix raises ValueError.""" with pytest.raises(ValueError, match="Invalid monitoring_location_id"): @@ -655,57 +646,19 @@ def test_get_daily_integer_id_raises(self): with pytest.raises(TypeError): get_daily(monitoring_location_id=5129115, parameter_code="00060") - def test_tuple_normalizes_to_list(self): - """A tuple of valid strings is accepted and normalized to list.""" - result = _check_monitoring_location_id(("USGS-01646500", "USGS-02238500")) - assert result == ["USGS-01646500", "USGS-02238500"] - assert isinstance(result, list) - - def test_pandas_series_normalizes_to_list(self): - """A pandas.Series of valid strings is accepted and normalized to list.""" - s = pd.Series(["USGS-01646500", "USGS-02238500"]) - result = _check_monitoring_location_id(s) - assert result == ["USGS-01646500", "USGS-02238500"] - assert isinstance(result, list) - - def test_pandas_index_normalizes_to_list(self): - """A pandas.Index of valid strings is accepted and normalized to list.""" - idx = pd.Index(["USGS-01646500", "USGS-02238500"]) - result = _check_monitoring_location_id(idx) - assert result == ["USGS-01646500", "USGS-02238500"] - assert isinstance(result, list) - - def test_numpy_array_normalizes_to_list(self): - """A numpy.ndarray of valid strings is accepted and normalized to list.""" - import numpy as np - - arr = np.array(["USGS-01646500", "USGS-02238500"]) - result = _check_monitoring_location_id(arr) - assert result == ["USGS-01646500", "USGS-02238500"] - assert isinstance(result, list) - - def test_numpy_int_array_raises_type_error(self): - """An iterable whose elements aren't strings (numpy int array) raises.""" - import numpy as np - - with pytest.raises(TypeError, match="elements must be strings"): - _check_monitoring_location_id(np.array([1, 2, 3])) - - def test_pandas_series_of_ints_raises_type_error(self): - """An iterable whose elements aren't strings (Series of ints) raises.""" - with pytest.raises(TypeError, match="elements must be strings"): - _check_monitoring_location_id(pd.Series([1, 2, 3])) - - def test_dict_raises_type_error(self): - """Mappings are rejected — iterating a dict yields keys, which is a footgun.""" - with pytest.raises(TypeError, match="not dict"): - _check_monitoring_location_id({"USGS-01646500": "site"}) - def test_get_daily_malformed_id_raises(self): """get_daily raises ValueError for a malformed string ID.""" with pytest.raises(ValueError): get_daily(monitoring_location_id="dog", parameter_code="00060") + def test_per_item_format_check_in_list(self): + """The AGENCY-ID format check runs on EVERY element of an + iterable, not just the first. Regression guard against a + future ``_check_id_format`` loop that bails after one valid + item or only checks the head.""" + with pytest.raises(ValueError, match="Invalid monitoring_location_id"): + _check_monitoring_location_id(["USGS-01646500", "badformat"]) + class TestNormalizeStrIterable: """Tests for the generic _normalize_str_iterable helper. diff --git a/tests/waterdata_utils_test.py b/tests/waterdata_utils_test.py index 78868c38..c135115c 100644 --- a/tests/waterdata_utils_test.py +++ b/tests/waterdata_utils_test.py @@ -1,3 +1,4 @@ +import json import logging from unittest import mock @@ -6,12 +7,15 @@ import requests import dataretrieval.waterdata.utils as _utils_module +from dataretrieval.waterdata.chunking import RateLimited, ServiceUnavailable from dataretrieval.waterdata.utils import ( _arrange_cols, _error_body, _format_api_dates, _get_args, _handle_stats_nesting, + _parse_retry_after, + _raise_for_non_200, _walk_pages, ) @@ -189,6 +193,110 @@ def test_walk_pages_raises_on_mid_pagination_429(): assert "rate-limit window" in msg # 429-specific guidance present +def test_walk_pages_wraps_initial_page_parse_error(): + """A 200 response whose body fails to parse on the FIRST page used + to escape ``_walk_pages`` as a raw ``JSONDecodeError``, while the + SAME failure on a subsequent page was wrapped via + ``_paginated_failure_message``. The asymmetry meant operators got + different exception types for the same logical bug depending on + which page hit it. The initial-parse wrapper closes the gap.""" + resp = mock.MagicMock() + resp.status_code = 200 + resp.url = "https://example.com/page1" + # Body is unparseable JSON (gateway HTML page, truncated stream). + resp.json.side_effect = json.JSONDecodeError("Expecting value", "...", 0) + + mock_client = mock.MagicMock(spec=requests.Session) + mock_client.send.return_value = resp + + mock_req = mock.MagicMock(spec=requests.PreparedRequest) + mock_req.method = "GET" + mock_req.headers = {} + mock_req.url = "https://example.com/page1" + + with pytest.raises(RuntimeError, match="Paginated request failed") as excinfo: + _walk_pages(geopd=False, req=mock_req, client=mock_client) + + # The JSONDecodeError causing it is on __cause__ so callers can drill in. + assert isinstance(excinfo.value.__cause__, json.JSONDecodeError) + + +def test_get_resp_data_handles_missing_features_key(): + """Regression: a 200 with ``numberReturned > 0`` but no + ``features`` key (real schema-drift shape) used to crash + ``_get_resp_data`` with ``KeyError`` — wrapped downstream by + ``_paginate`` as a generic transport error. ``_handle_stats_nesting`` + was already hardened against this; ``_get_resp_data`` now mirrors + that defensiveness and returns an empty frame instead.""" + from dataretrieval.waterdata.utils import _get_resp_data + + resp = mock.Mock() + resp.json.return_value = {"numberReturned": 1, "links": []} + df = _get_resp_data(resp, geopd=False) + assert df.empty + assert isinstance(df, pd.DataFrame) + + +def test_walk_pages_does_not_mutate_initial_response(): + """The aggregated response returned from ``_walk_pages`` is built + via ``_aggregate_paginated_response``, which returns a fresh copy. + Any caller that inspected ``initial_response.headers`` / + ``.elapsed`` before pagination completed (a Session response hook, + a logging middleware) must continue to see the original first-page + values — NOT the rewritten cumulative values.""" + import datetime as _dt + + page1 = mock.MagicMock() + page1.status_code = 200 + page1.url = "https://example.com/page1" + page1.elapsed = _dt.timedelta(seconds=1) + page1.headers = {"x-ratelimit-remaining": "999"} + page1.json.return_value = { + "numberReturned": 1, + "features": [{"id": "1", "properties": {"val": "a"}}], + "links": [{"rel": "next", "href": "https://example.com/page2"}], + } + page1_initial_headers_id = id(page1.headers) + page1_initial_elapsed = page1.elapsed + + page2 = mock.MagicMock() + page2.status_code = 200 + page2.url = "https://example.com/page2" + page2.elapsed = _dt.timedelta(seconds=2) + page2.headers = {"x-ratelimit-remaining": "998"} + page2.json.return_value = { + "numberReturned": 1, + "features": [{"id": "2", "properties": {"val": "b"}}], + "links": [], + } + + mock_client = mock.MagicMock(spec=requests.Session) + mock_client.send.return_value = page1 + mock_client.request.return_value = page2 + + mock_req = mock.MagicMock(spec=requests.PreparedRequest) + mock_req.method = "GET" + mock_req.headers = {} + mock_req.url = "https://example.com/page1" + + df, final = _walk_pages(geopd=False, req=mock_req, client=mock_client) + assert len(df) == 2 + + # The original first-page response object must be unmutated: + # both .headers (same dict object) and .elapsed unchanged. + assert id(page1.headers) == page1_initial_headers_id + assert page1.headers["x-ratelimit-remaining"] == "999" + assert page1.elapsed == page1_initial_elapsed + + # The returned aggregate carries page-2 headers + cumulative elapsed. + assert final.headers["x-ratelimit-remaining"] == "998" + assert final.elapsed == _dt.timedelta(seconds=3) + # And mutating the aggregate's headers doesn't leak into either page. + final.headers["X-Trace-Id"] = "abc" + assert "X-Trace-Id" not in page1.headers + assert "X-Trace-Id" not in page2.headers + + def _stats_initial_ok(): """A 200-OK initial stats response: empty data list, signals one more page.""" resp = mock.MagicMock() @@ -231,8 +339,12 @@ def _run_get_stats_data_with_failure(failure_resp_or_exc, monkeypatch): ) -def test_get_stats_data_raises_on_connection_error_mid_pagination(monkeypatch): - """get_stats_data variant of the connection-error-raises contract.""" +def test_get_stats_data_raises_on_mid_pagination_failure(monkeypatch): + """Wiring smoke: ``get_stats_data`` and ``_walk_pages`` share the + same ``_paginate`` strategy helper, so error-routing behaviour is + exercised by the ``_walk_pages`` triplet above. This single + ``get_stats_data`` mid-pagination case proves the stats-specific + follow-up callback is wired into ``_paginate`` correctly.""" with pytest.raises(RuntimeError, match="Paginated request failed") as excinfo: _run_get_stats_data_with_failure( requests.ConnectionError("stats-boom"), @@ -243,34 +355,6 @@ def test_get_stats_data_raises_on_connection_error_mid_pagination(monkeypatch): assert "stats-boom" in str(excinfo.value) -def test_get_stats_data_raises_on_5xx_mid_pagination(monkeypatch): - """get_stats_data variant of the 5xx-raises contract.""" - page2_503 = mock.MagicMock() - page2_503.status_code = 503 - page2_503.json.return_value = { - "code": "ServiceUnavailable", - "description": "upstream timeout", - } - page2_503.url = "https://example.com/stats?service=foo&next_token=tok2" - - with pytest.raises(RuntimeError, match="Paginated request failed") as excinfo: - _run_get_stats_data_with_failure(page2_503, monkeypatch) - - assert "503" in str(excinfo.value) or "ServiceUnavailable" in str(excinfo.value) - - -def test_get_stats_data_raises_on_mid_pagination_429(monkeypatch): - """get_stats_data variant of the 429-raises contract.""" - page2_429 = mock.MagicMock() - page2_429.status_code = 429 - page2_429.url = "https://example.com/stats?service=foo&next_token=tok2" - - with pytest.raises(RuntimeError, match="Paginated request failed") as excinfo: - _run_get_stats_data_with_failure(page2_429, monkeypatch) - - assert "429" in str(excinfo.value) - - def test_get_stats_data_warning_includes_next_token(caplog, monkeypatch): """The pagination-failure warning includes the next_token so operators can identify which page in the sequence failed. (Addresses Copilot's @@ -322,6 +406,87 @@ def test_handle_stats_nesting_tolerates_missing_drop_columns(): assert df["monitoring_location_id"].iloc[0] == "USGS-12345" +def test_handle_stats_nesting_returns_empty_on_empty_features(): + """A mid-pagination empty page ({\"features\": [], \"next\": }) + must not crash the downstream merge with + ``KeyError: 'monitoring_location_id'``. The function short- + circuits to an empty DataFrame so pagination can continue.""" + df = _handle_stats_nesting({"features": [], "next": None}, geopd=False) + assert df.empty + + +def test_handle_stats_nesting_empty_preserves_geopd_type(): + """When geopandas is available, the empty-features short-circuit + must return a ``GeoDataFrame`` rather than a plain ``DataFrame``. + Otherwise a subsequent ``pd.concat([empty, geo_page])`` downgrades + the final result to a plain ``DataFrame`` and strips geometry/CRS + — a real regression for geopd-installed users on stats queries + that hit an empty intermediate page.""" + # Monkeypatch a stub gpd into the utils module so the test runs + # whether or not geopandas is actually installed. + fake_gpd = mock.MagicMock() + + class _Sentinel: + pass + + fake_gpd.GeoDataFrame = lambda *a, **kw: _Sentinel() + with mock.patch.object(_utils_module, "gpd", fake_gpd, create=True): + result = _handle_stats_nesting({"features": []}, geopd=True) + assert isinstance(result, _Sentinel) + + +def test_get_resp_data_empty_preserves_geopd_type(): + """Same as the stats-side preservation: ``_get_resp_data``'s + ``numberReturned == 0`` short-circuit must return a + ``GeoDataFrame`` (not a plain ``DataFrame``) when geopd is True, + so paginating across a sparse intermediate page doesn't downgrade + the final concat result.""" + from dataretrieval.waterdata.utils import _get_resp_data + + fake_gpd = mock.MagicMock() + + class _Sentinel: + pass + + fake_gpd.GeoDataFrame = lambda *a, **kw: _Sentinel() + + resp = mock.MagicMock() + resp.json.return_value = {"numberReturned": 0, "features": [], "links": []} + with mock.patch.object(_utils_module, "gpd", fake_gpd, create=True): + result = _get_resp_data(resp, geopd=True) + assert isinstance(result, _Sentinel) + + +def test_handle_stats_nesting_tolerates_missing_features_key(): + """A 200 response with a body that doesn't carry ``features`` at + all (rare but seen in error envelopes) must also short-circuit + rather than KeyError before the schema-aware extraction even + runs.""" + df = _handle_stats_nesting({}, geopd=False) + assert df.empty + + +def test_get_resp_data_always_materializes_id_column(): + """``_get_resp_data`` must always materialize the ``id`` column + (NaN-filled when no feature carries one) so the downstream + ``_arrange_cols`` rename to the service-specific output_id + (``daily_id``, ``channel_measurements_id``, etc.) isn't a + silent no-op.""" + from dataretrieval.waterdata.utils import _get_resp_data + + resp = mock.MagicMock() + resp.json.return_value = { + "numberReturned": 2, + "features": [ + {"properties": {"val": "a"}}, # no top-level id + {"properties": {"val": "b"}}, # ditto + ], + } + df = _get_resp_data(resp, geopd=False) + assert "id" in df.columns + assert df["id"].isna().all() + + # --- _arrange_cols ---------------------------------------------------------- @@ -489,3 +654,79 @@ def test_error_body_still_parses_well_formed_json(): assert "400" in msg assert "BadRequest" in msg assert "missing parameter" in msg + + +def test_parse_retry_after_handles_none_and_empty(): + """Absent or empty header → ``None`` (no quota signal). The chunker + treats ``None`` as "fall back to my own retry policy," so this + branch must not return a misleading 0.""" + assert _parse_retry_after(None) is None + assert _parse_retry_after("") is None + assert _parse_retry_after(" ") is None + + +def test_parse_retry_after_parses_delta_seconds(): + """Integer and float forms of delta-seconds (the common shape USGS + sends) are parsed directly without touching the HTTP-date branch.""" + assert _parse_retry_after("120") == 120.0 + assert _parse_retry_after("0") == 0.0 + assert _parse_retry_after("42.5") == 42.5 + # Surrounding whitespace is stripped before parsing. + assert _parse_retry_after(" 30 ") == 30.0 + + +def test_parse_retry_after_clamps_negative_delta_to_zero(): + """A negative delta-seconds means the server is saying "retry now." + Returning the negative value would let callers pass it to + ``time.sleep`` and get a ``ValueError`` — clamp at the source.""" + assert _parse_retry_after("-10") == 0.0 + assert _parse_retry_after("-0.5") == 0.0 + + +def test_parse_retry_after_returns_none_for_unparseable(): + """Garbage values (including the RFC 1123 HTTP-date form that the + HTTP spec allows but USGS doesn't actually send) surface as + ``None``, letting the chunker fall back to its own retry policy + instead of guessing a delay.""" + assert _parse_retry_after("not-a-date") is None + assert _parse_retry_after("Wed, 21 Oct 2099 07:28:00 GMT") is None + + +def test_raise_for_non_200_raises_service_unavailable_for_5xx(): + """5xx must surface as the typed ``ServiceUnavailable`` (not bare + ``RuntimeError``) so the chunker can wrap it as a resumable + ``ServiceInterrupted`` rather than treating it as a fatal error.""" + resp = _make_response(503, "", reason="Service Unavailable") + resp.headers["Retry-After"] = "120" + with pytest.raises(ServiceUnavailable) as excinfo: + _raise_for_non_200(resp) + assert excinfo.value.retry_after == 120.0 + + +def test_raise_for_non_200_attaches_retry_after_to_rate_limited(): + """``Retry-After`` on a 429 response must travel onto + ``RateLimited.retry_after`` so the chunker can surface it on + ``QuotaExhausted.retry_after`` for callers to honor.""" + resp = _make_response(429, "", reason="Too Many Requests") + resp.headers["Retry-After"] = "60" + with pytest.raises(RateLimited) as excinfo: + _raise_for_non_200(resp) + assert excinfo.value.retry_after == 60.0 + + +def test_raise_for_non_200_still_raises_bare_runtimeerror_for_other_4xx(): + """4xx other than 429 (e.g. 400 Bad Request) is a programmer error + that retry won't fix. Must remain bare ``RuntimeError`` so the + chunker's classifier doesn't wrap it as resumable.""" + resp = _make_response( + 400, + '{"code": "BadRequest", "description": "missing parameter"}', + reason="Bad Request", + content_type="application/json", + ) + with pytest.raises(RuntimeError) as excinfo: + _raise_for_non_200(resp) + # Must be exactly RuntimeError — not RateLimited, not + # ServiceUnavailable. Both subclass RuntimeError, so a plain + # ``pytest.raises(RuntimeError)`` would match either. + assert type(excinfo.value) is RuntimeError