diff --git a/aikido_zen/background_process/cloud_connection_manager/update_service_config.py b/aikido_zen/background_process/cloud_connection_manager/update_service_config.py index e09f2c3d1..8956b8c37 100644 --- a/aikido_zen/background_process/cloud_connection_manager/update_service_config.py +++ b/aikido_zen/background_process/cloud_connection_manager/update_service_config.py @@ -21,6 +21,7 @@ def update_service_config(connection_manager, res): blocked_uids=res.get("blockedUserIds", []), bypassed_ips=res.get("allowedIPAddresses", []), received_any_stats=res.get("receivedAnyStats", True), + excluded_uids_from_rate_limiting=res.get("excludedUserIdsFromRateLimiting", []), ) # Handle outbound request blocking configuration diff --git a/aikido_zen/background_process/cloud_connection_manager/update_service_config_test.py b/aikido_zen/background_process/cloud_connection_manager/update_service_config_test.py index be6ddfed7..4a69879b8 100644 --- a/aikido_zen/background_process/cloud_connection_manager/update_service_config_test.py +++ b/aikido_zen/background_process/cloud_connection_manager/update_service_config_test.py @@ -234,3 +234,48 @@ def test_update_service_config_block_new_outgoing_requests_only(): assert connection_manager.conf.outbound_domains == { "existing.com": "allow" } # Not changed + + +def test_update_service_config_excluded_user_ids_from_rate_limiting(): + """Test that excludedUserIdsFromRateLimiting is correctly applied to config""" + connection_manager = MagicMock() + connection_manager.conf = ServiceConfig( + endpoints=[], + last_updated_at=0, + blocked_uids=set(), + bypassed_ips=[], + received_any_stats=False, + ) + connection_manager.block = False + + res = { + "success": True, + "excludedUserIdsFromRateLimiting": ["user1", "user2"], + } + + update_service_config(connection_manager, res) + + assert connection_manager.conf.excluded_uids_from_rate_limiting == { + "user1", + "user2", + } + + +def test_update_service_config_excluded_user_ids_defaults_to_empty(): + """Test that excluded_uids_from_rate_limiting defaults to empty set when field is absent""" + connection_manager = MagicMock() + connection_manager.conf = ServiceConfig( + endpoints=[], + last_updated_at=0, + blocked_uids=set(), + bypassed_ips=[], + received_any_stats=False, + excluded_uids_from_rate_limiting=["user1"], + ) + connection_manager.block = False + + res = {"success": True} + + update_service_config(connection_manager, res) + + assert connection_manager.conf.excluded_uids_from_rate_limiting == set() diff --git a/aikido_zen/background_process/service_config.py b/aikido_zen/background_process/service_config.py index a72f6dbaf..3c30b70cc 100644 --- a/aikido_zen/background_process/service_config.py +++ b/aikido_zen/background_process/service_config.py @@ -17,10 +17,16 @@ def __init__( blocked_uids, bypassed_ips, received_any_stats: bool, + excluded_uids_from_rate_limiting=None, ): # Init the class using update function : self.update( - endpoints, last_updated_at, blocked_uids, bypassed_ips, received_any_stats + endpoints, + last_updated_at, + blocked_uids, + bypassed_ips, + received_any_stats, + excluded_uids_from_rate_limiting, ) self.block_new_outgoing_requests = False self.outbound_domains = {} @@ -32,10 +38,14 @@ def update( blocked_uids, bypassed_ips, received_any_stats: bool, + excluded_uids_from_rate_limiting=None, ): self.last_updated_at = last_updated_at self.received_any_stats = bool(received_any_stats) self.blocked_uids = set(blocked_uids) + self.excluded_uids_from_rate_limiting = set( + excluded_uids_from_rate_limiting or [] + ) self.set_endpoints(endpoints) self.set_bypassed_ips(bypassed_ips) diff --git a/aikido_zen/background_process/service_config_test.py b/aikido_zen/background_process/service_config_test.py index d34611693..d477032fa 100644 --- a/aikido_zen/background_process/service_config_test.py +++ b/aikido_zen/background_process/service_config_test.py @@ -319,3 +319,49 @@ def test_service_config_with_empty_allowlist(): assert admin_endpoint["route"] == "/admin" assert isinstance(admin_endpoint["allowedIPAddresses"], list) assert len(admin_endpoint["allowedIPAddresses"]) == 0 + + +def test_excluded_uids_from_rate_limiting_defaults_to_empty(): + config = ServiceConfig( + endpoints=[], + last_updated_at=0, + blocked_uids=set(), + bypassed_ips=[], + received_any_stats=False, + ) + assert config.excluded_uids_from_rate_limiting == set() + + +def test_excluded_uids_from_rate_limiting_stored_as_set(): + config = ServiceConfig( + endpoints=[], + last_updated_at=0, + blocked_uids=set(), + bypassed_ips=[], + received_any_stats=False, + excluded_uids_from_rate_limiting=["user1", "user2"], + ) + assert config.excluded_uids_from_rate_limiting == {"user1", "user2"} + + +def test_excluded_uids_from_rate_limiting_updated_via_update(): + config = ServiceConfig( + endpoints=[], + last_updated_at=0, + blocked_uids=set(), + bypassed_ips=[], + received_any_stats=False, + excluded_uids_from_rate_limiting=["user1"], + ) + assert "user1" in config.excluded_uids_from_rate_limiting + + config.update( + endpoints=[], + last_updated_at=0, + blocked_uids=set(), + bypassed_ips=[], + received_any_stats=False, + excluded_uids_from_rate_limiting=["user2", "user3"], + ) + assert config.excluded_uids_from_rate_limiting == {"user2", "user3"} + assert "user1" not in config.excluded_uids_from_rate_limiting diff --git a/aikido_zen/ratelimiting/__init__.py b/aikido_zen/ratelimiting/__init__.py index dcf0927ac..e2d5dc12a 100644 --- a/aikido_zen/ratelimiting/__init__.py +++ b/aikido_zen/ratelimiting/__init__.py @@ -24,6 +24,12 @@ def should_ratelimit_request( max_requests = int(endpoint["rateLimiting"]["maxRequests"]) windows_size_in_ms = int(endpoint["rateLimiting"]["windowSizeInMS"]) + if ( + user + and user.get("id") in connection_manager.conf.excluded_uids_from_rate_limiting + ): + return {"block": False} + if group: allowed = connection_manager.rate_limiter.is_allowed( get_key_for_group(endpoint, group), diff --git a/aikido_zen/ratelimiting/init_test.py b/aikido_zen/ratelimiting/init_test.py index 966d8c7ac..9dfefab4a 100644 --- a/aikido_zen/ratelimiting/init_test.py +++ b/aikido_zen/ratelimiting/init_test.py @@ -15,7 +15,7 @@ def user(): return {"id": "user123"} -def create_connection_manager(endpoints=[], bypassed_ips=[]): +def create_connection_manager(endpoints=[], bypassed_ips=[], excluded_uids=[]): cm = MagicMock() cm.conf = ServiceConfig( endpoints=endpoints, @@ -23,6 +23,7 @@ def create_connection_manager(endpoints=[], bypassed_ips=[]): blocked_uids=[], bypassed_ips=bypassed_ips, received_any_stats=True, + excluded_uids_from_rate_limiting=excluded_uids, ) cm.rate_limiter = RateLimiter( max_items=5000, time_to_live_in_ms=120 * 60 * 1000 # 120 minutes @@ -478,6 +479,79 @@ def test_works_with_multiple_rate_limit_groups_and_different_users(): } +def test_excluded_user_bypasses_user_rate_limit(): + endpoint = { + "method": "POST", + "route": "/login", + "forceProtectionOff": False, + "rateLimiting": { + "enabled": True, + "maxRequests": 3, + "windowSizeInMS": 1000, + }, + } + cm = create_connection_manager([endpoint], excluded_uids=["user123"]) + route_metadata = create_route_metadata() + + # Excluded user should never be blocked, even past maxRequests + for _ in range(5): + assert should_ratelimit_request( + route_metadata, "1.2.3.4", {"id": "user123"}, cm + ) == {"block": False} + + +def test_non_excluded_user_still_rate_limited(): + endpoint = { + "method": "POST", + "route": "/login", + "forceProtectionOff": False, + "rateLimiting": { + "enabled": True, + "maxRequests": 3, + "windowSizeInMS": 1000, + }, + } + cm = create_connection_manager([endpoint], excluded_uids=["other_user"]) + route_metadata = create_route_metadata() + + assert should_ratelimit_request( + route_metadata, "1.2.3.4", {"id": "user123"}, cm + ) == {"block": False} + assert should_ratelimit_request( + route_metadata, "1.2.3.4", {"id": "user123"}, cm + ) == {"block": False} + assert should_ratelimit_request( + route_metadata, "1.2.3.4", {"id": "user123"}, cm + ) == {"block": False} + assert should_ratelimit_request( + route_metadata, "1.2.3.4", {"id": "user123"}, cm + ) == { + "block": True, + "trigger": "user", + } + + +def test_excluded_user_bypasses_group_rate_limit(): + endpoint = { + "method": "POST", + "route": "/login", + "forceProtectionOff": False, + "rateLimiting": { + "enabled": True, + "maxRequests": 3, + "windowSizeInMS": 1000, + }, + } + cm = create_connection_manager([endpoint], excluded_uids=["user123"]) + route_metadata = create_route_metadata() + + # Excluded user should never be blocked, even past maxRequests, even with a group set + for _ in range(5): + assert should_ratelimit_request( + route_metadata, "1.2.3.4", {"id": "user123"}, cm, "group1" + ) == {"block": False} + + def test_rate_limits_by_group_if_user_is_not_set(): cm = create_connection_manager( [ diff --git a/aikido_zen/thread/thread_cache.py b/aikido_zen/thread/thread_cache.py index 0072bac3c..60acb17ab 100644 --- a/aikido_zen/thread/thread_cache.py +++ b/aikido_zen/thread/thread_cache.py @@ -43,6 +43,7 @@ def reset(self): bypassed_ips=[], last_updated_at=-1, received_any_stats=False, + excluded_uids_from_rate_limiting=set(), ) self.middleware_installed = False self.hostnames.clear()