Skip to content

Commit c0bdc78

Browse files
committed
Add test cases for asgi_middleware
1 parent a5f1bf6 commit c0bdc78

1 file changed

Lines changed: 166 additions & 0 deletions

File tree

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import inspect
2+
3+
import pytest
4+
from unittest.mock import AsyncMock, patch, MagicMock
5+
6+
from aikido_zen.background_process.commands import process_check_firewall_lists
7+
from aikido_zen.context import current_context
8+
from aikido_zen.thread.thread_cache import get_cache
9+
from aikido_zen.storage.firewall_lists import FirewallLists
10+
from aikido_zen.sources.functions.asgi_middleware import InternalASGIMiddleware
11+
from aikido_zen.vulnerabilities.attack_wave_detection.attack_wave_detector import (
12+
AttackWaveDetector,
13+
)
14+
15+
16+
# --- Fixtures ---
17+
@pytest.fixture(autouse=True)
18+
def run_around_tests():
19+
get_cache().reset()
20+
current_context.set(None)
21+
yield
22+
get_cache().reset()
23+
current_context.set(None)
24+
25+
26+
TEST_ASGI_SCOPE = {
27+
"type": "http",
28+
"client": ["1.1.1.1"],
29+
"method": "GET",
30+
"headers": [],
31+
"scheme": "http",
32+
"server": "127.0.0.1",
33+
"query_string": b"",
34+
}
35+
36+
37+
@pytest.fixture(autouse=True)
38+
def mock_asgi_app():
39+
class ASGIMock:
40+
def __init__(self):
41+
self.called = 0
42+
43+
async def app(self, scope, receive, send):
44+
self.called += 1
45+
46+
def was_called(self):
47+
return self.called > 0
48+
49+
return ASGIMock()
50+
51+
52+
class MyMockComms:
53+
def __init__(self):
54+
self.firewall_lists = FirewallLists()
55+
self.conn_manager = MagicMock()
56+
self.conn_manager.firewall_lists = self.firewall_lists
57+
self.conn_manager.attack_wave_detector = AttackWaveDetector()
58+
self.attacks = []
59+
60+
def send_data_to_bg_process(self, action, obj, receive=False, timeout_in_sec=0.1):
61+
if action != "CHECK_FIREWALL_LISTS":
62+
return {"success": False}
63+
res = process_check_firewall_lists(self.conn_manager, obj)
64+
return {
65+
"success": True,
66+
"data": res,
67+
}
68+
69+
70+
def patch_firewall_lists(func):
71+
async def wrapper(*args, **kwargs):
72+
with patch("aikido_zen.background_process.comms.get_comms") as mock_comms:
73+
comms = MyMockComms()
74+
mock_comms.return_value = comms
75+
76+
sig = inspect.signature(func)
77+
if "attacks" in sig.parameters:
78+
kwargs["attacks"] = comms.attacks
79+
if "firewall_lists" in sig.parameters:
80+
kwargs["firewall_lists"] = comms.firewall_lists
81+
82+
return await func(*args, **kwargs)
83+
84+
return wrapper
85+
86+
87+
# --- Test Cases ---
88+
@pytest.mark.asyncio
89+
async def test_middleware_ignores_non_http_scope(mock_asgi_app):
90+
middleware = InternalASGIMiddleware(mock_asgi_app.app, "test_source")
91+
scope = {"type": "websocket"}
92+
receive = AsyncMock()
93+
send = AsyncMock()
94+
95+
assert not mock_asgi_app.was_called()
96+
await middleware(scope, receive, send)
97+
assert mock_asgi_app.was_called()
98+
99+
send.assert_not_called()
100+
receive.assert_not_called()
101+
102+
103+
@pytest.mark.asyncio
104+
async def test_middleware_bypasses_blocked_ip(mock_asgi_app):
105+
middleware = InternalASGIMiddleware(mock_asgi_app.app, "test_source")
106+
scope = {"type": "http", "client": ["192.168.1.1"]}
107+
receive = AsyncMock()
108+
send = AsyncMock()
109+
110+
cache = get_cache()
111+
cache.config.set_bypassed_ips(["192.168.1.1"])
112+
113+
await middleware(scope, receive, send)
114+
assert mock_asgi_app.was_called()
115+
116+
117+
@pytest.mark.asyncio
118+
@patch_firewall_lists
119+
async def test_middleware_blocks_request_if_intercepted(firewall_lists):
120+
firewall_lists.set_blocked_ips(
121+
[{"source": "test", "description": "Blocked for testing", "ips": ["1.1.1.1"]}]
122+
)
123+
124+
async def app(scope, receive, send):
125+
pass
126+
127+
middleware = InternalASGIMiddleware(app, "uvicorn")
128+
scope = {
129+
"type": "http",
130+
"client": ["1.1.1.1"],
131+
"method": "GET",
132+
"headers": [],
133+
"scheme": "http",
134+
"server": "127.0.0.1",
135+
"query_string": b"",
136+
}
137+
receive = AsyncMock()
138+
send = AsyncMock()
139+
140+
await middleware(scope, receive, send)
141+
142+
send.assert_any_call(
143+
{
144+
"type": "http.response.start",
145+
"status": 403,
146+
"headers": [(b"content-type", b"text/plain")],
147+
}
148+
)
149+
send.assert_any_call(
150+
{
151+
"type": "http.response.body",
152+
"body": b"Your IP address is blocked due to Blocked for testing (Your IP: 1.1.1.1)",
153+
"more_body": False,
154+
}
155+
)
156+
157+
158+
@pytest.mark.asyncio
159+
async def test_middleware_increments_total_hits(mock_asgi_app):
160+
middleware = InternalASGIMiddleware(mock_asgi_app.app, "uvicorn")
161+
receive = AsyncMock()
162+
send = AsyncMock()
163+
164+
await middleware(TEST_ASGI_SCOPE, receive, send)
165+
assert get_cache().stats.total_hits == 1
166+
assert mock_asgi_app.was_called()

0 commit comments

Comments
 (0)