1111import base64
1212import json
1313
14+ from dataclasses import dataclass
1415from typing import Any
1516
1617from a2a .types import (
@@ -92,22 +93,33 @@ def to_stored_metadata(
9293 return metadata
9394
9495
95- def to_sdk_metadata (stored_metadata : dict [str , Any ] | None ) -> dict [str , Any ]:
96+ @dataclass
97+ class _UnpackedMetadata :
98+ original_metadata : dict [str , Any ] | None = None
99+ extensions : list [str ] | None = None
100+ reference_task_ids : list [str ] | None = None
101+ part_metadata : list [dict [str , Any ] | None ] | None = None
102+ part_types : list [str ] | None = None
103+
104+
105+ def to_sdk_metadata (
106+ stored_metadata : dict [str , Any ] | None ,
107+ ) -> _UnpackedMetadata :
96108 """Unpacks metadata, extensions, and part types/metadata from a storage dictionary."""
97109 if not stored_metadata :
98- return {}
110+ return _UnpackedMetadata ()
99111
100112 version = stored_metadata .get (_METADATA_VERSION_KEY )
101113 if version is None :
102- return { ' original_metadata' : stored_metadata }
103-
104- return {
105- ' original_metadata' : stored_metadata .get (_ORIGINAL_METADATA_KEY ),
106- ' extensions' : stored_metadata .get (_EXTENSIONS_KEY ),
107- 'reference_tasks' : stored_metadata .get (_REFERENCE_TASK_IDS_KEY ),
108- ' part_metadata' : stored_metadata .get (_PART_METADATA_KEY ),
109- ' part_types' : stored_metadata .get (_PART_TYPES_KEY ),
110- }
114+ return _UnpackedMetadata ( original_metadata = stored_metadata )
115+
116+ return _UnpackedMetadata (
117+ original_metadata = stored_metadata .get (_ORIGINAL_METADATA_KEY ),
118+ extensions = stored_metadata .get (_EXTENSIONS_KEY ),
119+ reference_task_ids = stored_metadata .get (_REFERENCE_TASK_IDS_KEY ),
120+ part_metadata = stored_metadata .get (_PART_METADATA_KEY ),
121+ part_types = stored_metadata .get (_PART_TYPES_KEY ),
122+ )
111123
112124
113125def to_stored_part (part : Part ) -> genai_types .Part :
@@ -173,7 +185,7 @@ def to_sdk_part(
173185 root = FilePart (
174186 file = FileWithUri (
175187 mime_type = stored_part .file_data .mime_type ,
176- uri = stored_part .file_data .file_uri ,
188+ uri = stored_part .file_data .file_uri or '' ,
177189 ),
178190 metadata = part_metadata ,
179191 )
@@ -201,14 +213,14 @@ def to_stored_artifact(artifact: Artifact) -> vertexai_types.TaskArtifact:
201213def to_sdk_artifact (stored_artifact : vertexai_types .TaskArtifact ) -> Artifact :
202214 """Converts a proto TaskArtifact to a SDK Artifact."""
203215 unpacked_meta = to_sdk_metadata (stored_artifact .metadata )
204- part_metadatas = unpacked_meta .get ( ' part_metadata' ) or []
205- part_types = unpacked_meta .get ( ' part_types' ) or []
216+ part_metadata_list = unpacked_meta .part_metadata or []
217+ part_types = unpacked_meta .part_types or []
206218
207219 parts = []
208220 for i , part in enumerate (stored_artifact .parts or []):
209221 meta : dict [str , Any ] | None = None
210- if i < len (part_metadatas ):
211- meta = part_metadatas [i ]
222+ if i < len (part_metadata_list ):
223+ meta = part_metadata_list [i ]
212224 ptype = ''
213225 if i < len (part_types ):
214226 ptype = part_types [i ]
@@ -218,8 +230,8 @@ def to_sdk_artifact(stored_artifact: vertexai_types.TaskArtifact) -> Artifact:
218230 artifact_id = stored_artifact .artifact_id ,
219231 name = stored_artifact .display_name ,
220232 description = stored_artifact .description ,
221- extensions = unpacked_meta .get ( ' extensions' ) ,
222- metadata = unpacked_meta .get ( ' original_metadata' ) ,
233+ extensions = unpacked_meta .extensions ,
234+ metadata = unpacked_meta .original_metadata ,
223235 parts = parts ,
224236 )
225237
@@ -251,25 +263,27 @@ def to_sdk_message(
251263 if not stored_msg :
252264 return None
253265 unpacked_meta = to_sdk_metadata (stored_msg .metadata )
254- part_metadatas = unpacked_meta .get ( ' part_metadata' ) or []
255- part_types = unpacked_meta .get ( ' part_types' ) or []
266+ part_metadata_list = unpacked_meta .part_metadata or []
267+ part_types = unpacked_meta .part_types or []
256268
257269 parts = []
258270 for i , part in enumerate (stored_msg .parts or []):
259- meta : dict [str , Any ] | None = None
260- if i < len (part_metadatas ):
261- meta = part_metadatas [i ]
262- ptype = ''
271+ part_metadata : dict [str , Any ] | None = None
272+ if i < len (part_metadata_list ):
273+ part_metadata = part_metadata_list [i ]
274+ part_type = ''
263275 if i < len (part_types ):
264- ptype = part_types [i ]
265- parts .append (to_sdk_part (part , part_metadata = meta , part_type = ptype ))
276+ part_type = part_types [i ]
277+ parts .append (
278+ to_sdk_part (part , part_metadata = part_metadata , part_type = part_type )
279+ )
266280
267281 return Message (
268282 message_id = stored_msg .message_id ,
269283 role = Role (stored_msg .role ),
270- extensions = unpacked_meta .get ( ' extensions' ) ,
271- reference_task_ids = unpacked_meta .get ( 'reference_tasks' ) ,
272- metadata = unpacked_meta .get ( ' original_metadata' ) ,
284+ extensions = unpacked_meta .extensions ,
285+ reference_task_ids = unpacked_meta .reference_task_ids ,
286+ metadata = unpacked_meta .original_metadata ,
273287 parts = parts ,
274288 )
275289
0 commit comments