Skip to content

Commit 2bde844

Browse files
committed
Add ASYNC430: lint rule for pytest.raises(ExceptionGroup)
Adds a new rule to detect usage of pytest.raises(ExceptionGroup) or pytest.raises(BaseExceptionGroup) in async functions, suggesting the use of pytest.RaisesGroup instead. This is recommended because RaisesGroup provides better support for exception group testing in async contexts, matching the structure of exception groups more accurately. Closes #430
1 parent d7efe31 commit 2bde844

3 files changed

Lines changed: 109 additions & 0 deletions

File tree

flake8_async/visitors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
visitor111,
3939
visitor118,
4040
visitor123,
41+
visitor430,
4142
visitor_utility,
4243
visitors,
4344
)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""Visitor to check for pytest.raises(ExceptionGroup) usage.
2+
3+
ASYNC430: Suggests using pytest.RaisesGroup instead of pytest.raises(ExceptionGroup).
4+
"""
5+
6+
from __future__ import annotations
7+
8+
import ast
9+
from typing import TYPE_CHECKING, Any
10+
11+
from .flake8asyncvisitor import Flake8AsyncVisitor
12+
from .helpers import error_class
13+
14+
if TYPE_CHECKING:
15+
from collections.abc import Mapping
16+
17+
18+
@error_class
19+
class Visitor430(Flake8AsyncVisitor):
20+
error_codes: Mapping[str, str] = {
21+
"ASYNC430": (
22+
"Using `pytest.raises(ExceptionGroup)` is discouraged, consider using "
23+
"`pytest.RaisesGroup` instead."
24+
)
25+
}
26+
27+
def __init__(self, *args: Any, **kwargs: Any):
28+
super().__init__(*args, **kwargs)
29+
self.imports_pytest: bool = False
30+
self.imports_exceptiongroup: bool = False
31+
self.async_function = False
32+
33+
def visit_AsyncFunctionDef(
34+
self, node: ast.AsyncFunctionDef | ast.FunctionDef | ast.Lambda
35+
):
36+
self.save_state(node, "async_function")
37+
self.async_function = isinstance(node, ast.AsyncFunctionDef)
38+
39+
visit_FunctionDef = visit_AsyncFunctionDef
40+
visit_Lambda = visit_AsyncFunctionDef
41+
42+
def visit_Import(self, node: ast.Import) -> None:
43+
for alias in node.names:
44+
if alias.name == "pytest":
45+
self.imports_pytest = True
46+
47+
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
48+
if node.module == "pytest":
49+
self.imports_pytest = True
50+
elif node.module == "builtins" or node.module is None:
51+
# Check for `from builtins import ExceptionGroup`
52+
for alias in node.names:
53+
if alias.name in ("ExceptionGroup", "BaseExceptionGroup"):
54+
self.imports_exceptiongroup = True
55+
56+
def visit_Call(self, node: ast.Call) -> None:
57+
if not self.async_function:
58+
return
59+
60+
func_name = ast.unparse(node.func)
61+
62+
# Check for pytest.raises(ExceptionGroup) or pytest.raises(BaseExceptionGroup)
63+
if not (
64+
func_name == "pytest.raises"
65+
or (self.imports_pytest and func_name == "raises")
66+
):
67+
return
68+
69+
# Check first argument (exception type)
70+
if not node.args:
71+
return
72+
73+
first_arg = node.args[0]
74+
if isinstance(first_arg, ast.Name) and first_arg.id in (
75+
"ExceptionGroup",
76+
"BaseExceptionGroup",
77+
):
78+
self.error(node)
79+
elif isinstance(first_arg, ast.Attribute) and first_arg.attr in (
80+
"ExceptionGroup",
81+
"BaseExceptionGroup",
82+
):
83+
# Handle pytest.raises(pytest.ExceptionGroup) or similar
84+
self.error(node)

tests/eval_files/async430.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# type: ignore
2+
import pytest
3+
4+
5+
async def test_pytest_raises_exceptiongroup():
6+
with pytest.raises(ExceptionGroup): # ASYNC430: 9
7+
pass
8+
9+
10+
async def test_pytest_raises_baseexceptiongroup():
11+
with pytest.raises(BaseExceptionGroup): # ASYNC430: 9
12+
pass
13+
14+
15+
async def test_pytest_raises_other():
16+
# Should not error
17+
with pytest.raises(ValueError):
18+
pass
19+
20+
21+
async def test_pytest_raises_group():
22+
# Should not error - this is what we want users to use
23+
with pytest.RaisesGroup(ExceptionGroup):
24+
pass

0 commit comments

Comments
 (0)