|
1 | | -from unittest.mock import AsyncMock, MagicMock |
| 1 | +from unittest.mock import AsyncMock, MagicMock, patch |
2 | 2 |
|
3 | 3 | import pytest |
4 | 4 |
|
@@ -63,39 +63,28 @@ def base_client( |
63 | 63 |
|
64 | 64 | @pytest.mark.asyncio |
65 | 65 | async def test_transport_async_context_manager() -> None: |
66 | | - class TestTransport(ClientTransport): |
67 | | - def __init__(self) -> None: |
68 | | - self.closed = False |
69 | | - |
70 | | - async def close(self) -> None: |
71 | | - self.closed = True |
72 | | - |
73 | | - TestTransport.__abstractmethods__ = set() # type: ignore[attr-defined] |
74 | | - |
75 | | - transport = TestTransport() |
76 | | - async with transport as t: |
77 | | - assert t is transport |
78 | | - |
79 | | - assert transport.closed |
| 66 | + with ( |
| 67 | + patch.object(ClientTransport, '__abstractmethods__', set()), |
| 68 | + patch.object(ClientTransport, 'close', new_callable=AsyncMock), |
| 69 | + ): |
| 70 | + transport = ClientTransport() |
| 71 | + async with transport as t: |
| 72 | + assert t is transport |
| 73 | + transport.close.assert_not_awaited() |
| 74 | + transport.close.assert_awaited_once() |
80 | 75 |
|
81 | 76 |
|
82 | 77 | @pytest.mark.asyncio |
83 | 78 | async def test_transport_async_context_manager_on_exception() -> None: |
84 | | - class TestTransport(ClientTransport): |
85 | | - def __init__(self) -> None: |
86 | | - self.closed = False |
87 | | - |
88 | | - async def close(self) -> None: |
89 | | - self.closed = True |
90 | | - |
91 | | - TestTransport.__abstractmethods__ = set() # type: ignore[attr-defined] |
92 | | - |
93 | | - transport = TestTransport() |
94 | | - with pytest.raises(RuntimeError, match='boom'): |
95 | | - async with transport: |
96 | | - raise RuntimeError('boom') |
97 | | - |
98 | | - assert transport.closed |
| 79 | + with ( |
| 80 | + patch.object(ClientTransport, '__abstractmethods__', set()), |
| 81 | + patch.object(ClientTransport, 'close', new_callable=AsyncMock), |
| 82 | + ): |
| 83 | + transport = ClientTransport() |
| 84 | + with pytest.raises(RuntimeError, match='boom'): |
| 85 | + async with transport: |
| 86 | + raise RuntimeError('boom') |
| 87 | + transport.close.assert_awaited_once() |
99 | 88 |
|
100 | 89 |
|
101 | 90 | @pytest.mark.asyncio |
|
0 commit comments