diff --git a/.github/actions/spelling/expect.txt b/.github/actions/spelling/expect.txt new file mode 100644 index 000000000..abf7a6f71 --- /dev/null +++ b/.github/actions/spelling/expect.txt @@ -0,0 +1 @@ +datapart diff --git a/src/a2a/contrib/tasks/vertex_task_converter.py b/src/a2a/contrib/tasks/vertex_task_converter.py index 5015211c7..16820a55f 100644 --- a/src/a2a/contrib/tasks/vertex_task_converter.py +++ b/src/a2a/contrib/tasks/vertex_task_converter.py @@ -11,13 +11,18 @@ import base64 import json +from dataclasses import dataclass +from typing import Any + from a2a.types import ( Artifact, DataPart, FilePart, FileWithBytes, FileWithUri, + Message, Part, + Role, Task, TaskState, TaskStatus, @@ -25,6 +30,16 @@ ) +_ORIGINAL_METADATA_KEY = 'originalMetadata' +_EXTENSIONS_KEY = 'extensions' +_REFERENCE_TASK_IDS_KEY = 'referenceTaskIds' +_PART_METADATA_KEY = 'partMetadata' +_METADATA_VERSION_KEY = '__vertex_compat_v' +_METADATA_VERSION_NUMBER = 1.0 + +_DATA_PART_MIME_TYPE = 'application/x-a2a-datapart' + + _TO_SDK_TASK_STATE = { vertexai_types.A2aTaskState.STATE_UNSPECIFIED: TaskState.unknown, vertexai_types.A2aTaskState.SUBMITTED: TaskState.submitted, @@ -52,6 +67,55 @@ def to_stored_task_state(task_state: TaskState) -> vertexai_types.A2aTaskState: ) +def to_stored_metadata( + original_metadata: dict[str, Any] | None, + extensions: list[str] | None, + reference_task_ids: list[str] | None, + parts: list[Part], +) -> dict[str, Any]: + """Packs original metadata, extensions, and part types/metadata into a storage dictionary.""" + metadata: dict[str, Any] = {_METADATA_VERSION_KEY: _METADATA_VERSION_NUMBER} + if original_metadata: + metadata[_ORIGINAL_METADATA_KEY] = original_metadata + if extensions: + metadata[_EXTENSIONS_KEY] = extensions + if reference_task_ids: + metadata[_REFERENCE_TASK_IDS_KEY] = reference_task_ids + + metadata[_PART_METADATA_KEY] = [part.root.metadata for part in parts] + + return metadata + + +@dataclass +class _UnpackedMetadata: + original_metadata: dict[str, Any] | None = None + extensions: list[str] | None = None + reference_task_ids: list[str] | None = None + part_metadata: list[dict[str, Any] | None] | None = None + + +def to_sdk_metadata( + stored_metadata: dict[str, Any] | None, +) -> _UnpackedMetadata: + """Unpacks metadata, extensions, and part types/metadata from a storage dictionary.""" + if not stored_metadata: + return _UnpackedMetadata() + + version = stored_metadata.get(_METADATA_VERSION_KEY) + if version is None: + return _UnpackedMetadata(original_metadata=stored_metadata) + if version > _METADATA_VERSION_NUMBER: + raise ValueError(f'Unsupported metadata version: {version}') + + return _UnpackedMetadata( + original_metadata=stored_metadata.get(_ORIGINAL_METADATA_KEY), + extensions=stored_metadata.get(_EXTENSIONS_KEY), + reference_task_ids=stored_metadata.get(_REFERENCE_TASK_IDS_KEY), + part_metadata=stored_metadata.get(_PART_METADATA_KEY), + ) + + def to_stored_part(part: Part) -> genai_types.Part: """Converts a SDK Part to a proto Part.""" if isinstance(part.root, TextPart): @@ -60,7 +124,7 @@ def to_stored_part(part: Part) -> genai_types.Part: data_bytes = json.dumps(part.root.data).encode('utf-8') return genai_types.Part( inline_data=genai_types.Blob( - mime_type='application/json', data=data_bytes + mime_type=_DATA_PART_MIME_TYPE, data=data_bytes ) ) if isinstance(part.root, FilePart): @@ -82,20 +146,31 @@ def to_stored_part(part: Part) -> genai_types.Part: raise ValueError(f'Unsupported part type: {type(part.root)}') -def to_sdk_part(stored_part: genai_types.Part) -> Part: +def to_sdk_part( + stored_part: genai_types.Part, + part_metadata: dict[str, Any] | None = None, +) -> Part: """Converts a proto Part to a SDK Part.""" if stored_part.text: - return Part(root=TextPart(text=stored_part.text)) + return Part( + root=TextPart(text=stored_part.text, metadata=part_metadata) + ) if stored_part.inline_data: + mime_type = stored_part.inline_data.mime_type + if mime_type == _DATA_PART_MIME_TYPE: + data_dict = json.loads(stored_part.inline_data.data or b'{}') + return Part(root=DataPart(data=data_dict, metadata=part_metadata)) + encoded_bytes = base64.b64encode( stored_part.inline_data.data or b'' ).decode('utf-8') return Part( root=FilePart( file=FileWithBytes( - mime_type=stored_part.inline_data.mime_type, + mime_type=mime_type, bytes=encoded_bytes, - ) + ), + metadata=part_metadata, ) ) if stored_part.file_data: @@ -103,8 +178,9 @@ def to_sdk_part(stored_part: genai_types.Part) -> Part: root=FilePart( file=FileWithUri( mime_type=stored_part.file_data.mime_type, - uri=stored_part.file_data.file_uri, - ) + uri=stored_part.file_data.file_uri or '', + ), + metadata=part_metadata, ) ) @@ -115,15 +191,83 @@ def to_stored_artifact(artifact: Artifact) -> vertexai_types.TaskArtifact: """Converts a SDK Artifact to a proto TaskArtifact.""" return vertexai_types.TaskArtifact( artifact_id=artifact.artifact_id, + display_name=artifact.name, + description=artifact.description, parts=[to_stored_part(part) for part in artifact.parts], + metadata=to_stored_metadata( + original_metadata=artifact.metadata, + extensions=artifact.extensions, + reference_task_ids=None, + parts=artifact.parts, + ), ) def to_sdk_artifact(stored_artifact: vertexai_types.TaskArtifact) -> Artifact: """Converts a proto TaskArtifact to a SDK Artifact.""" + unpacked_meta = to_sdk_metadata(stored_artifact.metadata) + part_metadata_list = unpacked_meta.part_metadata or [] + + parts = [] + for i, part in enumerate(stored_artifact.parts or []): + meta: dict[str, Any] | None = None + if i < len(part_metadata_list): + meta = part_metadata_list[i] + parts.append(to_sdk_part(part, part_metadata=meta)) + return Artifact( artifact_id=stored_artifact.artifact_id, - parts=[to_sdk_part(part) for part in stored_artifact.parts], + name=stored_artifact.display_name, + description=stored_artifact.description, + extensions=unpacked_meta.extensions, + metadata=unpacked_meta.original_metadata, + parts=parts, + ) + + +def to_stored_message( + message: Message | None, +) -> vertexai_types.TaskMessage | None: + """Converts a SDK Message to a proto Message.""" + if not message: + return None + role = message.role.value if message.role else '' + return vertexai_types.TaskMessage( + message_id=message.message_id, + role=role, + parts=[to_stored_part(part) for part in message.parts], + metadata=to_stored_metadata( + original_metadata=message.metadata, + extensions=message.extensions, + reference_task_ids=message.reference_task_ids, + parts=message.parts, + ), + ) + + +def to_sdk_message( + stored_msg: vertexai_types.TaskMessage | None, +) -> Message | None: + """Converts a proto Message to a SDK Message.""" + if not stored_msg: + return None + unpacked_meta = to_sdk_metadata(stored_msg.metadata) + part_metadata_list = unpacked_meta.part_metadata or [] + + parts = [] + for i, part in enumerate(stored_msg.parts or []): + part_metadata: dict[str, Any] | None = None + if i < len(part_metadata_list): + part_metadata = part_metadata_list[i] + parts.append(to_sdk_part(part, part_metadata=part_metadata)) + + return Message( + message_id=stored_msg.message_id, + role=Role(stored_msg.role), + extensions=unpacked_meta.extensions, + reference_task_ids=unpacked_meta.reference_task_ids, + metadata=unpacked_meta.original_metadata, + parts=parts, ) @@ -133,6 +277,11 @@ def to_stored_task(task: Task) -> vertexai_types.A2aTask: context_id=task.context_id, metadata=task.metadata, state=to_stored_task_state(task.status.state), + status_details=vertexai_types.TaskStatusDetails( + task_message=to_stored_message(task.status.message) + ) + if task.status.message + else None, output=vertexai_types.TaskOutput( artifacts=[ to_stored_artifact(artifact) @@ -144,10 +293,14 @@ def to_stored_task(task: Task) -> vertexai_types.A2aTask: def to_sdk_task(a2a_task: vertexai_types.A2aTask) -> Task: """Converts a proto A2aTask to a SDK Task.""" + msg: Message | None = None + if a2a_task.status_details and a2a_task.status_details.task_message: + msg = to_sdk_message(a2a_task.status_details.task_message) + return Task( id=a2a_task.name.split('/')[-1], context_id=a2a_task.context_id, - status=TaskStatus(state=to_sdk_task_state(a2a_task.state)), + status=TaskStatus(state=to_sdk_task_state(a2a_task.state), message=msg), metadata=a2a_task.metadata or {}, artifacts=[ to_sdk_artifact(artifact) diff --git a/src/a2a/contrib/tasks/vertex_task_store.py b/src/a2a/contrib/tasks/vertex_task_store.py index 2612d6105..5ba9147f5 100644 --- a/src/a2a/contrib/tasks/vertex_task_store.py +++ b/src/a2a/contrib/tasks/vertex_task_store.py @@ -80,6 +80,32 @@ def _get_status_change_event( ) return None + def _get_status_details_change_event( + self, + previous_task: Task, + task: Task, + event_sequence_number: int, + ) -> vertexai_types.TaskEvent | None: + if task.status.message != previous_task.status.message: + status_details = ( + vertexai_types.TaskStatusDetails( + task_message=vertex_task_converter.to_stored_message( + task.status.message + ) + ) + if task.status.message + else vertexai_types.TaskStatusDetails() + ) + return vertexai_types.TaskEvent( + event_data=vertexai_types.TaskEventData( + status_details_change=vertexai_types.TaskStatusDetailsChange( + new_task_status=status_details, + ), + ), + event_sequence_number=event_sequence_number, + ) + return None + def _get_metadata_change_event( self, previous_task: Task, task: Task, event_sequence_number: int ) -> vertexai_types.TaskEvent | None: @@ -158,6 +184,13 @@ async def _update( events.append(status_event) event_sequence_number += 1 + status_details_event = self._get_status_details_change_event( + previous_task, task, event_sequence_number + ) + if status_details_event: + events.append(status_details_event) + event_sequence_number += 1 + metadata_event = self._get_metadata_change_event( previous_task, task, event_sequence_number ) diff --git a/tests/contrib/tasks/fake_vertex_client.py b/tests/contrib/tasks/fake_vertex_client.py index 86d14ede0..8a4a86903 100644 --- a/tests/contrib/tasks/fake_vertex_client.py +++ b/tests/contrib/tasks/fake_vertex_client.py @@ -36,6 +36,12 @@ async def append( data = event.event_data if getattr(data, 'state_change', None): task.state = getattr(data.state_change, 'new_state', task.state) + if getattr(data, 'status_details_change', None): + task.status_details = getattr( + data.status_details_change, + 'new_task_status', + getattr(task, 'status_details', None), + ) if getattr(data, 'metadata_change', None): task.metadata = getattr( data.metadata_change, 'new_metadata', task.metadata diff --git a/tests/contrib/tasks/test_vertex_task_converter.py b/tests/contrib/tasks/test_vertex_task_converter.py index de6ae8cd6..4c2cec9d7 100644 --- a/tests/contrib/tasks/test_vertex_task_converter.py +++ b/tests/contrib/tasks/test_vertex_task_converter.py @@ -9,11 +9,14 @@ from vertexai import types as vertexai_types from google.genai import types as genai_types from a2a.contrib.tasks.vertex_task_converter import ( + _DATA_PART_MIME_TYPE, to_sdk_artifact, + to_sdk_message, to_sdk_part, to_sdk_task, to_sdk_task_state, to_stored_artifact, + to_stored_message, to_stored_part, to_stored_task, to_stored_task_state, @@ -24,7 +27,9 @@ FilePart, FileWithBytes, FileWithUri, + Message, Part, + Role, Task, TaskState, TaskStatus, @@ -123,7 +128,7 @@ def test_to_stored_part_data() -> None: sdk_part = Part(root=DataPart(data={'key': 'value'})) stored_part = to_stored_part(sdk_part) assert stored_part.inline_data is not None - assert stored_part.inline_data.mime_type == 'application/json' + assert stored_part.inline_data.mime_type == _DATA_PART_MIME_TYPE assert stored_part.inline_data.data == b'{"key": "value"}' @@ -190,6 +195,18 @@ def test_to_sdk_part_inline_data() -> None: assert sdk_part.root.file.bytes == expected_b64 +def test_to_sdk_part_inline_data_datapart() -> None: + stored_part = genai_types.Part( + inline_data=genai_types.Blob( + mime_type=_DATA_PART_MIME_TYPE, + data=b'{"key": "val"}', + ) + ) + sdk_part = to_sdk_part(stored_part) + assert isinstance(sdk_part.root, DataPart) + assert sdk_part.root.data == {'key': 'val'} + + def test_to_sdk_part_file_data() -> None: stored_part = genai_types.Part( file_data=genai_types.FileData( @@ -313,23 +330,11 @@ def test_sdk_part_text_conversion_round_trip() -> None: def test_sdk_part_data_conversion_round_trip() -> None: - # A DataPart is converted to `inline_data` in Vertex AI, which lacks the original - # `DataPart` vs `FilePart` distinction. When reading it back from the stored - # protocol format, it becomes a `FilePart` with base64-encoded `FileWithBytes` - # and `mime_type="application/json"`. sdk_part = Part(root=DataPart(data={'key': 'value'})) stored_part = to_stored_part(sdk_part) - round_trip_sdk_part = to_sdk_part(stored_part) + round_trip_sdk_part = to_sdk_part(stored_part, part_metadata=None) - expected_b64 = base64.b64encode(b'{"key": "value"}').decode('utf-8') - assert round_trip_sdk_part == Part( - root=FilePart( - file=FileWithBytes( - bytes=expected_b64, - mime_type='application/json', - ) - ) - ) + assert round_trip_sdk_part == sdk_part def test_sdk_part_file_bytes_conversion_round_trip() -> None: @@ -361,16 +366,6 @@ def test_sdk_part_file_uri_conversion_round_trip() -> None: assert round_trip_sdk_part == sdk_part -def test_sdk_artifact_conversion_round_trip() -> None: - sdk_artifact = Artifact( - artifact_id='art-123', - parts=[Part(root=TextPart(text='part_1'))], - ) - stored_artifact = to_stored_artifact(sdk_artifact) - round_trip_sdk_artifact = to_sdk_artifact(stored_artifact) - assert round_trip_sdk_artifact == sdk_artifact - - def test_sdk_task_conversion_round_trip() -> None: sdk_task = Task( id='task-1', @@ -403,3 +398,88 @@ def test_sdk_task_conversion_round_trip() -> None: assert round_trip_sdk_task.metadata == sdk_task.metadata assert round_trip_sdk_task.artifacts == sdk_task.artifacts assert round_trip_sdk_task.history == [] + + +def test_stored_artifact_conversion_round_trip() -> None: + """Test converting an Artifact to TaskArtifact and back restores everything.""" + original_artifact = Artifact( + artifact_id='art123', + name='My cool artifact', + description='A very interesting description', + extensions=['ext1', 'ext2'], + metadata={'custom': 'value'}, + parts=[ + Part( + root=TextPart( + text='hello', metadata={'part_meta': 'hello_meta'} + ) + ), + Part(root=DataPart(data={'foo': 'bar'})), # no metadata + ], + ) + + stored = to_stored_artifact(original_artifact) + assert isinstance(stored, vertexai_types.TaskArtifact) + + # ensure it was populated correctly + assert stored.display_name == 'My cool artifact' + assert stored.description == 'A very interesting description' + assert stored.metadata['__vertex_compat_v'] == 1.0 + + restored_artifact = to_sdk_artifact(stored) + + assert restored_artifact.artifact_id == original_artifact.artifact_id + assert restored_artifact.name == original_artifact.name + assert restored_artifact.description == original_artifact.description + assert restored_artifact.extensions == original_artifact.extensions + assert restored_artifact.metadata == original_artifact.metadata + + assert len(restored_artifact.parts) == 2 + assert isinstance(restored_artifact.parts[0].root, TextPart) + assert restored_artifact.parts[0].root.text == 'hello' + assert restored_artifact.parts[0].root.metadata == { + 'part_meta': 'hello_meta' + } + + assert isinstance(restored_artifact.parts[1].root, DataPart) + assert restored_artifact.parts[1].root.data == {'foo': 'bar'} + assert restored_artifact.parts[1].root.metadata is None + + +def test_stored_message_conversion_round_trip() -> None: + """Test converting a Message to TaskMessage and back restores everything.""" + original_message = Message( + message_id='msg456', + role=Role.agent, + reference_task_ids=['tsk2', 'tsk3'], + extensions=['ext_msg'], + metadata={'msg_meta': 42}, + parts=[ + Part(root=TextPart(text='message text')), + ], + ) + + stored = to_stored_message(original_message) + assert stored is not None + assert isinstance(stored, vertexai_types.TaskMessage) + + assert stored.message_id == 'msg456' + assert stored.role == 'agent' + assert stored.metadata['__vertex_compat_v'] == 1.0 + + restored_message = to_sdk_message(stored) + assert restored_message is not None + + assert restored_message.message_id == original_message.message_id + assert restored_message.role == original_message.role + assert ( + restored_message.reference_task_ids + == original_message.reference_task_ids + ) + assert restored_message.extensions == original_message.extensions + assert restored_message.metadata == original_message.metadata + + assert len(restored_message.parts) == 1 + assert isinstance(restored_message.parts[0].root, TextPart) + assert restored_message.parts[0].root.text == 'message text' + assert restored_message.parts[0].root.metadata is None diff --git a/tests/contrib/tasks/test_vertex_task_store.py b/tests/contrib/tasks/test_vertex_task_store.py index fbcbc37f4..ed99c09bb 100644 --- a/tests/contrib/tasks/test_vertex_task_store.py +++ b/tests/contrib/tasks/test_vertex_task_store.py @@ -63,7 +63,9 @@ def backend_type(request) -> str: from a2a.contrib.tasks.vertex_task_store import VertexTaskStore from a2a.types import ( Artifact, + Message, Part, + Role, Task, TaskState, TaskStatus, @@ -504,3 +506,67 @@ async def test_metadata_field_mapping( retrieved_none = await vertex_store.get('task-metadata-test-4') assert retrieved_none is not None assert retrieved_none.metadata == {} + + +@pytest.mark.asyncio +async def test_update_task_status_details( + vertex_store: VertexTaskStore, +) -> None: + """Test updating an existing task by changing the status details (message) with part metadata.""" + task_id = 'update-test-task-status-details' + original_task = Task( + id=task_id, + context_id='session-update', + status=TaskStatus(state=TaskState.submitted), + kind='task', + metadata=None, + artifacts=[], + history=[], + ) + await vertex_store.save(original_task) + + retrieved_before_update = await vertex_store.get(task_id) + assert retrieved_before_update is not None + assert retrieved_before_update.status.message is None + + updated_task = original_task.model_copy(deep=True) + updated_task.status.state = TaskState.failed + updated_task.status.timestamp = '2023-01-02T11:00:00Z' + updated_task.status.message = Message( + message_id='msg-error-1', + role=Role.agent, + parts=[ + Part( + root=TextPart( + text='Task failed due to an unknown error', + metadata={'error_code': 'UNKNOWN', 'retryable': False}, + ) + ) + ], + ) + + await vertex_store.save(updated_task) + + retrieved_after_update = await vertex_store.get(task_id) + assert retrieved_after_update is not None + assert retrieved_after_update.status.state == TaskState.failed + assert retrieved_after_update.status.message is not None + assert retrieved_after_update.status.message.message_id == 'msg-error-1' + assert retrieved_after_update.status.message.role == Role.agent + assert len(retrieved_after_update.status.message.parts) == 1 + + assert isinstance( + retrieved_after_update.status.message.parts[0].root, TextPart + ) + text_part = retrieved_after_update.status.message.parts[0].root + assert text_part.text == 'Task failed due to an unknown error' + assert text_part.metadata == {'error_code': 'UNKNOWN', 'retryable': False} + + # Also test clearing the message + cleared_task = updated_task.model_copy(deep=True) + cleared_task.status.message = None + + await vertex_store.save(cleared_task) + retrieved_cleared = await vertex_store.get(task_id) + assert retrieved_cleared is not None + assert retrieved_cleared.status.message is None