diff --git a/README.md b/README.md index f9fe422b..a468c652 100644 --- a/README.md +++ b/README.md @@ -109,22 +109,13 @@ Visit the [API Reference](https://doi-usgs.github.io/dataretrieval-python/reference/waterdata.html) for more information and examples on available services and input parameters. -**NEW:** This module implements -[logging](https://docs.python.org/3/howto/logging.html#logging-basic-tutorial) -so you can view the URL requests sent to the USGS Water Data APIs and the -number of requests remaining each hour. These messages can be helpful for -troubleshooting and support. To enable logging in your Python console or -notebook: +For verbose troubleshooting and support — including the request URL sent to the +API — enable debug-level +[logging](https://docs.python.org/3/howto/logging.html#logging-basic-tutorial): ```python import logging -logging.basicConfig(level=logging.INFO) -``` -To log messages to a file, you can specify a filename in the -`basicConfig` call: - -```python -logging.basicConfig(filename='waterdata.log', level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) ``` ### Water Quality Portal (WQP) diff --git a/dataretrieval/waterdata/_progress.py b/dataretrieval/waterdata/_progress.py new file mode 100644 index 00000000..7263d555 --- /dev/null +++ b/dataretrieval/waterdata/_progress.py @@ -0,0 +1,264 @@ +"""A single self-updating status line for paginated / chunked Water Data queries. + +Water Data getters fan out two ways the caller can't see: large multi-value +requests are split into URL-length-safe *chunks* (``chunking`` module), and each +request follows ``next`` links across an unknown number of *pages* +(``utils._paginate``). This module surfaces that work as one line on stderr, +rewritten in place as data arrives:: + + Retrieving: daily · 6 pages · 2,881 rows · 995/1,000 requests remaining + +It replaces the per-page ``logger.info`` calls that previously narrated the same +events one line at a time. + +The active reporter lives in a :class:`~contextvars.ContextVar` rather than being +threaded through every signature: progress is a cross-cutting concern that the +chunk orchestrator (outer, chunk counts) and the page-walking loop (inner, +page/row/rate-limit counts) both update without knowing about each other. Call +:func:`progress_context` to activate one and :func:`current` to reach it. + +By default the line is shown for interactive use — an interactive terminal or a +Jupyter/IPython kernel (like ``tqdm``) — while redirected logs and CI stay clean. +``API_USGS_PROGRESS`` forces it on (``1``/``true``) or off (``0``/``false``). +""" + +from __future__ import annotations + +import contextvars +import os +import sys +from collections.abc import Iterator +from contextlib import contextmanager +from typing import TextIO + + +def _group_int(value: str) -> str: + """Comma-group a plain ASCII integer string; pass anything else through. + + (``str.isdigit`` alone is True for non-decimal unicode digits that ``int`` + rejects, hence the ``isascii`` guard.) + """ + return f"{int(value):,}" if value.isascii() and value.isdigit() else value + + +# The reporter active for the current query. A ContextVar (not a module global) +# so the chunk orchestrator and the page loop resolve to the same reporter +# within one query, and an unrelated query in another context can't clobber its +# state. (It does not give concurrent queries sharing one stderr separate +# lines — they would still interleave.) +_active: contextvars.ContextVar[ProgressReporter | None] = contextvars.ContextVar( + "waterdata_progress", default=None +) + +# Where to register for an API key. Surfaced once when a query runs without an +# API key configured (no API_USGS_PAT), since unauthenticated callers hit much +# lower rate limits (see the API_USGS_PAT note in the README). +SIGNUP_URL = "https://api.waterdata.usgs.gov/signup/" + +# Process-level latch so the "no API key" pointer is shown at most once. +_api_key_hint_shown = False + + +def _in_jupyter_kernel() -> bool: + """True when running inside a Jupyter/IPython *kernel* (notebook, lab, + qtconsole). + + A kernel's ``stderr`` isn't a TTY, but it honors carriage-return rewrites in + the cell output area — the same mechanism ``tqdm`` rides on — so the line is + worth showing there. The plain IPython terminal REPL is a + ``TerminalInteractiveShell`` (already a TTY), so only the ZMQ kernel needs + this extra signal. Detected without importing IPython: if it isn't already + imported, we aren't in a shell. + """ + ipython = sys.modules.get("IPython") + if ipython is None: + return False + shell = ipython.get_ipython() + return shell is not None and type(shell).__name__ == "ZMQInteractiveShell" + + +def _enabled_default(stream: TextIO) -> bool: + """Whether to draw the line by default. + + ``API_USGS_PROGRESS`` wins when set. Otherwise show it for interactive use — + a TTY or a Jupyter/IPython kernel — and stay quiet for redirected output, + logs, and CI. + """ + override = os.getenv("API_USGS_PROGRESS") + if override is not None: + return override.strip().lower() not in {"", "0", "false", "no", "off"} + if _in_jupyter_kernel(): + return True + return hasattr(stream, "isatty") and stream.isatty() + + +class ProgressReporter: + """Accumulates query progress and rewrites a single status line in place. + + Every update method is a no-op when the reporter is disabled, so call sites + need no ``if enabled`` guards. The line is redrawn with a leading carriage + return and padded to erase the previous (possibly longer) contents; + :meth:`close` terminates it with a newline so the final state persists. + """ + + def __init__( + self, + *, + service: str | None = None, + stream: TextIO | None = None, + enabled: bool | None = None, + ) -> None: + self._stream = stream if stream is not None else sys.stderr + self.enabled = _enabled_default(self._stream) if enabled is None else enabled + # The service/collection being retrieved (e.g. "daily", "peaks"), + # shown as the line's leading label. + self.service = service + self.total_chunks = 1 + self.current_chunk = 0 + self.pages = 0 + self.rows = 0 + self.rate_remaining: str | None = None + # The hourly request quota (``x-ratelimit-limit``), shown as the + # denominator when the server reports it. + self.rate_limit: str | None = None + self._last_len = 0 + # Whether anything was actually written to the stream — drives whether + # close() needs a terminating newline. (``current_chunk`` is a poor + # proxy: ``start_chunk`` sets it even when it doesn't render.) + self._rendered = False + self._closed = False + + def set_chunks(self, total: int) -> None: + """Record how many filter chunks this query was split into.""" + self.total_chunks = max(int(total), 1) + + def start_chunk(self, index: int) -> None: + """Mark the start of chunk ``index`` (1-based) and redraw. + + Only redraws when actually chunking (``total_chunks > 1``); a + single-chunk plan has nothing chunk-specific to show yet, so it + avoids a premature "0 pages" frame before the first page arrives. + """ + self.current_chunk = index + if self.total_chunks > 1: + self._render() + + def add_page(self, rows: int = 0) -> None: + """Record one fetched page carrying ``rows`` rows and redraw.""" + self.pages += 1 + self.rows += int(rows) + self._render() + + def set_rate_remaining( + self, value: str | int | None, limit: str | int | None = None + ) -> None: + """Update the rate-limit display from the response headers. + + ``value`` is ``x-ratelimit-remaining``; ``limit`` is the optional + ``x-ratelimit-limit`` quota, shown as the denominator. Empty/missing + values are ignored so a page that omits a header doesn't blank out the + last known value. + """ + if value not in (None, ""): + self.rate_remaining = str(value) + if limit not in (None, ""): + self.rate_limit = str(limit) + + def _format(self) -> str: + parts: list[str] = [] + if self.total_chunks > 1: + parts.append(f"chunk {self.current_chunk}/{self.total_chunks}") + parts.append(f"{self.pages} page" + ("" if self.pages == 1 else "s")) + if self.rows: + parts.append(f"{self.rows:,} rows") + if self.rate_remaining is not None: + remaining = _group_int(self.rate_remaining) + if self.rate_limit is not None: + limit = _group_int(self.rate_limit) + segment = f"{remaining}/{limit} requests remaining" + else: + segment = f"{remaining} requests remaining" + parts.append(segment) + if self.service: + return f"Retrieving: {self.service} · " + " · ".join(parts) + return "Progress: " + " · ".join(parts) + + def _render(self) -> None: + if not self.enabled or self._closed: + return + try: + line = self._format() + pad = max(self._last_len - len(line), 0) + self._stream.write("\r" + line + " " * pad) + self._stream.flush() + self._last_len = len(line) + self._rendered = True + except Exception: # noqa: BLE001 + # Progress output is best-effort cosmetics; a broken pipe (output + # piped to ``head``), a closed stream, or an encoding error must + # never disturb — let alone truncate — the query. Disable so we + # don't retry on every subsequent page. + self.enabled = False + + def close(self) -> None: + """Finalize the line with a trailing newline so it persists on screen. + + If no API key is configured (no ``API_USGS_PAT``), append a one-time + pointer to API-key registration, since unauthenticated callers hit much + lower rate limits. + """ + if self._closed: + return + self._closed = True + if not (self.enabled and self._rendered): + return + try: + self._stream.write("\n") + self._maybe_hint_api_key() + self._stream.flush() + except Exception: # noqa: BLE001 + self.enabled = False + + def _maybe_hint_api_key(self) -> None: + global _api_key_hint_shown + if _api_key_hint_shown or os.getenv("API_USGS_PAT"): + return + # Set the once-per-process latch only after a successful write, so a + # failed write (broken pipe) doesn't silently burn the hint for every + # later query in the process. + self._stream.write( + f"No API key detected — register for higher rate limits at {SIGNUP_URL}\n" + ) + _api_key_hint_shown = True + + +@contextmanager +def progress_context( + *, + service: str | None = None, + stream: TextIO | None = None, + enabled: bool | None = None, +) -> Iterator[ProgressReporter]: + """Activate a :class:`ProgressReporter` for the duration of a query. + + ``service`` labels the line (e.g. ``"Retrieving: daily ..."``). If a reporter + is already active (a nested call), the existing one is yielded unchanged so + the outermost query owns the single line; only the outermost context closes + it (and ``service``/``stream``/``enabled`` of a nested call are ignored). + """ + existing = _active.get() + if existing is not None: + yield existing + return + reporter = ProgressReporter(service=service, stream=stream, enabled=enabled) + token = _active.set(reporter) + try: + yield reporter + finally: + _active.reset(token) + reporter.close() + + +def current() -> ProgressReporter | None: + """Return the reporter active for the current query, or ``None``.""" + return _active.get() diff --git a/dataretrieval/waterdata/api.py b/dataretrieval/waterdata/api.py index 106501dd..6f24d80f 100644 --- a/dataretrieval/waterdata/api.py +++ b/dataretrieval/waterdata/api.py @@ -2338,7 +2338,7 @@ def get_samples( req = PreparedRequest() req.prepare_url(url, params=params) - logger.info("Request: %s", req.url) + logger.debug("Request: %s", req.url) response = requests.get( url, params=params, verify=ssl_check, headers=_default_headers() @@ -2410,7 +2410,7 @@ def get_samples_summary( req = PreparedRequest() req.prepare_url(url, params=params) - logger.info("Request: %s", req.url) + logger.debug("Request: %s", req.url) response = requests.get( url, params=params, verify=ssl_check, headers=_default_headers() diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index c6eb9945..a6fee155 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -49,6 +49,7 @@ import requests from requests.structures import CaseInsensitiveDict +from . import _progress from .filters import ( _check_numeric_filter_pitfall, _is_chunkable, @@ -1126,10 +1127,15 @@ def resume(self) -> tuple[pd.DataFrame, requests.Response]: (checked after the first sub-request). """ with requests.Session() as session, _publish_session(session): + reporter = _progress.current() + if reporter is not None: + reporter.set_chunks(self.plan.total) completed = len(self._chunks) for i, sub_args in enumerate(self.plan.iter_sub_args()): if i < completed: continue + if reporter is not None: + reporter.start_chunk(i + 1) self._issue(sub_args) frames = [frame for frame, _ in self._chunks] responses = [resp for _, resp in self._chunks] diff --git a/dataretrieval/waterdata/utils.py b/dataretrieval/waterdata/utils.py index 58d4673d..dd908143 100644 --- a/dataretrieval/waterdata/utils.py +++ b/dataretrieval/waterdata/utils.py @@ -17,7 +17,7 @@ from dataretrieval import __version__ from dataretrieval.utils import BaseMetadata -from dataretrieval.waterdata import chunking +from dataretrieval.waterdata import _progress, chunking from dataretrieval.waterdata.chunking import ( _QUOTA_HEADER, RateLimited, @@ -40,6 +40,15 @@ # Set up logger for this module logger = logging.getLogger(__name__) +# Whether geopandas is present is a static, environment-level fact, so warn once +# here at import time rather than per query/chunk. That avoids the warning +# repeating on every call and avoids it interleaving with the progress line's +# carriage-return rewrites. +if not GEOPANDAS: + logger.warning( + "Geopandas not installed. Geometries will be flattened into pandas DataFrames." + ) + BASE_URL = "https://api.waterdata.usgs.gov" OGC_API_VERSION = "v0" OGC_API_URL = f"{BASE_URL}/ogcapi/{OGC_API_VERSION}" @@ -657,9 +666,7 @@ def _next_req_url( Notes ----- - - If the environment variable "API_USGS_PAT" is set, logs the remaining - requests for the current hour. - - Logs the next URL if found at info level. + - Returns None when the response carries no features. - Expects the response JSON to contain a "links" list with objects having "rel" and "href" keys. - Checks for the "next" relation in the "links" to determine the next URL. @@ -668,17 +675,9 @@ def _next_req_url( body = resp.json() if not body.get("numberReturned"): return None - header_info = resp.headers - if os.getenv("API_USGS_PAT", ""): - logger.info( - "Remaining requests this hour: %s", - header_info.get(_QUOTA_HEADER, ""), - ) for link in body.get("links", []): if link.get("rel") == "next": - next_url = link.get("href") - logger.info("Next URL: %s", next_url) - return next_url + return link.get("href") return None @@ -855,7 +854,6 @@ def _aggregate_paginated_response( def _paginate( initial_req: requests.PreparedRequest, *, - geopd: bool, parse_response: Callable[[requests.Response], tuple[pd.DataFrame, _Cursor | None]], follow_up: Callable[[_Cursor, requests.Session], requests.Response], client: requests.Session | None = None, @@ -874,10 +872,6 @@ def _paginate( ---------- initial_req : requests.PreparedRequest First-page request to send. - geopd : bool - Whether ``geopandas`` is available — logged once at WARNING - level when ``False`` (matches historical behavior of both - callers). parse_response : callable ``resp -> (df, next_cursor_or_None)``. Returns the page's DataFrame and the cursor (URL, token, …) used to drive @@ -918,13 +912,8 @@ def _paginate( callers can branch on the specific type; equivalent failures on subsequent pages are wrapped per above. """ - logger.info("Requesting: %s", initial_req.url) - if not geopd: - logger.warning( - "Geopandas not installed. Geometries will be flattened " - "into pandas DataFrames." - ) - + logger.debug("Requesting: %s", initial_req.url) + reporter = _progress.current() with _session(client) as sess: resp = sess.send(initial_req) _raise_for_non_200(resp) @@ -944,6 +933,12 @@ def _paginate( logger.warning("Initial response parse failed.") raise RuntimeError(_paginated_failure_message(0, e)) from e dfs = [df] + if reporter is not None: + reporter.set_rate_remaining( + resp.headers.get(_QUOTA_HEADER), + limit=resp.headers.get("x-ratelimit-limit"), + ) + reporter.add_page(rows=len(df)) while cursor is not None: try: resp = follow_up(cursor, sess) @@ -951,6 +946,12 @@ def _paginate( df, cursor = parse_response(resp) dfs.append(df) total_elapsed += resp.elapsed + if reporter is not None: + reporter.set_rate_remaining( + resp.headers.get(_QUOTA_HEADER), + limit=resp.headers.get("x-ratelimit-limit"), + ) + reporter.add_page(rows=len(df)) except Exception as e: # noqa: BLE001 logger.warning( "Request failed at cursor %r. Data download interrupted.", @@ -1027,7 +1028,6 @@ def follow_up(cursor: str, sess: requests.Session) -> requests.Response: return _paginate( req, - geopd=geopd, parse_response=parse_response, follow_up=follow_up, client=client, @@ -1244,7 +1244,8 @@ def get_ogc_data( convert_type = args.pop("convert_type", False) args = {k: v for k, v in args.items() if v is not None} - return_list, response = _fetch_once(args) + with _progress.progress_context(service=service): + return_list, response = _fetch_once(args) return_list = _deal_with_empty(return_list, properties, service) if convert_type: return_list = _type_cols(return_list) @@ -1317,9 +1318,8 @@ def _handle_stats_nesting( if not features: return gpd.GeoDataFrame() if geopd else pd.DataFrame() - # The geopd-missing warning is emitted once at the top of - # ``get_stats_data`` (parallel to ``_walk_pages``); doing it here - # would log per page. + # The geopd-missing warning is emitted once at import (see top of module); + # doing it here would log per page. if not geopd: outer_props = [ {k: v for k, v in (f.get("properties") or {}).items() if k != "data"} @@ -1488,13 +1488,15 @@ def follow_up(cursor: str, sess: requests.Session) -> requests.Response: method, url=url, params={**args, "next_token": cursor}, headers=headers ) - df, response = _paginate( - req, - geopd=GEOPANDAS, - parse_response=parse_response, - follow_up=follow_up, - client=client, - ) + # The stats path doesn't go through ``multi_value_chunked``, so it opens + # its own progress context; ``_paginate`` reports pages/rate-limit into it. + with _progress.progress_context(service=service): + df, response = _paginate( + req, + parse_response=parse_response, + follow_up=follow_up, + client=client, + ) if expand_percentiles: df = _expand_percentiles(df) diff --git a/tests/waterdata_progress_test.py b/tests/waterdata_progress_test.py new file mode 100644 index 00000000..14a98839 --- /dev/null +++ b/tests/waterdata_progress_test.py @@ -0,0 +1,365 @@ +"""Tests for the Water Data single-line progress reporter. + +Covers ProgressReporter rendering / no-op behavior, TTY + environment-variable +gating, progress_context nesting, and that the pagination loop in +``_walk_pages`` reports pages and the rate-limit header through an active +reporter. +""" + +import io +import sys +import types +from unittest import mock + +import pytest +import requests + +from dataretrieval.waterdata import _progress +from dataretrieval.waterdata._progress import ( + ProgressReporter, + current, + progress_context, +) +from dataretrieval.waterdata.utils import _walk_pages + + +@pytest.fixture(autouse=True) +def _reset_api_key_hint_latch(monkeypatch): + """The 'no API key' pointer is latched once per process; reset it so each + test sees a clean slate regardless of order.""" + monkeypatch.setattr(_progress, "_api_key_hint_shown", False) + + +# -- ProgressReporter rendering ------------------------------------------------ + + +def test_disabled_reporter_writes_nothing(): + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=False) + reporter.set_chunks(3) + reporter.start_chunk(1) + reporter.add_page(rows=5) + reporter.set_rate_remaining("100") + reporter.close() + assert stream.getvalue() == "" + + +def test_renders_pages_rows_and_rate_limit(): + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.set_rate_remaining("4870") + reporter.add_page(rows=1234) + out = stream.getvalue() + assert out.lstrip("\r").startswith("Progress: ") + assert "1 page" in out + assert "1,234 rows" in out + assert "4,870 requests remaining" in out + + +def test_page_count_is_pluralized(): + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.add_page() + assert "1 page" in stream.getvalue() and "1 pages" not in stream.getvalue() + reporter.add_page() + assert "2 pages" in stream.getvalue() + + +def test_chunk_segment_only_shown_when_multiple_chunks(): + single = io.StringIO() + reporter = ProgressReporter(stream=single, enabled=True) + reporter.set_chunks(1) + reporter.add_page() + assert "chunk" not in single.getvalue() + + many = io.StringIO() + reporter = ProgressReporter(stream=many, enabled=True) + reporter.set_chunks(5) + reporter.start_chunk(2) + assert "chunk 2/5" in many.getvalue() + + +def test_missing_rate_limit_does_not_blank_last_known_value(): + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.set_rate_remaining("500") + reporter.set_rate_remaining(None) + reporter.set_rate_remaining("") + reporter.add_page() + assert "500 requests remaining" in stream.getvalue() + + +def test_renders_remaining_over_limit(): + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.set_rate_remaining("952", limit="1000") + reporter.add_page(rows=1) + assert "952/1,000 requests remaining" in stream.getvalue() + + +def test_no_slash_when_limit_absent(): + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.set_rate_remaining("4870") # remaining only, no limit header + reporter.add_page() + out = stream.getvalue() + assert "4,870 requests remaining" in out + assert "/" not in out + + +def test_service_label_leads_the_line(): + stream = io.StringIO() + reporter = ProgressReporter(service="daily", stream=stream, enabled=True) + reporter.add_page(rows=5) + assert stream.getvalue().lstrip("\r").startswith("Retrieving: daily · ") + + +def test_close_terminates_active_line_with_newline(): + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.add_page() + reporter.close() + assert stream.getvalue().endswith("\n") + + +def test_close_without_activity_writes_nothing(): + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.close() + assert stream.getvalue() == "" + + +class _RaisingStream: + """A stream whose writes always fail, e.g. a broken pipe (output | head).""" + + def write(self, *_): + raise BrokenPipeError("broken pipe") + + def flush(self): + pass + + +def test_reporter_swallows_stream_errors_and_disables(monkeypatch): + monkeypatch.delenv("API_USGS_PAT", raising=False) + reporter = ProgressReporter(stream=_RaisingStream(), enabled=True) + reporter.add_page(rows=1) # render write raises -> must be swallowed + reporter.close() # newline + hint writes raise -> must be swallowed + assert reporter.enabled is False + + +# -- API-key pointer ----------------------------------------------------------- + + +def test_hints_api_key_when_no_key_configured(monkeypatch): + monkeypatch.delenv("API_USGS_PAT", raising=False) + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.add_page(rows=5) + reporter.close() + assert _progress.SIGNUP_URL in stream.getvalue() + + +def test_hint_fires_even_when_rate_limit_was_seen(monkeypatch): + # Anonymous responses still carry a rate-limit header, so absence of a key + # — not absence of the header — is what drives the pointer. + monkeypatch.delenv("API_USGS_PAT", raising=False) + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.set_rate_remaining("891") + reporter.add_page(rows=5) + reporter.close() + assert _progress.SIGNUP_URL in stream.getvalue() + + +def test_no_hint_when_api_key_present(monkeypatch): + monkeypatch.setenv("API_USGS_PAT", "secret") + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.add_page(rows=5) # no rate-limit, but a key is configured + reporter.close() + assert _progress.SIGNUP_URL not in stream.getvalue() + + +def test_no_hint_when_disabled(monkeypatch): + monkeypatch.delenv("API_USGS_PAT", raising=False) + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=False) + reporter.add_page(rows=5) + reporter.close() + assert stream.getvalue() == "" + + +def test_api_key_hint_shown_at_most_once(monkeypatch): + monkeypatch.delenv("API_USGS_PAT", raising=False) + + first = io.StringIO() + r1 = ProgressReporter(stream=first, enabled=True) + r1.add_page(rows=5) + r1.close() + assert _progress.SIGNUP_URL in first.getvalue() + + second = io.StringIO() + r2 = ProgressReporter(stream=second, enabled=True) + r2.add_page(rows=5) + r2.close() + assert _progress.SIGNUP_URL not in second.getvalue() + + +# -- enable/disable gating ----------------------------------------------------- + + +def test_default_disabled_for_non_tty(monkeypatch): + monkeypatch.delenv("API_USGS_PROGRESS", raising=False) + monkeypatch.setattr(_progress, "_in_jupyter_kernel", lambda: False) + # io.StringIO.isatty() returns False. + assert ProgressReporter(stream=io.StringIO()).enabled is False + + +def test_env_var_forces_on(monkeypatch): + monkeypatch.setenv("API_USGS_PROGRESS", "1") + assert ProgressReporter(stream=io.StringIO()).enabled is True + + +def test_env_var_forces_off_even_on_tty(monkeypatch): + monkeypatch.setenv("API_USGS_PROGRESS", "0") + tty = mock.MagicMock() + tty.isatty.return_value = True + assert ProgressReporter(stream=tty).enabled is False + + +def _fake_ipython(shell_class_name): + """A stand-in IPython module whose get_ipython() returns a shell of the + given class name (e.g. 'ZMQInteractiveShell' for a Jupyter kernel).""" + shell = type(shell_class_name, (), {})() + return types.SimpleNamespace(get_ipython=lambda: shell) + + +def test_enabled_in_jupyter_kernel(monkeypatch): + # A Jupyter kernel's stderr isn't a TTY, but the line should still show + # (it honors \r in the cell output, like tqdm). + monkeypatch.delenv("API_USGS_PROGRESS", raising=False) + monkeypatch.setitem(sys.modules, "IPython", _fake_ipython("ZMQInteractiveShell")) + assert ProgressReporter(stream=io.StringIO()).enabled is True + + +def test_terminal_ipython_without_tty_stays_disabled(monkeypatch): + # The terminal REPL is its own TTY; the kernel signal must not force the + # line on for a non-TTY (e.g. redirected) stream. + monkeypatch.delenv("API_USGS_PROGRESS", raising=False) + monkeypatch.setitem( + sys.modules, "IPython", _fake_ipython("TerminalInteractiveShell") + ) + assert ProgressReporter(stream=io.StringIO()).enabled is False + + +def test_env_var_off_overrides_jupyter_kernel(monkeypatch): + monkeypatch.setenv("API_USGS_PROGRESS", "0") + monkeypatch.setitem(sys.modules, "IPython", _fake_ipython("ZMQInteractiveShell")) + assert ProgressReporter(stream=io.StringIO()).enabled is False + + +# -- progress_context ---------------------------------------------------------- + + +def test_progress_context_sets_and_clears_current(monkeypatch): + monkeypatch.delenv("API_USGS_PROGRESS", raising=False) + assert current() is None + with progress_context(enabled=False) as reporter: + assert current() is reporter + assert current() is None + + +def test_nested_context_reuses_outer_reporter(): + with progress_context(enabled=False) as outer: + with progress_context(enabled=False) as inner: + assert inner is outer + # Inner exit must not deactivate the outer reporter. + assert current() is outer + assert current() is None + + +# -- integration with _walk_pages --------------------------------------------- + + +def _resp(features, *, next_url=None, rate_remaining=None): + resp = mock.MagicMock() + links = [{"rel": "next", "href": next_url}] if next_url else [] + resp.json.return_value = { + "numberReturned": len(features), + "features": features, + "links": links, + } + headers = {} + if rate_remaining is not None: + headers["x-ratelimit-remaining"] = rate_remaining + resp.headers = headers + resp.status_code = 200 + return resp + + +def test_walk_pages_reports_pages_and_rate_limit(): + resp1 = _resp( + [{"id": "1", "properties": {"v": "a"}}], + next_url="https://example.com/p2", + rate_remaining="4999", + ) + resp2 = _resp([{"id": "2", "properties": {"v": "b"}}], rate_remaining="4998") + + client = mock.MagicMock(spec=requests.Session) + client.send.return_value = resp1 + client.request.return_value = resp2 + + req = mock.MagicMock(spec=requests.PreparedRequest) + req.method = "GET" + req.headers = {} + req.url = "https://example.com/p1" + + stream = io.StringIO() + with progress_context(service="daily", stream=stream, enabled=True): + df, _ = _walk_pages(geopd=False, req=req, client=client) + + assert len(df) == 2 + out = stream.getvalue() + # The service set on the context reaches _paginate's render via the contextvar. + assert "Retrieving: daily ·" in out + assert "2 pages" in out + assert "4,998 requests remaining" in out + assert out.endswith("\n") + + +def test_walk_pages_without_context_does_not_error(): + # No active reporter: pagination must still work and stay silent. + resp = _resp([{"id": "1", "properties": {"v": "a"}}]) + client = mock.MagicMock(spec=requests.Session) + client.send.return_value = resp + + req = mock.MagicMock(spec=requests.PreparedRequest) + req.method = "GET" + req.headers = {} + req.url = "https://example.com/p1" + + df, _ = _walk_pages(geopd=False, req=req, client=client) + assert len(df) == 1 + assert current() is None + + +def test_broken_progress_stream_does_not_truncate_pagination(): + # A render failure (broken pipe) lands inside _walk_pages' per-page try; + # it must NOT be mistaken for a failed request and silently drop pages. + resp1 = _resp( + [{"id": "1", "properties": {"v": "a"}}], next_url="https://example.com/p2" + ) + resp2 = _resp([{"id": "2", "properties": {"v": "b"}}]) + client = mock.MagicMock(spec=requests.Session) + client.send.return_value = resp1 + client.request.return_value = resp2 + + req = mock.MagicMock(spec=requests.PreparedRequest) + req.method = "GET" + req.headers = {} + req.url = "https://example.com/p1" + + with progress_context(stream=_RaisingStream(), enabled=True): + df, _ = _walk_pages(geopd=False, req=req, client=client) + + assert len(df) == 2 # both pages returned despite the broken progress stream