Skip to content

Commit 677bb71

Browse files
refactor: use method names for JSON-RPC dispatching (#932)
# Description This PR changes the routing logic in the JsonRpcDispatcher, enforcing it on the method name, rather that the object type.
1 parent 2648c5e commit 677bb71

2 files changed

Lines changed: 388 additions & 24 deletions

File tree

src/a2a/server/routes/jsonrpc_dispatcher.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
DefaultServerCallContextBuilder,
3232
ServerCallContextBuilder,
3333
)
34-
from a2a.types import A2ARequest
3534
from a2a.types.a2a_pb2 import (
3635
AgentCard,
3736
CancelTaskRequest,
@@ -349,7 +348,7 @@ async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911,
349348
else:
350349
try:
351350
raw_result = await self._process_non_streaming_request(
352-
request_id, specific_request, call_context
351+
specific_request, call_context
353352
)
354353
handler_result = JSONRPC20Response(
355354
result=raw_result, _id=request_id
@@ -385,7 +384,7 @@ async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911,
385384
async def _process_streaming_request(
386385
self,
387386
request_id: str | int | None,
388-
request_obj: A2ARequest,
387+
request_obj: Any,
389388
context: ServerCallContext,
390389
) -> AsyncGenerator[dict[str, Any], None]:
391390
"""Processes streaming requests (SendStreamingMessage or SubscribeToTask).
@@ -399,11 +398,12 @@ async def _process_streaming_request(
399398
An `AsyncGenerator` object to stream results to the client.
400399
"""
401400
stream: AsyncGenerator | None = None
402-
if isinstance(request_obj, SendMessageRequest):
401+
method = context.state.get('method')
402+
if method == 'SendStreamingMessage':
403403
stream = self.request_handler.on_message_send_stream(
404404
request_obj, context
405405
)
406-
elif isinstance(request_obj, SubscribeToTaskRequest):
406+
elif method == 'SubscribeToTask':
407407
stream = self.request_handler.on_subscribe_to_task(
408408
request_obj, context
409409
)
@@ -538,55 +538,53 @@ async def _handle_get_extended_agent_card(
538538
@validate_version(constants.PROTOCOL_VERSION_1_0)
539539
async def _process_non_streaming_request( # noqa: PLR0911
540540
self,
541-
request_id: str | int | None,
542-
request_obj: A2ARequest,
541+
request_obj: Any,
543542
context: ServerCallContext,
544543
) -> dict[str, Any] | None:
545-
"""Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*).
544+
"""Processes non-streaming requests.
546545
547546
Args:
548-
request_id: The ID of the request.
549547
request_obj: The proto request message.
550548
context: The ServerCallContext for the request.
551549
552550
Returns:
553551
A dict containing the result or error.
554552
"""
555-
match request_obj:
556-
case SendMessageRequest():
553+
method = context.state.get('method')
554+
match method:
555+
case 'SendMessage':
557556
return await self._handle_send_message(request_obj, context)
558-
case CancelTaskRequest():
557+
case 'CancelTask':
559558
return await self._handle_cancel_task(request_obj, context)
560-
case GetTaskRequest():
559+
case 'GetTask':
561560
return await self._handle_get_task(request_obj, context)
562-
case ListTasksRequest():
561+
case 'ListTasks':
563562
return await self._handle_list_tasks(request_obj, context)
564-
case TaskPushNotificationConfig():
563+
case 'CreateTaskPushNotificationConfig':
565564
return await self._handle_create_task_push_notification_config(
566565
request_obj, context
567566
)
568-
case GetTaskPushNotificationConfigRequest():
567+
case 'GetTaskPushNotificationConfig':
569568
return await self._handle_get_task_push_notification_config(
570569
request_obj, context
571570
)
572-
case ListTaskPushNotificationConfigsRequest():
571+
case 'ListTaskPushNotificationConfigs':
573572
return await self._handle_list_task_push_notification_configs(
574573
request_obj, context
575574
)
576-
case DeleteTaskPushNotificationConfigRequest():
577-
return await self._handle_delete_task_push_notification_config(
575+
case 'DeleteTaskPushNotificationConfig':
576+
await self._handle_delete_task_push_notification_config(
578577
request_obj, context
579578
)
580-
case GetExtendedAgentCardRequest():
579+
return None
580+
case 'GetExtendedAgentCard':
581581
return await self._handle_get_extended_agent_card(
582582
request_obj, context
583583
)
584584
case _:
585-
logger.error(
586-
'Unhandled validated request type: %s', type(request_obj)
587-
)
585+
logger.error('Unhandled method: %s', method)
588586
raise UnsupportedOperationError(
589-
message=f'Request type {type(request_obj).__name__} is unknown.'
587+
message=f'Method {method} is not supported.'
590588
)
591589

592590
def _create_response(

0 commit comments

Comments
 (0)