-
Notifications
You must be signed in to change notification settings - Fork 4.6k
[Python] Python] Bound the memory used for fnapi outbound data messages and receiving messages. #38407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[Python] Python] Bound the memory used for fnapi outbound data messages and receiving messages. #38407
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| # | ||
| # Licensed to the Apache Software Foundation (ASF) under one or more | ||
| # contributor license agreements. See the NOTICE file distributed with | ||
| # this work for additional information regarding copyright ownership. | ||
| # The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| # (the "License"); you may not use this file except in compliance with | ||
| # the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| """A thread-safe queue that limits capacity by total byte size.""" | ||
|
|
||
| import queue | ||
| import time | ||
| from typing import Any | ||
| from typing import Callable | ||
|
|
||
|
|
||
| class ByteLimitedQueue(queue.Queue): | ||
| """A queue.Queue that limits by both element count and total weight. | ||
|
|
||
| A single element is allowed to exceed the maxweight to avoid deadlock. | ||
| Note that shutdown is only supported after there are no more put calls. | ||
| """ | ||
| def __init__( | ||
| self, | ||
| weighing_fn, # type: Callable[[Any], int] | ||
| maxsize=0, # type: int | ||
| maxweight=0, # type: int | ||
| ): | ||
| # type: (...) -> None | ||
|
|
||
| """Initializes a ByteLimitedQueue. | ||
|
|
||
| Args: | ||
| weighing_fn: A Callable that accepts an item and returns its integer | ||
| weight. | ||
| maxsize: The maximum number of items allowed in the queue. If 0 or | ||
| negative, there is no limit on the number of elements. | ||
| maxweight: The maximum accumulated weight allowed in the queue. | ||
| """ | ||
| super().__init__(maxsize=0) | ||
| self.max_elements = maxsize | ||
| self.max_weight = maxweight | ||
| self.weighing_fn = weighing_fn | ||
| self._byte_size = 0 | ||
|
|
||
| def _is_full(self, item_size): | ||
| if self._qsize() == 0: | ||
| return False | ||
| if self.max_elements > 0 and self._qsize() >= self.max_elements: | ||
| return True | ||
| if self.max_weight > 0 and self._byte_size + item_size > self.max_weight: | ||
| return True | ||
| return False | ||
|
|
||
| def put(self, item, block=True, timeout=None): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It may be worth calling out in the docstring that we don't guarantee that the element will land as soon as enough space opens up, since https://github.com/python/cpython/blob/45c47d26c230086163ac1ef0aa9f955f794fb69c/Lib/queue.py#L214-L215 will wake up one random thread that is waiting, which might not be the one that can fit. this is fine as long as we are continuously emptying the queue |
||
| item_size = max(1, self.weighing_fn(item)) | ||
| with self.not_full: | ||
| if not block: | ||
| if self._is_full(item_size): | ||
| raise queue.Full | ||
| elif timeout is None: | ||
| while self._is_full(item_size): | ||
| self.not_full.wait() | ||
| elif timeout < 0: | ||
| raise ValueError("'timeout' must be a non-negative number") | ||
| else: | ||
| endtime = time.monotonic() + timeout | ||
|
tvalentyn marked this conversation as resolved.
|
||
| while self._is_full(item_size): | ||
| remaining = endtime - time.monotonic() | ||
| if remaining <= 0.0: | ||
| raise queue.Full | ||
| self.not_full.wait(remaining) | ||
|
|
||
| self._put((item, item_size)) | ||
| self._byte_size += item_size | ||
| self.unfinished_tasks += 1 | ||
| self.not_empty.notify() | ||
|
Comment on lines
+64
to
+86
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
def put(self, item, block=True, timeout=None):
item_size = max(1, self.weighing_fn(item))
with self.not_full:
if hasattr(self, '_is_shutdown') and self._is_shutdown():
raise queue.ShutDown
if not block:
if self._is_full(item_size):
raise queue.Full
elif timeout is None:
while self._is_full(item_size):
if hasattr(self, '_is_shutdown') and self._is_shutdown():
raise queue.ShutDown
self.not_full.wait()
elif timeout < 0:
raise ValueError("'timeout' must be a non-negative number")
else:
endtime = time.monotonic() + timeout
while self._is_full(item_size):
if hasattr(self, '_is_shutdown') and self._is_shutdown():
raise queue.ShutDown
remaining = endtime - time.monotonic()
if remaining <= 0.0:
raise queue.Full
self.not_full.wait(remaining)
self._put((item, item_size))
self._byte_size += item_size
self.unfinished_tasks += 1
self.not_empty.notify()
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Supporting shutdown is unneeded for the use by data_plane. I just added a comment for now that shutdown does not work if called while there may be more put calls. |
||
|
|
||
| def _get(self): | ||
| item, item_weight = super()._get() | ||
| self._byte_size -= item_weight | ||
| return item | ||
|
scwhittle marked this conversation as resolved.
|
||
|
|
||
| def byte_size(self): | ||
| """Return the total byte weight of elements in the queue.""" | ||
| with self.mutex: | ||
| return self._byte_size | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,168 @@ | ||
| # | ||
| # Licensed to the Apache Software Foundation (ASF) under one or more | ||
| # contributor license agreements. See the NOTICE file distributed with | ||
| # this work for additional information regarding copyright ownership. | ||
| # The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| # (the "License"); you may not use this file except in compliance with | ||
| # the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| """Unit tests for byte-limited queue.""" | ||
|
|
||
| import queue | ||
| import sys | ||
| import threading | ||
| import time | ||
| import unittest | ||
|
|
||
| from apache_beam.utils.byte_limited_queue import ByteLimitedQueue | ||
|
|
||
|
|
||
| class FakeItem(object): | ||
| def __init__(self, size): | ||
| self._size = size | ||
|
|
||
| def weight(self): | ||
| return self._size | ||
|
|
||
|
|
||
| class ByteLimitedQueueTest(unittest.TestCase): | ||
| def test_unbounded(self): | ||
| bq = ByteLimitedQueue(lambda x: x.weight()) | ||
| for i in range(200): | ||
| bq.put(FakeItem(i)) | ||
| # Add 1 since weight of zero is set to 1 | ||
| self.assertEqual(bq.byte_size(), sum(range(200)) + 1) | ||
| self.assertEqual(bq.qsize(), 200) | ||
|
|
||
| def test_put_and_get(self): | ||
| bq = ByteLimitedQueue(lambda x: x.weight(), maxweight=200) | ||
| bq.put(FakeItem(50)) | ||
| bq.put(FakeItem(140)) | ||
| self.assertEqual(bq.byte_size(), 190) | ||
| self.assertEqual(bq.qsize(), 2) | ||
| # Putting another would exceed 200. | ||
| with self.assertRaises(queue.Full): | ||
| bq.put(FakeItem(20), block=False) | ||
| bq.put(FakeItem(10), block=False) | ||
| self.assertEqual(bq.byte_size(), 200) | ||
| self.assertEqual(bq.qsize(), 3) | ||
|
|
||
| self.assertEqual(bq.get().weight(), 50) | ||
| self.assertEqual(bq.byte_size(), 150) | ||
| self.assertEqual(bq.qsize(), 2) | ||
| bq.put(FakeItem(20), block=False) | ||
|
|
||
| def test_dual_limit(self): | ||
| # Queue limits: at most 2 items, OR at most 100 weight. | ||
| bq = ByteLimitedQueue(lambda x: x.weight(), maxsize=3, maxweight=100) | ||
| bq.put(FakeItem(30)) | ||
| bq.put(FakeItem(40)) | ||
| bq.put(FakeItem(20)) | ||
| self.assertEqual(bq.byte_size(), 90) | ||
| self.assertEqual(bq.qsize(), 3) | ||
| # Full on element count (size=2). | ||
| with self.assertRaises(queue.Full): | ||
| bq.put(FakeItem(10), block=False) | ||
| self.assertEqual(bq.get().weight(), 30) | ||
| self.assertEqual(bq.get().weight(), 40) | ||
| bq.put(FakeItem(10)) | ||
| # Full on byte count | ||
| with self.assertRaises(queue.Full): | ||
| bq.put(FakeItem(90), block=False) | ||
| self.assertEqual(bq.get().weight(), 20) | ||
| bq.put(FakeItem(90), block=False) | ||
|
|
||
| @unittest.skipIf(sys.version_info < (3, 13), 'Queue.ShutDown added in 3.13.') | ||
| def test_multithreading(self): | ||
| bq = ByteLimitedQueue(lambda x: x.weight(), maxsize=0, maxweight=100) | ||
| received = [] | ||
|
|
||
| def producer(): | ||
| for i in range(101): | ||
| bq.put(FakeItem(i)) | ||
|
|
||
| def consumer(): | ||
| while True: | ||
| try: | ||
| received.append(bq.get().weight()) | ||
| except queue.ShutDown: | ||
| break | ||
|
|
||
| t1 = threading.Thread(target=producer) | ||
| t2 = threading.Thread(target=producer) | ||
| t3 = threading.Thread(target=consumer) | ||
|
|
||
| t1.start() | ||
| t2.start() | ||
| t3.start() | ||
|
|
||
| t1.join() | ||
| t2.join() | ||
| bq.shutdown() | ||
|
|
||
| t3.join() | ||
|
|
||
| self.assertEqual(len(received), 202) | ||
| self.assertEqual(sum(received), 2 * sum(range(101))) | ||
|
|
||
| def test_multithreading_timeout(self): | ||
| bq = ByteLimitedQueue(lambda x: x.weight(), maxsize=0, maxweight=10) | ||
| bq.put(FakeItem(10)) | ||
|
|
||
| # The queue is completely full. A timeout put should raise queue.Full. | ||
| with self.assertRaises(queue.Full): | ||
| bq.put(FakeItem(5), timeout=0.01) | ||
|
|
||
| def delayed_consumer(): | ||
| time.sleep(0.05) | ||
| bq.get() | ||
|
|
||
| # Start a thread that will free up space after 50ms. | ||
| t = threading.Thread(target=delayed_consumer) | ||
| t.start() | ||
|
|
||
| # The put should succeed once the consumer runs, use a high timeout to | ||
| # flakiness. | ||
| bq.put(FakeItem(5), timeout=60) | ||
| t.join() | ||
|
|
||
| def test_negative_timeout(self): | ||
| bq = ByteLimitedQueue(lambda x: x.weight()) | ||
| # Putting an item with a negative timeout should raise ValueError. | ||
| with self.assertRaises(ValueError): | ||
| bq.put(FakeItem(5), timeout=-1) | ||
|
|
||
| def test_single_element_override(self): | ||
| bq = ByteLimitedQueue(lambda x: x.weight(), maxweight=10) | ||
| # An item of size 50 exceeds maxweight 10, but should be admitted | ||
| # immediately without blocking since the queue is currently empty! | ||
| bq.put(FakeItem(50), block=False) | ||
| self.assertEqual(bq.qsize(), 1) | ||
| self.assertEqual(bq.byte_size(), 50) | ||
|
|
||
| def test_inconsistent_weighing_fn(self): | ||
| # Return a different weight for the same item. | ||
| weights = [10, 5] | ||
| bq = ByteLimitedQueue(lambda x: weights.pop(0), maxweight=100) | ||
|
|
||
| bq.put(1) | ||
| self.assertEqual(bq.byte_size(), 10) | ||
|
|
||
| # Upon popping, the weighing function (if called) would have returned 5, | ||
| # but the stored weight prevents corruption and cleanly reduces the size to | ||
| # 0. | ||
| bq.get() | ||
| self.assertEqual(bq.byte_size(), 0) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this because it is simpler or because we don't want to shutdown earlier to avoid data loss?
we could also raise NotImplemented error on shutdown if we don't want to implement this contract.