Skip to content

Commit 76b40e0

Browse files
feat: add async context manager support to ClientTransport - Apply suggestion from reviewer
1 parent ca97b56 commit 76b40e0

1 file changed

Lines changed: 19 additions & 30 deletions

File tree

tests/client/test_base_client.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import AsyncMock, MagicMock
1+
from unittest.mock import AsyncMock, MagicMock, patch
22

33
import pytest
44

@@ -63,39 +63,28 @@ def base_client(
6363

6464
@pytest.mark.asyncio
6565
async def test_transport_async_context_manager() -> None:
66-
class TestTransport(ClientTransport):
67-
def __init__(self) -> None:
68-
self.closed = False
69-
70-
async def close(self) -> None:
71-
self.closed = True
72-
73-
TestTransport.__abstractmethods__ = set() # type: ignore[attr-defined]
74-
75-
transport = TestTransport()
76-
async with transport as t:
77-
assert t is transport
78-
79-
assert transport.closed
66+
with (
67+
patch.object(ClientTransport, '__abstractmethods__', set()),
68+
patch.object(ClientTransport, 'close', new_callable=AsyncMock),
69+
):
70+
transport = ClientTransport()
71+
async with transport as t:
72+
assert t is transport
73+
transport.close.assert_not_awaited()
74+
transport.close.assert_awaited_once()
8075

8176

8277
@pytest.mark.asyncio
8378
async def test_transport_async_context_manager_on_exception() -> None:
84-
class TestTransport(ClientTransport):
85-
def __init__(self) -> None:
86-
self.closed = False
87-
88-
async def close(self) -> None:
89-
self.closed = True
90-
91-
TestTransport.__abstractmethods__ = set() # type: ignore[attr-defined]
92-
93-
transport = TestTransport()
94-
with pytest.raises(RuntimeError, match='boom'):
95-
async with transport:
96-
raise RuntimeError('boom')
97-
98-
assert transport.closed
79+
with (
80+
patch.object(ClientTransport, '__abstractmethods__', set()),
81+
patch.object(ClientTransport, 'close', new_callable=AsyncMock),
82+
):
83+
transport = ClientTransport()
84+
with pytest.raises(RuntimeError, match='boom'):
85+
async with transport:
86+
raise RuntimeError('boom')
87+
transport.close.assert_awaited_once()
9988

10089

10190
@pytest.mark.asyncio

0 commit comments

Comments
 (0)