Skip to content

Commit 97058bb

Browse files
refactor(client)!: remove ClientTaskManager and Consumers from client (#916)
# Description This PR removes the client side TaskManager, as it represent a redundant duplication of the server-side TaskManager, and the client Consumers. Consumers can be replaced with [ClientCallInterceptor](https://github.com/a2aproject/a2a-python/blob/1.0-dev/src/a2a/client/interceptors.py). Fix #734
1 parent 3942c57 commit 97058bb

16 files changed

Lines changed: 68 additions & 517 deletions

itk/main.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,17 +138,19 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
138138
nested_msg = wrap_instruction_to_request(call.instruction)
139139
request = SendMessageRequest(message=nested_msg)
140140

141-
results = []
141+
results: list[str] = []
142142
async for event in client.send_message(request):
143-
# Event is streaming response and task
143+
# Event is StreamResponse
144144
logger.info('Event: %s', event)
145-
stream_resp, task = event
145+
stream_resp = event
146146

147147
message = None
148148
if stream_resp.HasField('message'):
149149
message = stream_resp.message
150-
elif task and task.status.HasField('message'):
151-
message = task.status.message
150+
elif stream_resp.HasField(
151+
'task'
152+
) and stream_resp.task.status.HasField('message'):
153+
message = stream_resp.task.status.message
152154
elif stream_resp.HasField(
153155
'status_update'
154156
) and stream_resp.status_update.status.HasField('message'):

src/a2a/client/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
Client,
1212
ClientCallContext,
1313
ClientConfig,
14-
ClientEvent,
15-
Consumer,
1614
)
1715
from a2a.client.client_factory import ClientFactory, minimal_agent_card
1816
from a2a.client.errors import (
@@ -35,9 +33,7 @@
3533
'ClientCallContext',
3634
'ClientCallInterceptor',
3735
'ClientConfig',
38-
'ClientEvent',
3936
'ClientFactory',
40-
'Consumer',
4137
'CredentialService',
4238
'InMemoryContextCredentialStore',
4339
'create_text_message_object',

src/a2a/client/base_client.py

Lines changed: 12 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55
Client,
66
ClientCallContext,
77
ClientConfig,
8-
ClientEvent,
9-
Consumer,
108
)
11-
from a2a.client.client_task_manager import ClientTaskManager
129
from a2a.client.interceptors import (
1310
AfterArgs,
1411
BeforeArgs,
@@ -42,10 +39,9 @@ def __init__(
4239
card: AgentCard,
4340
config: ClientConfig,
4441
transport: ClientTransport,
45-
consumers: list[Consumer],
4642
interceptors: list[ClientCallInterceptor],
4743
):
48-
super().__init__(consumers, interceptors)
44+
super().__init__(interceptors)
4945
self._card = card
5046
self._config = config
5147
self._transport = transport
@@ -56,7 +52,7 @@ async def send_message(
5652
request: SendMessageRequest,
5753
*,
5854
context: ClientCallContext | None = None,
59-
) -> AsyncIterator[ClientEvent]:
55+
) -> AsyncIterator[StreamResponse]:
6056
"""Sends a message to the agent.
6157
6258
This method handles both streaming and non-streaming (polling) interactions
@@ -68,7 +64,7 @@ async def send_message(
6864
context: Optional client call context.
6965
7066
Yields:
71-
An async iterator of `ClientEvent`
67+
An async iterator of `StreamResponse`
7268
"""
7369
self._apply_client_config(request)
7470
if not self._config.streaming or not self._card.capabilities.streaming:
@@ -84,19 +80,14 @@ async def send_message(
8480
# In non-streaming case we convert to a StreamResponse so that the
8581
# client always sees the same iterator.
8682
stream_response = StreamResponse()
87-
client_event: ClientEvent
8883
if response.HasField('task'):
8984
stream_response.task.CopyFrom(response.task)
90-
client_event = (stream_response, response.task)
9185
elif response.HasField('message'):
9286
stream_response.message.CopyFrom(response.message)
93-
client_event = (stream_response, None)
9487
else:
95-
# Response must have either task or message
9688
raise ValueError('Response has neither task nor message')
9789

98-
await self.consume(client_event, self._card)
99-
yield client_event
90+
yield stream_response
10091
return
10192

10293
async for event in self._execute_stream_with_interceptors(
@@ -130,8 +121,7 @@ async def _process_stream(
130121
self,
131122
stream: AsyncIterator[StreamResponse],
132123
before_args: BeforeArgs,
133-
) -> AsyncGenerator[ClientEvent]:
134-
tracker = ClientTaskManager()
124+
) -> AsyncGenerator[StreamResponse, None]:
135125
async for stream_response in stream:
136126
after_args = AfterArgs(
137127
result=stream_response,
@@ -140,12 +130,8 @@ async def _process_stream(
140130
context=before_args.context,
141131
)
142132
await self._intercept_after(after_args)
143-
intercepted_response = after_args.result
144-
client_event = await self._format_stream_event(
145-
intercepted_response, tracker
146-
)
147-
yield client_event
148-
if intercepted_response.HasField('message'):
133+
yield after_args.result
134+
if after_args.result.HasField('message'):
149135
return
150136

151137
async def get_task(
@@ -318,7 +304,7 @@ async def subscribe(
318304
request: SubscribeToTaskRequest,
319305
*,
320306
context: ClientCallContext | None = None,
321-
) -> AsyncIterator[ClientEvent]:
307+
) -> AsyncIterator[StreamResponse]:
322308
"""Resubscribes to a task's event stream.
323309
324310
This is only available if both the client and server support streaming.
@@ -328,7 +314,7 @@ async def subscribe(
328314
context: Optional client call context.
329315
330316
Yields:
331-
An async iterator of `ClientEvent` objects.
317+
An async iterator of `StreamResponse` objects.
332318
333319
Raises:
334320
NotImplementedError: If streaming is not supported by the client or server.
@@ -436,7 +422,7 @@ async def _execute_stream_with_interceptors(
436422
transport_call: Callable[
437423
[Any, ClientCallContext | None], AsyncIterator[StreamResponse]
438424
],
439-
) -> AsyncIterator[ClientEvent]:
425+
) -> AsyncIterator[StreamResponse]:
440426

441427
before_args = BeforeArgs(
442428
input=input_data,
@@ -446,7 +432,7 @@ async def _execute_stream_with_interceptors(
446432
)
447433
before_result = await self._intercept_before(before_args)
448434

449-
if before_result:
435+
if before_result is not None:
450436
after_args = AfterArgs(
451437
result=before_result['early_return'],
452438
method=method,
@@ -455,8 +441,7 @@ async def _execute_stream_with_interceptors(
455441
)
456442
await self._intercept_after(after_args, before_result['executed'])
457443

458-
tracker = ClientTaskManager()
459-
yield await self._format_stream_event(after_args.result, tracker)
444+
yield after_args.result
460445
return
461446

462447
stream = transport_call(before_args.input, before_args.context)
@@ -495,19 +480,3 @@ async def _intercept_after(
495480
await interceptor.after(args)
496481
if args.early_return:
497482
return
498-
499-
async def _format_stream_event(
500-
self, stream_response: StreamResponse, tracker: ClientTaskManager
501-
) -> ClientEvent:
502-
client_event: ClientEvent
503-
if stream_response.HasField('message'):
504-
client_event = (stream_response, None)
505-
await self.consume(client_event, self._card)
506-
return client_event
507-
508-
await tracker.process(stream_response)
509-
updated_task = tracker.get_task_or_raise()
510-
client_event = (stream_response, updated_task)
511-
512-
await self.consume(client_event, self._card)
513-
return client_event

src/a2a/client/client.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33

44
from abc import ABC, abstractmethod
5-
from collections.abc import AsyncIterator, Callable, Coroutine, MutableMapping
5+
from collections.abc import AsyncIterator, Callable, MutableMapping
66
from types import TracebackType
77
from typing import Any
88

@@ -77,13 +77,6 @@ class ClientConfig:
7777
"""Push notification configurations to use for every request."""
7878

7979

80-
ClientEvent = tuple[StreamResponse, Task | None]
81-
82-
# Alias for an event consuming callback. It takes either a (task, update) pair
83-
# or a message as well as the agent card for the agent this came from.
84-
Consumer = Callable[[ClientEvent, AgentCard], Coroutine[None, Any, Any]]
85-
86-
8780
class ClientCallContext(BaseModel):
8881
"""A context passed with each client call, allowing for call-specific.
8982
@@ -106,16 +99,13 @@ class Client(ABC):
10699

107100
def __init__(
108101
self,
109-
consumers: list[Consumer] | None = None,
110102
interceptors: list[ClientCallInterceptor] | None = None,
111103
):
112-
"""Initializes the client with consumers and interceptors.
104+
"""Initializes the client with interceptors.
113105
114106
Args:
115-
consumers: A list of callables to process events from the agent.
116107
interceptors: A list of interceptors to process requests and responses.
117108
"""
118-
self._consumers = consumers or []
119109
self._interceptors = interceptors or []
120110

121111
async def __aenter__(self) -> Self:
@@ -137,14 +127,12 @@ async def send_message(
137127
request: SendMessageRequest,
138128
*,
139129
context: ClientCallContext | None = None,
140-
) -> AsyncIterator[ClientEvent]:
130+
) -> AsyncIterator[StreamResponse]:
141131
"""Sends a message to the server.
142132
143133
This will automatically use the streaming or non-streaming approach
144134
as supported by the server and the client config. Client will
145-
aggregate update events and return an iterator of (`Task`,`Update`)
146-
pairs, or a `Message`. Client will also send these values to any
147-
configured `Consumer`s in the client.
135+
aggregate update events and return an iterator of `StreamResponse`.
148136
"""
149137
return
150138
yield
@@ -218,7 +206,7 @@ async def subscribe(
218206
request: SubscribeToTaskRequest,
219207
*,
220208
context: ClientCallContext | None = None,
221-
) -> AsyncIterator[ClientEvent]:
209+
) -> AsyncIterator[StreamResponse]:
222210
"""Resubscribes to a task's event stream."""
223211
return
224212
yield
@@ -233,23 +221,10 @@ async def get_extended_agent_card(
233221
) -> AgentCard:
234222
"""Retrieves the agent's card."""
235223

236-
async def add_event_consumer(self, consumer: Consumer) -> None:
237-
"""Attaches additional consumers to the `Client`."""
238-
self._consumers.append(consumer)
239-
240224
async def add_interceptor(self, interceptor: ClientCallInterceptor) -> None:
241225
"""Attaches additional interceptors to the `Client`."""
242226
self._interceptors.append(interceptor)
243227

244-
async def consume(
245-
self,
246-
event: ClientEvent,
247-
card: AgentCard,
248-
) -> None:
249-
"""Processes the event via all the registered `Consumer`s."""
250-
for c in self._consumers:
251-
await c(event, card)
252-
253228
@abstractmethod
254229
async def close(self) -> None:
255230
"""Closes the client and releases any underlying resources."""

src/a2a/client/client_factory.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from a2a.client.base_client import BaseClient
1313
from a2a.client.card_resolver import A2ACardResolver
14-
from a2a.client.client import Client, ClientConfig, Consumer
14+
from a2a.client.client import Client, ClientConfig
1515
from a2a.client.transports.base import ClientTransport
1616
from a2a.client.transports.jsonrpc import JsonRpcTransport
1717
from a2a.client.transports.rest import RestTransport
@@ -63,12 +63,11 @@ class ClientFactory:
6363
6464
.. code-block:: python
6565
66-
factory = ClientFactory(config, consumers)
66+
factory = ClientFactory(config)
6767
# Optionally register custom client implementations
6868
factory.register('my_customer_transport', NewCustomTransportClient)
69-
# Then with an agent card make a client with additional consumers and
70-
# interceptors
71-
client = factory.create(card, additional_consumers, interceptors)
69+
# Then with an agent card make a client with additional interceptors
70+
client = factory.create(card, interceptors)
7271
7372
Now the client can be used consistently regardless of the transport. This
7473
aligns the client configuration with the server's capabilities.
@@ -77,17 +76,12 @@ class ClientFactory:
7776
def __init__(
7877
self,
7978
config: ClientConfig,
80-
consumers: list[Consumer] | None = None,
8179
):
82-
if consumers is None:
83-
consumers = []
84-
8580
client = config.httpx_client or httpx.AsyncClient()
8681
client.headers.setdefault(VERSION_HEADER, PROTOCOL_VERSION_CURRENT)
8782
config.httpx_client = client
8883

8984
self._config = config
90-
self._consumers = consumers
9185
self._registry: dict[str, TransportProducer] = {}
9286
self._register_defaults(config.supported_protocol_bindings)
9387

@@ -263,7 +257,6 @@ async def connect( # noqa: PLR0913
263257
cls,
264258
agent: str | AgentCard,
265259
client_config: ClientConfig | None = None,
266-
consumers: list[Consumer] | None = None,
267260
interceptors: list[ClientCallInterceptor] | None = None,
268261
relative_card_path: str | None = None,
269262
resolver_http_kwargs: dict[str, Any] | None = None,
@@ -286,7 +279,7 @@ async def connect( # noqa: PLR0913
286279
Args:
287280
agent: The base URL of the agent, or the AgentCard to connect to.
288281
client_config: The ClientConfig to use when connecting to the agent.
289-
consumers: A list of `Consumer` methods to pass responses to.
282+
290283
interceptors: A list of interceptors to use for each request. These
291284
are used for things like attaching credentials or http headers
292285
to all outbound requests.
@@ -325,7 +318,7 @@ async def connect( # noqa: PLR0913
325318
factory = cls(client_config)
326319
for label, generator in (extra_transports or {}).items():
327320
factory.register(label, generator)
328-
return factory.create(card, consumers, interceptors)
321+
return factory.create(card, interceptors)
329322

330323
def register(self, label: str, generator: TransportProducer) -> None:
331324
"""Register a new transport producer for a given transport label."""
@@ -334,14 +327,12 @@ def register(self, label: str, generator: TransportProducer) -> None:
334327
def create(
335328
self,
336329
card: AgentCard,
337-
consumers: list[Consumer] | None = None,
338330
interceptors: list[ClientCallInterceptor] | None = None,
339331
) -> Client:
340332
"""Create a new `Client` for the provided `AgentCard`.
341333
342334
Args:
343335
card: An `AgentCard` defining the characteristics of the agent.
344-
consumers: A list of `Consumer` methods to pass responses to.
345336
interceptors: A list of interceptors to use for each request. These
346337
are used for things like attaching credentials or http headers
347338
to all outbound requests.
@@ -381,10 +372,6 @@ def create(
381372
if transport_protocol not in self._registry:
382373
raise ValueError(f'no client available for {transport_protocol}')
383374

384-
all_consumers = self._consumers.copy()
385-
if consumers:
386-
all_consumers.extend(consumers)
387-
388375
transport = self._registry[transport_protocol](
389376
card, selected_interface.url, self._config
390377
)
@@ -398,7 +385,6 @@ def create(
398385
card,
399386
self._config,
400387
transport,
401-
all_consumers,
402388
interceptors or [],
403389
)
404390

0 commit comments

Comments
 (0)