diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index cbd28f8b0a3f..e46e66bd11d3 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -49,6 +49,7 @@ from apache_beam.portability.api import beam_fn_api_pb2_grpc from apache_beam.runners.worker.channel_factory import GRPCChannelFactory from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor +from apache_beam.utils.byte_limited_queue import ByteLimitedQueue if TYPE_CHECKING: import apache_beam.coders.slow_stream @@ -455,10 +456,20 @@ class _GrpcDataChannel(DataChannel): def __init__(self, data_buffer_time_limit_ms=0): # type: (int) -> None + def _element_weight(element): + if isinstance(element, beam_fn_api_pb2.Elements.Data): + return len(element.data) + elif isinstance(element, beam_fn_api_pb2.Elements.Timers): + return len(element.timers) + return 0 + self._data_buffer_time_limit_ms = data_buffer_time_limit_ms - self._to_send = queue.Queue() # type: queue.Queue[DataOrTimers] + self._to_send = ByteLimitedQueue( + maxsize=10000, maxweight=100 << 20, + weighing_fn=_element_weight) # type: queue.Queue[DataOrTimers] self._received = collections.defaultdict( - lambda: queue.Queue(maxsize=5) + lambda: ByteLimitedQueue( + maxsize=5, maxweight=100 << 20, weighing_fn=_element_weight) ) # type: DefaultDict[str, queue.Queue[DataOrTimers]] # Keep a cache of completed instructions. Data for completed instructions diff --git a/sdks/python/apache_beam/utils/byte_limited_queue.py b/sdks/python/apache_beam/utils/byte_limited_queue.py new file mode 100644 index 000000000000..9bb4d469ae2a --- /dev/null +++ b/sdks/python/apache_beam/utils/byte_limited_queue.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A thread-safe queue that limits capacity by total byte size.""" + +import queue +import time +from typing import Any +from typing import Callable + + +class ByteLimitedQueue(queue.Queue): + """A queue.Queue that limits by both element count and total weight. + + A single element is allowed to exceed the maxweight to avoid deadlock. + Note that shutdown is only supported after there are no more put calls. + """ + def __init__( + self, + weighing_fn, # type: Callable[[Any], int] + maxsize=0, # type: int + maxweight=0, # type: int + ): + # type: (...) -> None + + """Initializes a ByteLimitedQueue. + + Args: + weighing_fn: A Callable that accepts an item and returns its integer + weight. + maxsize: The maximum number of items allowed in the queue. If 0 or + negative, there is no limit on the number of elements. + maxweight: The maximum accumulated weight allowed in the queue. + """ + super().__init__(maxsize=0) + self.max_elements = maxsize + self.max_weight = maxweight + self.weighing_fn = weighing_fn + self._byte_size = 0 + + def _is_full(self, item_size): + if self._qsize() == 0: + return False + if self.max_elements > 0 and self._qsize() >= self.max_elements: + return True + if self.max_weight > 0 and self._byte_size + item_size > self.max_weight: + return True + return False + + def put(self, item, block=True, timeout=None): + item_size = max(1, self.weighing_fn(item)) + with self.not_full: + if not block: + if self._is_full(item_size): + raise queue.Full + elif timeout is None: + while self._is_full(item_size): + self.not_full.wait() + elif timeout < 0: + raise ValueError("'timeout' must be a non-negative number") + else: + endtime = time.monotonic() + timeout + while self._is_full(item_size): + remaining = endtime - time.monotonic() + if remaining <= 0.0: + raise queue.Full + self.not_full.wait(remaining) + + self._put((item, item_size)) + self._byte_size += item_size + self.unfinished_tasks += 1 + self.not_empty.notify() + + def _get(self): + item, item_weight = super()._get() + self._byte_size -= item_weight + return item + + def byte_size(self): + """Return the total byte weight of elements in the queue.""" + with self.mutex: + return self._byte_size diff --git a/sdks/python/apache_beam/utils/byte_limited_queue_test.py b/sdks/python/apache_beam/utils/byte_limited_queue_test.py new file mode 100644 index 000000000000..e6349f3af00b --- /dev/null +++ b/sdks/python/apache_beam/utils/byte_limited_queue_test.py @@ -0,0 +1,168 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for byte-limited queue.""" + +import queue +import sys +import threading +import time +import unittest + +from apache_beam.utils.byte_limited_queue import ByteLimitedQueue + + +class FakeItem(object): + def __init__(self, size): + self._size = size + + def weight(self): + return self._size + + +class ByteLimitedQueueTest(unittest.TestCase): + def test_unbounded(self): + bq = ByteLimitedQueue(lambda x: x.weight()) + for i in range(200): + bq.put(FakeItem(i)) + # Add 1 since weight of zero is set to 1 + self.assertEqual(bq.byte_size(), sum(range(200)) + 1) + self.assertEqual(bq.qsize(), 200) + + def test_put_and_get(self): + bq = ByteLimitedQueue(lambda x: x.weight(), maxweight=200) + bq.put(FakeItem(50)) + bq.put(FakeItem(140)) + self.assertEqual(bq.byte_size(), 190) + self.assertEqual(bq.qsize(), 2) + # Putting another would exceed 200. + with self.assertRaises(queue.Full): + bq.put(FakeItem(20), block=False) + bq.put(FakeItem(10), block=False) + self.assertEqual(bq.byte_size(), 200) + self.assertEqual(bq.qsize(), 3) + + self.assertEqual(bq.get().weight(), 50) + self.assertEqual(bq.byte_size(), 150) + self.assertEqual(bq.qsize(), 2) + bq.put(FakeItem(20), block=False) + + def test_dual_limit(self): + # Queue limits: at most 2 items, OR at most 100 weight. + bq = ByteLimitedQueue(lambda x: x.weight(), maxsize=3, maxweight=100) + bq.put(FakeItem(30)) + bq.put(FakeItem(40)) + bq.put(FakeItem(20)) + self.assertEqual(bq.byte_size(), 90) + self.assertEqual(bq.qsize(), 3) + # Full on element count (size=2). + with self.assertRaises(queue.Full): + bq.put(FakeItem(10), block=False) + self.assertEqual(bq.get().weight(), 30) + self.assertEqual(bq.get().weight(), 40) + bq.put(FakeItem(10)) + # Full on byte count + with self.assertRaises(queue.Full): + bq.put(FakeItem(90), block=False) + self.assertEqual(bq.get().weight(), 20) + bq.put(FakeItem(90), block=False) + + @unittest.skipIf(sys.version_info < (3, 13), 'Queue.ShutDown added in 3.13.') + def test_multithreading(self): + bq = ByteLimitedQueue(lambda x: x.weight(), maxsize=0, maxweight=100) + received = [] + + def producer(): + for i in range(101): + bq.put(FakeItem(i)) + + def consumer(): + while True: + try: + received.append(bq.get().weight()) + except queue.ShutDown: + break + + t1 = threading.Thread(target=producer) + t2 = threading.Thread(target=producer) + t3 = threading.Thread(target=consumer) + + t1.start() + t2.start() + t3.start() + + t1.join() + t2.join() + bq.shutdown() + + t3.join() + + self.assertEqual(len(received), 202) + self.assertEqual(sum(received), 2 * sum(range(101))) + + def test_multithreading_timeout(self): + bq = ByteLimitedQueue(lambda x: x.weight(), maxsize=0, maxweight=10) + bq.put(FakeItem(10)) + + # The queue is completely full. A timeout put should raise queue.Full. + with self.assertRaises(queue.Full): + bq.put(FakeItem(5), timeout=0.01) + + def delayed_consumer(): + time.sleep(0.05) + bq.get() + + # Start a thread that will free up space after 50ms. + t = threading.Thread(target=delayed_consumer) + t.start() + + # The put should succeed once the consumer runs, use a high timeout to + # flakiness. + bq.put(FakeItem(5), timeout=60) + t.join() + + def test_negative_timeout(self): + bq = ByteLimitedQueue(lambda x: x.weight()) + # Putting an item with a negative timeout should raise ValueError. + with self.assertRaises(ValueError): + bq.put(FakeItem(5), timeout=-1) + + def test_single_element_override(self): + bq = ByteLimitedQueue(lambda x: x.weight(), maxweight=10) + # An item of size 50 exceeds maxweight 10, but should be admitted + # immediately without blocking since the queue is currently empty! + bq.put(FakeItem(50), block=False) + self.assertEqual(bq.qsize(), 1) + self.assertEqual(bq.byte_size(), 50) + + def test_inconsistent_weighing_fn(self): + # Return a different weight for the same item. + weights = [10, 5] + bq = ByteLimitedQueue(lambda x: weights.pop(0), maxweight=100) + + bq.put(1) + self.assertEqual(bq.byte_size(), 10) + + # Upon popping, the weighing function (if called) would have returned 5, + # but the stored weight prevents corruption and cleanly reduces the size to + # 0. + bq.get() + self.assertEqual(bq.byte_size(), 0) + + +if __name__ == '__main__': + unittest.main()