Skip to content

Commit b1e42c8

Browse files
committed
refactor: introduce _UnpackedMetadata
1 parent ed7f6e4 commit b1e42c8

2 files changed

Lines changed: 47 additions & 30 deletions

File tree

src/a2a/contrib/tasks/vertex_task_converter.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import base64
1212
import json
1313

14+
from dataclasses import dataclass
1415
from typing import Any
1516

1617
from a2a.types import (
@@ -92,22 +93,33 @@ def to_stored_metadata(
9293
return metadata
9394

9495

95-
def to_sdk_metadata(stored_metadata: dict[str, Any] | None) -> dict[str, Any]:
96+
@dataclass
97+
class _UnpackedMetadata:
98+
original_metadata: dict[str, Any] | None = None
99+
extensions: list[str] | None = None
100+
reference_task_ids: list[str] | None = None
101+
part_metadata: list[dict[str, Any] | None] | None = None
102+
part_types: list[str] | None = None
103+
104+
105+
def to_sdk_metadata(
106+
stored_metadata: dict[str, Any] | None,
107+
) -> _UnpackedMetadata:
96108
"""Unpacks metadata, extensions, and part types/metadata from a storage dictionary."""
97109
if not stored_metadata:
98-
return {}
110+
return _UnpackedMetadata()
99111

100112
version = stored_metadata.get(_METADATA_VERSION_KEY)
101113
if version is None:
102-
return {'original_metadata': stored_metadata}
103-
104-
return {
105-
'original_metadata': stored_metadata.get(_ORIGINAL_METADATA_KEY),
106-
'extensions': stored_metadata.get(_EXTENSIONS_KEY),
107-
'reference_tasks': stored_metadata.get(_REFERENCE_TASK_IDS_KEY),
108-
'part_metadata': stored_metadata.get(_PART_METADATA_KEY),
109-
'part_types': stored_metadata.get(_PART_TYPES_KEY),
110-
}
114+
return _UnpackedMetadata(original_metadata=stored_metadata)
115+
116+
return _UnpackedMetadata(
117+
original_metadata=stored_metadata.get(_ORIGINAL_METADATA_KEY),
118+
extensions=stored_metadata.get(_EXTENSIONS_KEY),
119+
reference_task_ids=stored_metadata.get(_REFERENCE_TASK_IDS_KEY),
120+
part_metadata=stored_metadata.get(_PART_METADATA_KEY),
121+
part_types=stored_metadata.get(_PART_TYPES_KEY),
122+
)
111123

112124

113125
def to_stored_part(part: Part) -> genai_types.Part:
@@ -173,7 +185,7 @@ def to_sdk_part(
173185
root=FilePart(
174186
file=FileWithUri(
175187
mime_type=stored_part.file_data.mime_type,
176-
uri=stored_part.file_data.file_uri,
188+
uri=stored_part.file_data.file_uri or '',
177189
),
178190
metadata=part_metadata,
179191
)
@@ -201,14 +213,14 @@ def to_stored_artifact(artifact: Artifact) -> vertexai_types.TaskArtifact:
201213
def to_sdk_artifact(stored_artifact: vertexai_types.TaskArtifact) -> Artifact:
202214
"""Converts a proto TaskArtifact to a SDK Artifact."""
203215
unpacked_meta = to_sdk_metadata(stored_artifact.metadata)
204-
part_metadatas = unpacked_meta.get('part_metadata') or []
205-
part_types = unpacked_meta.get('part_types') or []
216+
part_metadata_list = unpacked_meta.part_metadata or []
217+
part_types = unpacked_meta.part_types or []
206218

207219
parts = []
208220
for i, part in enumerate(stored_artifact.parts or []):
209221
meta: dict[str, Any] | None = None
210-
if i < len(part_metadatas):
211-
meta = part_metadatas[i]
222+
if i < len(part_metadata_list):
223+
meta = part_metadata_list[i]
212224
ptype = ''
213225
if i < len(part_types):
214226
ptype = part_types[i]
@@ -218,8 +230,8 @@ def to_sdk_artifact(stored_artifact: vertexai_types.TaskArtifact) -> Artifact:
218230
artifact_id=stored_artifact.artifact_id,
219231
name=stored_artifact.display_name,
220232
description=stored_artifact.description,
221-
extensions=unpacked_meta.get('extensions'),
222-
metadata=unpacked_meta.get('original_metadata'),
233+
extensions=unpacked_meta.extensions,
234+
metadata=unpacked_meta.original_metadata,
223235
parts=parts,
224236
)
225237

@@ -251,25 +263,27 @@ def to_sdk_message(
251263
if not stored_msg:
252264
return None
253265
unpacked_meta = to_sdk_metadata(stored_msg.metadata)
254-
part_metadatas = unpacked_meta.get('part_metadata') or []
255-
part_types = unpacked_meta.get('part_types') or []
266+
part_metadata_list = unpacked_meta.part_metadata or []
267+
part_types = unpacked_meta.part_types or []
256268

257269
parts = []
258270
for i, part in enumerate(stored_msg.parts or []):
259-
meta: dict[str, Any] | None = None
260-
if i < len(part_metadatas):
261-
meta = part_metadatas[i]
262-
ptype = ''
271+
part_metadata: dict[str, Any] | None = None
272+
if i < len(part_metadata_list):
273+
part_metadata = part_metadata_list[i]
274+
part_type = ''
263275
if i < len(part_types):
264-
ptype = part_types[i]
265-
parts.append(to_sdk_part(part, part_metadata=meta, part_type=ptype))
276+
part_type = part_types[i]
277+
parts.append(
278+
to_sdk_part(part, part_metadata=part_metadata, part_type=part_type)
279+
)
266280

267281
return Message(
268282
message_id=stored_msg.message_id,
269283
role=Role(stored_msg.role),
270-
extensions=unpacked_meta.get('extensions'),
271-
reference_task_ids=unpacked_meta.get('reference_tasks'),
272-
metadata=unpacked_meta.get('original_metadata'),
284+
extensions=unpacked_meta.extensions,
285+
reference_task_ids=unpacked_meta.reference_task_ids,
286+
metadata=unpacked_meta.original_metadata,
273287
parts=parts,
274288
)
275289

tests/contrib/tasks/test_vertex_task_converter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
TextPart,
3636
)
3737

38+
3839
def test_to_sdk_task_state() -> None:
3940
assert (
4041
to_sdk_task_state(vertexai_types.A2aTaskState.STATE_UNSPECIFIED)
@@ -318,7 +319,9 @@ def test_sdk_part_text_conversion_round_trip() -> None:
318319
def test_sdk_part_data_conversion_round_trip() -> None:
319320
sdk_part = Part(root=DataPart(data={'key': 'value'}))
320321
stored_part = to_stored_part(sdk_part)
321-
round_trip_sdk_part = to_sdk_part(stored_part, part_metadata=None, part_type='data')
322+
round_trip_sdk_part = to_sdk_part(
323+
stored_part, part_metadata=None, part_type='data'
324+
)
322325

323326
assert round_trip_sdk_part == sdk_part
324327

0 commit comments

Comments
 (0)