Skip to content

Commit 2648c5e

Browse files
refactor(server)!: add build_user function to DefaultContextBuilder to allow A2A user creation customization (#925)
# Description - Add build_user function to the DefaultContextBuilder to allow user customization - Renamed CallContextBuilder / DefaultCallContextBuilder -> ServerCallContextBuilder / DefaultServerCallContextBuilder - Centralizes HTTP context-building logic into a new a2a.server.routes.common module, eliminating duplication between the JSON-RPC and REST dispatchers. Fixes #924 🦕
1 parent 5d22186 commit 2648c5e

16 files changed

Lines changed: 338 additions & 165 deletions

src/a2a/compat/v0_3/grpc_handler.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from a2a.server.context import ServerCallContext
2424
from a2a.server.request_handlers.grpc_handler import (
2525
_ERROR_CODE_MAP,
26-
CallContextBuilder,
27-
DefaultCallContextBuilder,
26+
DefaultGrpcServerCallContextBuilder,
27+
GrpcServerCallContextBuilder,
2828
)
2929
from a2a.server.request_handlers.request_handler import RequestHandler
3030
from a2a.types.a2a_pb2 import AgentCard
@@ -44,7 +44,7 @@ def __init__(
4444
self,
4545
agent_card: AgentCard,
4646
request_handler: RequestHandler,
47-
context_builder: CallContextBuilder | None = None,
47+
context_builder: GrpcServerCallContextBuilder | None = None,
4848
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
4949
| None = None,
5050
):
@@ -61,7 +61,9 @@ def __init__(
6161
"""
6262
self.agent_card = agent_card
6363
self.handler03 = RequestHandler03(request_handler=request_handler)
64-
self.context_builder = context_builder or DefaultCallContextBuilder()
64+
self._context_builder = (
65+
context_builder or DefaultGrpcServerCallContextBuilder()
66+
)
6567
self.card_modifier = card_modifier
6668

6769
async def _handle_unary(
@@ -72,7 +74,7 @@ async def _handle_unary(
7274
) -> TResponse:
7375
"""Centralized error handling and context management for unary calls."""
7476
try:
75-
server_context = self.context_builder.build(context)
77+
server_context = self._context_builder.build(context)
7678
result = await handler_func(server_context)
7779
self._set_extension_metadata(context, server_context)
7880
except A2AError as e:
@@ -88,7 +90,7 @@ async def _handle_stream(
8890
) -> AsyncIterable[TResponse]:
8991
"""Centralized error handling and context management for streaming calls."""
9092
try:
91-
server_context = self.context_builder.build(context)
93+
server_context = self._context_builder.build(context)
9294
async for item in handler_func(server_context):
9395
yield item
9496
self._set_extension_metadata(context, server_context)

src/a2a/compat/v0_3/jsonrpc_adapter.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from starlette.requests import Request
1212

1313
from a2a.server.request_handlers.request_handler import RequestHandler
14-
from a2a.server.routes import CallContextBuilder
1514
from a2a.types.a2a_pb2 import AgentCard
1615

1716
_package_starlette_installed = True
@@ -38,6 +37,10 @@
3837
from a2a.server.jsonrpc_models import (
3938
JSONRPCError as CoreJSONRPCError,
4039
)
40+
from a2a.server.routes.common import (
41+
DefaultServerCallContextBuilder,
42+
ServerCallContextBuilder,
43+
)
4144
from a2a.utils import constants
4245
from a2a.utils.errors import ExtendedAgentCardNotConfiguredError
4346
from a2a.utils.helpers import maybe_await, validate_version
@@ -67,7 +70,7 @@ def __init__( # noqa: PLR0913
6770
agent_card: 'AgentCard',
6871
http_handler: 'RequestHandler',
6972
extended_agent_card: 'AgentCard | None' = None,
70-
context_builder: 'CallContextBuilder | None' = None,
73+
context_builder: 'ServerCallContextBuilder | None' = None,
7174
card_modifier: 'Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None' = None,
7275
extended_card_modifier: 'Callable[[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard] | None' = None,
7376
):
@@ -78,7 +81,9 @@ def __init__( # noqa: PLR0913
7881
self.handler = RequestHandler03(
7982
request_handler=http_handler,
8083
)
81-
self._context_builder = context_builder
84+
self._context_builder = (
85+
context_builder or DefaultServerCallContextBuilder()
86+
)
8287

8388
def supports_method(self, method: str) -> bool:
8489
"""Returns True if the v0.3 adapter supports the given method name."""
@@ -126,11 +131,7 @@ async def handle_request(
126131
CoreInvalidRequestError(data=str(e)),
127132
)
128133

129-
call_context = (
130-
self._context_builder.build(request)
131-
if self._context_builder
132-
else ServerCallContext()
133-
)
134+
call_context = self._context_builder.build(request)
134135
call_context.tenant = (
135136
getattr(specific_request.params, 'tenant', '')
136137
if hasattr(specific_request, 'params')

src/a2a/compat/v0_3/rest_adapter.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@
3434
from a2a.compat.v0_3 import conversions
3535
from a2a.compat.v0_3.rest_handler import REST03Handler
3636
from a2a.server.context import ServerCallContext
37-
from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder
37+
from a2a.server.routes.common import (
38+
DefaultServerCallContextBuilder,
39+
ServerCallContextBuilder,
40+
)
3841
from a2a.utils.error_handlers import (
3942
rest_error_handler,
4043
rest_stream_error_handler,
@@ -60,7 +63,7 @@ def __init__( # noqa: PLR0913
6063
agent_card: 'AgentCard',
6164
http_handler: 'RequestHandler',
6265
extended_agent_card: 'AgentCard | None' = None,
63-
context_builder: 'CallContextBuilder | None' = None,
66+
context_builder: 'ServerCallContextBuilder | None' = None,
6467
card_modifier: 'Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None' = None,
6568
extended_card_modifier: 'Callable[[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard] | None' = None,
6669
):
@@ -71,7 +74,9 @@ def __init__( # noqa: PLR0913
7174
self.handler = REST03Handler(
7275
agent_card=agent_card, request_handler=http_handler
7376
)
74-
self._context_builder = context_builder or DefaultCallContextBuilder()
77+
self._context_builder = (
78+
context_builder or DefaultServerCallContextBuilder()
79+
)
7580

7681
@rest_error_handler
7782
async def _handle_request(

src/a2a/server/request_handlers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
try:
2121
from a2a.server.request_handlers.grpc_handler import (
22+
DefaultGrpcServerCallContextBuilder,
2223
GrpcHandler, # type: ignore
24+
GrpcServerCallContextBuilder,
2325
)
2426
except ImportError as e:
2527
_original_error = e
@@ -39,8 +41,10 @@ def __init__(self, *args, **kwargs):
3941

4042

4143
__all__ = [
44+
'DefaultGrpcServerCallContextBuilder',
4245
'DefaultRequestHandler',
4346
'GrpcHandler',
47+
'GrpcServerCallContextBuilder',
4448
'RequestHandler',
4549
'build_error_response',
4650
'prepare_response_object',

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import a2a.types.a2a_pb2_grpc as a2a_grpc
2525

2626
from a2a import types
27-
from a2a.auth.user import UnauthenticatedUser
27+
from a2a.auth.user import UnauthenticatedUser, User
2828
from a2a.extensions.common import (
2929
HTTP_EXTENSION_HEADER,
3030
get_requested_extensions,
@@ -41,15 +41,32 @@
4141

4242
logger = logging.getLogger(__name__)
4343

44-
# For now we use a trivial wrapper on the grpc context object
4544

46-
47-
class CallContextBuilder(ABC):
48-
"""A class for building ServerCallContexts using the Starlette Request."""
45+
class GrpcServerCallContextBuilder(ABC):
46+
"""Interface for building ServerCallContext from gRPC context."""
4947

5048
@abstractmethod
5149
def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
52-
"""Builds a ServerCallContext from a gRPC Request."""
50+
"""Builds a ServerCallContext from a gRPC ServicerContext."""
51+
52+
53+
class DefaultGrpcServerCallContextBuilder(GrpcServerCallContextBuilder):
54+
"""Default implementation of GrpcServerCallContextBuilder."""
55+
56+
def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
57+
"""Builds a ServerCallContext from a gRPC ServicerContext."""
58+
state = {'grpc_context': context}
59+
return ServerCallContext(
60+
user=self.build_user(context),
61+
state=state,
62+
requested_extensions=get_requested_extensions(
63+
_get_metadata_value(context, HTTP_EXTENSION_HEADER)
64+
),
65+
)
66+
67+
def build_user(self, context: grpc.aio.ServicerContext) -> User:
68+
"""Builds a User from a gRPC ServicerContext."""
69+
return UnauthenticatedUser()
5370

5471

5572
def _get_metadata_value(
@@ -67,22 +84,6 @@ def _get_metadata_value(
6784
]
6885

6986

70-
class DefaultCallContextBuilder(CallContextBuilder):
71-
"""A default implementation of CallContextBuilder."""
72-
73-
def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
74-
"""Builds the ServerCallContext."""
75-
user = UnauthenticatedUser()
76-
state = {'grpc_context': context}
77-
return ServerCallContext(
78-
user=user,
79-
state=state,
80-
requested_extensions=get_requested_extensions(
81-
_get_metadata_value(context, HTTP_EXTENSION_HEADER)
82-
),
83-
)
84-
85-
8687
_ERROR_CODE_MAP = {
8788
types.InvalidRequestError: grpc.StatusCode.INVALID_ARGUMENT,
8889
types.MethodNotFoundError: grpc.StatusCode.NOT_FOUND,
@@ -110,7 +111,7 @@ def __init__(
110111
self,
111112
agent_card: AgentCard,
112113
request_handler: RequestHandler,
113-
context_builder: CallContextBuilder | None = None,
114+
context_builder: GrpcServerCallContextBuilder | None = None,
114115
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
115116
| None = None,
116117
):
@@ -120,14 +121,17 @@ def __init__(
120121
agent_card: The AgentCard describing the agent's capabilities.
121122
request_handler: The underlying `RequestHandler` instance to
122123
delegate requests to.
123-
context_builder: The CallContextBuilder object. If none the
124-
DefaultCallContextBuilder is used.
124+
context_builder: The GrpcContextBuilder used to construct the
125+
ServerCallContext passed to the request_handler. If None the
126+
DefaultGrpcContextBuilder is used.
125127
card_modifier: An optional callback to dynamically modify the public
126128
agent card before it is served.
127129
"""
128130
self.agent_card = agent_card
129131
self.request_handler = request_handler
130-
self.context_builder = context_builder or DefaultCallContextBuilder()
132+
self._context_builder = (
133+
context_builder or DefaultGrpcServerCallContextBuilder()
134+
)
131135
self.card_modifier = card_modifier
132136

133137
async def _handle_unary(
@@ -451,6 +455,6 @@ def _build_call_context(
451455
context: grpc.aio.ServicerContext,
452456
request: message.Message,
453457
) -> ServerCallContext:
454-
server_context = self.context_builder.build(context)
458+
server_context = self._context_builder.build(context)
455459
server_context.tenant = getattr(request, 'tenant', '')
456460
return server_context

src/a2a/server/routes/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
"""A2A Routes."""
22

33
from a2a.server.routes.agent_card_routes import create_agent_card_routes
4-
from a2a.server.routes.jsonrpc_dispatcher import (
5-
CallContextBuilder,
6-
DefaultCallContextBuilder,
4+
from a2a.server.routes.common import (
5+
DefaultServerCallContextBuilder,
6+
ServerCallContextBuilder,
77
)
88
from a2a.server.routes.jsonrpc_routes import create_jsonrpc_routes
99
from a2a.server.routes.rest_routes import create_rest_routes
1010

1111

1212
__all__ = [
13-
'CallContextBuilder',
14-
'DefaultCallContextBuilder',
13+
'DefaultServerCallContextBuilder',
14+
'ServerCallContextBuilder',
1515
'create_agent_card_routes',
1616
'create_jsonrpc_routes',
1717
'create_rest_routes',

src/a2a/server/routes/common.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from abc import ABC, abstractmethod
2+
from typing import TYPE_CHECKING, Any
3+
4+
5+
if TYPE_CHECKING:
6+
from starlette.authentication import BaseUser
7+
from starlette.requests import Request
8+
else:
9+
try:
10+
from starlette.authentication import BaseUser
11+
from starlette.requests import Request
12+
except ImportError:
13+
Request = Any
14+
BaseUser = Any
15+
16+
from a2a.auth.user import UnauthenticatedUser, User
17+
from a2a.extensions.common import (
18+
HTTP_EXTENSION_HEADER,
19+
get_requested_extensions,
20+
)
21+
from a2a.server.context import ServerCallContext
22+
23+
24+
class StarletteUser(User):
25+
"""Adapts a Starlette BaseUser to the A2A User interface."""
26+
27+
def __init__(self, user: BaseUser):
28+
self._user = user
29+
30+
@property
31+
def is_authenticated(self) -> bool:
32+
"""Returns whether the current user is authenticated."""
33+
return self._user.is_authenticated
34+
35+
@property
36+
def user_name(self) -> str:
37+
"""Returns the user name of the current user."""
38+
return self._user.display_name
39+
40+
41+
class ServerCallContextBuilder(ABC):
42+
"""A class for building ServerCallContexts using the Starlette Request."""
43+
44+
@abstractmethod
45+
def build(self, request: Request) -> ServerCallContext:
46+
"""Builds a ServerCallContext from a Starlette Request."""
47+
48+
49+
class DefaultServerCallContextBuilder(ServerCallContextBuilder):
50+
"""A default implementation of ServerCallContextBuilder."""
51+
52+
def build(self, request: Request) -> ServerCallContext:
53+
"""Builds a ServerCallContext from a Starlette Request.
54+
55+
Args:
56+
request: The incoming Starlette Request object.
57+
58+
Returns:
59+
A ServerCallContext instance populated with user and state
60+
information from the request.
61+
"""
62+
state = {}
63+
if 'auth' in request.scope:
64+
state['auth'] = request.auth
65+
state['headers'] = dict(request.headers)
66+
return ServerCallContext(
67+
user=self.build_user(request),
68+
state=state,
69+
requested_extensions=get_requested_extensions(
70+
request.headers.getlist(HTTP_EXTENSION_HEADER)
71+
),
72+
)
73+
74+
def build_user(self, request: Request) -> User:
75+
"""Builds a User from a Starlette Request.
76+
77+
Args:
78+
request: The incoming Starlette Request object.
79+
80+
Returns:
81+
A User instance populated with user information from the request.
82+
"""
83+
if 'user' in request.scope:
84+
return StarletteUser(request.user)
85+
return UnauthenticatedUser()

0 commit comments

Comments
 (0)