Skip to content
1 change: 1 addition & 0 deletions .github/actions/spelling/expect.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
datapart
171 changes: 162 additions & 9 deletions src/a2a/contrib/tasks/vertex_task_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,35 @@
import base64
import json

from dataclasses import dataclass
from typing import Any

from a2a.types import (
Artifact,
DataPart,
FilePart,
FileWithBytes,
FileWithUri,
Message,
Part,
Role,
Task,
TaskState,
TaskStatus,
TextPart,
)


_ORIGINAL_METADATA_KEY = 'originalMetadata'
_EXTENSIONS_KEY = 'extensions'
_REFERENCE_TASK_IDS_KEY = 'referenceTaskIds'
_PART_METADATA_KEY = 'partMetadata'
_METADATA_VERSION_KEY = '__vertex_compat_v'
_METADATA_VERSION_NUMBER = 1.0

_DATA_PART_MIME_TYPE = 'application/x-a2a-datapart'


_TO_SDK_TASK_STATE = {
vertexai_types.A2aTaskState.STATE_UNSPECIFIED: TaskState.unknown,
vertexai_types.A2aTaskState.SUBMITTED: TaskState.submitted,
Expand Down Expand Up @@ -52,6 +67,55 @@ def to_stored_task_state(task_state: TaskState) -> vertexai_types.A2aTaskState:
)


def to_stored_metadata(
original_metadata: dict[str, Any] | None,
extensions: list[str] | None,
reference_task_ids: list[str] | None,
parts: list[Part],
) -> dict[str, Any]:
"""Packs original metadata, extensions, and part types/metadata into a storage dictionary."""
metadata: dict[str, Any] = {_METADATA_VERSION_KEY: _METADATA_VERSION_NUMBER}
if original_metadata:
metadata[_ORIGINAL_METADATA_KEY] = original_metadata
if extensions:
metadata[_EXTENSIONS_KEY] = extensions
if reference_task_ids:
metadata[_REFERENCE_TASK_IDS_KEY] = reference_task_ids

metadata[_PART_METADATA_KEY] = [part.root.metadata for part in parts]

return metadata


@dataclass
class _UnpackedMetadata:
original_metadata: dict[str, Any] | None = None
extensions: list[str] | None = None
reference_task_ids: list[str] | None = None
part_metadata: list[dict[str, Any] | None] | None = None


def to_sdk_metadata(
stored_metadata: dict[str, Any] | None,
) -> _UnpackedMetadata:
"""Unpacks metadata, extensions, and part types/metadata from a storage dictionary."""
if not stored_metadata:
return _UnpackedMetadata()

version = stored_metadata.get(_METADATA_VERSION_KEY)
if version is None:
return _UnpackedMetadata(original_metadata=stored_metadata)
if version > _METADATA_VERSION_NUMBER:
raise ValueError(f'Unsupported metadata version: {version}')

return _UnpackedMetadata(
original_metadata=stored_metadata.get(_ORIGINAL_METADATA_KEY),
extensions=stored_metadata.get(_EXTENSIONS_KEY),
reference_task_ids=stored_metadata.get(_REFERENCE_TASK_IDS_KEY),
part_metadata=stored_metadata.get(_PART_METADATA_KEY),
)


def to_stored_part(part: Part) -> genai_types.Part:
"""Converts a SDK Part to a proto Part."""
if isinstance(part.root, TextPart):
Expand All @@ -60,7 +124,7 @@ def to_stored_part(part: Part) -> genai_types.Part:
data_bytes = json.dumps(part.root.data).encode('utf-8')
return genai_types.Part(
inline_data=genai_types.Blob(
mime_type='application/json', data=data_bytes
mime_type=_DATA_PART_MIME_TYPE, data=data_bytes
)
)
if isinstance(part.root, FilePart):
Expand All @@ -82,29 +146,41 @@ def to_stored_part(part: Part) -> genai_types.Part:
raise ValueError(f'Unsupported part type: {type(part.root)}')


def to_sdk_part(stored_part: genai_types.Part) -> Part:
def to_sdk_part(
stored_part: genai_types.Part,
part_metadata: dict[str, Any] | None = None,
) -> Part:
"""Converts a proto Part to a SDK Part."""
if stored_part.text:
return Part(root=TextPart(text=stored_part.text))
return Part(
root=TextPart(text=stored_part.text, metadata=part_metadata)
)
if stored_part.inline_data:
mime_type = stored_part.inline_data.mime_type
if mime_type == _DATA_PART_MIME_TYPE:
data_dict = json.loads(stored_part.inline_data.data or b'{}')
return Part(root=DataPart(data=data_dict, metadata=part_metadata))

encoded_bytes = base64.b64encode(
stored_part.inline_data.data or b''
).decode('utf-8')
return Part(
root=FilePart(
file=FileWithBytes(
mime_type=stored_part.inline_data.mime_type,
mime_type=mime_type,
bytes=encoded_bytes,
)
),
metadata=part_metadata,
)
)
if stored_part.file_data:
return Part(
root=FilePart(
file=FileWithUri(
mime_type=stored_part.file_data.mime_type,
uri=stored_part.file_data.file_uri,
)
uri=stored_part.file_data.file_uri or '',
),
metadata=part_metadata,
)
)

Expand All @@ -115,15 +191,83 @@ def to_stored_artifact(artifact: Artifact) -> vertexai_types.TaskArtifact:
"""Converts a SDK Artifact to a proto TaskArtifact."""
return vertexai_types.TaskArtifact(
artifact_id=artifact.artifact_id,
display_name=artifact.name,
description=artifact.description,
parts=[to_stored_part(part) for part in artifact.parts],
metadata=to_stored_metadata(
original_metadata=artifact.metadata,
extensions=artifact.extensions,
reference_task_ids=None,
parts=artifact.parts,
),
)


def to_sdk_artifact(stored_artifact: vertexai_types.TaskArtifact) -> Artifact:
"""Converts a proto TaskArtifact to a SDK Artifact."""
unpacked_meta = to_sdk_metadata(stored_artifact.metadata)
part_metadata_list = unpacked_meta.part_metadata or []

parts = []
for i, part in enumerate(stored_artifact.parts or []):
meta: dict[str, Any] | None = None
if i < len(part_metadata_list):
meta = part_metadata_list[i]
parts.append(to_sdk_part(part, part_metadata=meta))

return Artifact(
artifact_id=stored_artifact.artifact_id,
parts=[to_sdk_part(part) for part in stored_artifact.parts],
name=stored_artifact.display_name,
description=stored_artifact.description,
extensions=unpacked_meta.extensions,
metadata=unpacked_meta.original_metadata,
parts=parts,
)


def to_stored_message(
message: Message | None,
) -> vertexai_types.TaskMessage | None:
"""Converts a SDK Message to a proto Message."""
if not message:
return None
role = message.role.value if message.role else ''
return vertexai_types.TaskMessage(
message_id=message.message_id,
role=role,
parts=[to_stored_part(part) for part in message.parts],
metadata=to_stored_metadata(
original_metadata=message.metadata,
extensions=message.extensions,
reference_task_ids=message.reference_task_ids,
parts=message.parts,
),
)


def to_sdk_message(
stored_msg: vertexai_types.TaskMessage | None,
) -> Message | None:
"""Converts a proto Message to a SDK Message."""
if not stored_msg:
return None
unpacked_meta = to_sdk_metadata(stored_msg.metadata)
part_metadata_list = unpacked_meta.part_metadata or []

parts = []
for i, part in enumerate(stored_msg.parts or []):
part_metadata: dict[str, Any] | None = None
if i < len(part_metadata_list):
part_metadata = part_metadata_list[i]
parts.append(to_sdk_part(part, part_metadata=part_metadata))

return Message(
message_id=stored_msg.message_id,
role=Role(stored_msg.role),
extensions=unpacked_meta.extensions,
reference_task_ids=unpacked_meta.reference_task_ids,
metadata=unpacked_meta.original_metadata,
parts=parts,
)


Expand All @@ -133,6 +277,11 @@ def to_stored_task(task: Task) -> vertexai_types.A2aTask:
context_id=task.context_id,
metadata=task.metadata,
state=to_stored_task_state(task.status.state),
status_details=vertexai_types.TaskStatusDetails(
task_message=to_stored_message(task.status.message)
)
if task.status.message
else None,
output=vertexai_types.TaskOutput(
artifacts=[
to_stored_artifact(artifact)
Expand All @@ -144,10 +293,14 @@ def to_stored_task(task: Task) -> vertexai_types.A2aTask:

def to_sdk_task(a2a_task: vertexai_types.A2aTask) -> Task:
"""Converts a proto A2aTask to a SDK Task."""
msg: Message | None = None
if a2a_task.status_details and a2a_task.status_details.task_message:
msg = to_sdk_message(a2a_task.status_details.task_message)

return Task(
id=a2a_task.name.split('/')[-1],
context_id=a2a_task.context_id,
status=TaskStatus(state=to_sdk_task_state(a2a_task.state)),
status=TaskStatus(state=to_sdk_task_state(a2a_task.state), message=msg),
metadata=a2a_task.metadata or {},
artifacts=[
to_sdk_artifact(artifact)
Expand Down
33 changes: 33 additions & 0 deletions src/a2a/contrib/tasks/vertex_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,32 @@ def _get_status_change_event(
)
return None

def _get_status_details_change_event(
self,
previous_task: Task,
task: Task,
event_sequence_number: int,
) -> vertexai_types.TaskEvent | None:
if task.status.message != previous_task.status.message:
status_details = (
vertexai_types.TaskStatusDetails(
task_message=vertex_task_converter.to_stored_message(
task.status.message
)
)
if task.status.message
else vertexai_types.TaskStatusDetails()
)
return vertexai_types.TaskEvent(
event_data=vertexai_types.TaskEventData(
status_details_change=vertexai_types.TaskStatusDetailsChange(
new_task_status=status_details,
),
),
event_sequence_number=event_sequence_number,
)
return None

def _get_metadata_change_event(
self, previous_task: Task, task: Task, event_sequence_number: int
) -> vertexai_types.TaskEvent | None:
Expand Down Expand Up @@ -158,6 +184,13 @@ async def _update(
events.append(status_event)
event_sequence_number += 1

status_details_event = self._get_status_details_change_event(
previous_task, task, event_sequence_number
)
if status_details_event:
events.append(status_details_event)
event_sequence_number += 1

metadata_event = self._get_metadata_change_event(
previous_task, task, event_sequence_number
)
Expand Down
6 changes: 6 additions & 0 deletions tests/contrib/tasks/fake_vertex_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ async def append(
data = event.event_data
if getattr(data, 'state_change', None):
task.state = getattr(data.state_change, 'new_state', task.state)
if getattr(data, 'status_details_change', None):
task.status_details = getattr(
data.status_details_change,
'new_task_status',
getattr(task, 'status_details', None),
)
if getattr(data, 'metadata_change', None):
task.metadata = getattr(
data.metadata_change, 'new_metadata', task.metadata
Expand Down
Loading
Loading