|
1 | | -"""Tests for vertex_task_converter mappings.""" |
2 | | - |
3 | 1 | import base64 |
4 | 2 |
|
5 | 3 | import pytest |
|
8 | 6 | pytest.importorskip( |
9 | 7 | 'vertexai', reason='Vertex Task Converter tests require vertexai' |
10 | 8 | ) |
11 | | -from google.genai import types as genai_types |
12 | 9 | from vertexai import types as vertexai_types |
13 | | - |
| 10 | +from google.genai import types as genai_types |
14 | 11 | from a2a.contrib.tasks.vertex_task_converter import ( |
15 | 12 | to_sdk_artifact, |
16 | | - to_sdk_message, |
17 | 13 | to_sdk_part, |
18 | 14 | to_sdk_task, |
19 | 15 | to_sdk_task_state, |
20 | 16 | to_stored_artifact, |
21 | | - to_stored_message, |
22 | 17 | to_stored_part, |
23 | 18 | to_stored_task, |
24 | 19 | to_stored_task_state, |
|
29 | 24 | FilePart, |
30 | 25 | FileWithBytes, |
31 | 26 | FileWithUri, |
32 | | - Message, |
33 | 27 | Part, |
34 | | - Role, |
35 | 28 | Task, |
36 | 29 | TaskState, |
37 | 30 | TaskStatus, |
38 | 31 | TextPart, |
39 | 32 | ) |
40 | 33 |
|
41 | 34 |
|
42 | | -def test_artifact_conversion_symmetry() -> None: |
43 | | - """Test converting an Artifact to TaskArtifact and back restores everything.""" |
44 | | - original_artifact = Artifact( |
45 | | - artifact_id='art123', |
46 | | - name='My cool artifact', |
47 | | - description='A very interesting description', |
48 | | - extensions=['ext1', 'ext2'], |
49 | | - metadata={'custom': 'value'}, |
50 | | - parts=[ |
51 | | - Part( |
52 | | - root=TextPart( |
53 | | - text='hello', metadata={'part_meta': 'hello_meta'} |
54 | | - ) |
55 | | - ), |
56 | | - Part(root=DataPart(data={'foo': 'bar'})), # no metadata |
57 | | - ], |
| 35 | +def test_to_sdk_task_state() -> None: |
| 36 | + assert ( |
| 37 | + to_sdk_task_state(vertexai_types.A2aTaskState.STATE_UNSPECIFIED) |
| 38 | + == TaskState.unknown |
| 39 | + ) |
| 40 | + assert ( |
| 41 | + to_sdk_task_state(vertexai_types.A2aTaskState.SUBMITTED) |
| 42 | + == TaskState.submitted |
| 43 | + ) |
| 44 | + assert ( |
| 45 | + to_sdk_task_state(vertexai_types.A2aTaskState.WORKING) |
| 46 | + == TaskState.working |
| 47 | + ) |
| 48 | + assert ( |
| 49 | + to_sdk_task_state(vertexai_types.A2aTaskState.COMPLETED) |
| 50 | + == TaskState.completed |
| 51 | + ) |
| 52 | + assert ( |
| 53 | + to_sdk_task_state(vertexai_types.A2aTaskState.CANCELLED) |
| 54 | + == TaskState.canceled |
| 55 | + ) |
| 56 | + assert ( |
| 57 | + to_sdk_task_state(vertexai_types.A2aTaskState.FAILED) |
| 58 | + == TaskState.failed |
| 59 | + ) |
| 60 | + assert ( |
| 61 | + to_sdk_task_state(vertexai_types.A2aTaskState.REJECTED) |
| 62 | + == TaskState.rejected |
| 63 | + ) |
| 64 | + assert ( |
| 65 | + to_sdk_task_state(vertexai_types.A2aTaskState.INPUT_REQUIRED) |
| 66 | + == TaskState.input_required |
| 67 | + ) |
| 68 | + assert ( |
| 69 | + to_sdk_task_state(vertexai_types.A2aTaskState.AUTH_REQUIRED) |
| 70 | + == TaskState.auth_required |
58 | 71 | ) |
| 72 | + assert to_sdk_task_state(999) == TaskState.unknown # type: ignore |
59 | 73 |
|
60 | | - stored = to_stored_artifact(original_artifact) |
61 | | - assert isinstance(stored, vertexai_types.TaskArtifact) |
62 | | - |
63 | | - # ensure it was populated correctly |
64 | | - assert stored.display_name == 'My cool artifact' |
65 | | - assert stored.description == 'A very interesting description' |
66 | | - assert stored.metadata['__vertex_compat_v'] == 1.0 |
67 | | - |
68 | | - restored_artifact = to_sdk_artifact(stored) |
69 | | - |
70 | | - assert restored_artifact.artifact_id == original_artifact.artifact_id |
71 | | - assert restored_artifact.name == original_artifact.name |
72 | | - assert restored_artifact.description == original_artifact.description |
73 | | - assert restored_artifact.extensions == original_artifact.extensions |
74 | | - assert restored_artifact.metadata == original_artifact.metadata |
75 | | - |
76 | | - assert len(restored_artifact.parts) == 2 |
77 | | - assert isinstance(restored_artifact.parts[0].root, TextPart) |
78 | | - assert restored_artifact.parts[0].root.text == 'hello' |
79 | | - assert restored_artifact.parts[0].root.metadata == { |
80 | | - 'part_meta': 'hello_meta' |
81 | | - } |
82 | | - |
83 | | - assert isinstance(restored_artifact.parts[1].root, DataPart) |
84 | | - assert restored_artifact.parts[1].root.data == {'foo': 'bar'} |
85 | | - assert restored_artifact.parts[1].root.metadata is None |
86 | | - |
87 | | - |
88 | | -def test_message_conversion_symmetry() -> None: |
89 | | - """Test converting a Message to TaskMessage and back restores everything.""" |
90 | | - original_message = Message( |
91 | | - message_id='msg456', |
92 | | - role=Role.agent, |
93 | | - context_id='ctx1', |
94 | | - task_id='tsk1', |
95 | | - reference_task_ids=['tsk2', 'tsk3'], |
96 | | - extensions=['ext_msg'], |
97 | | - metadata={'msg_meta': 42}, |
98 | | - parts=[ |
99 | | - Part(root=TextPart(text='message text')), |
100 | | - ], |
| 74 | + |
| 75 | +def test_to_stored_task_state() -> None: |
| 76 | + assert ( |
| 77 | + to_stored_task_state(TaskState.unknown) |
| 78 | + == vertexai_types.A2aTaskState.STATE_UNSPECIFIED |
| 79 | + ) |
| 80 | + assert ( |
| 81 | + to_stored_task_state(TaskState.submitted) |
| 82 | + == vertexai_types.A2aTaskState.SUBMITTED |
| 83 | + ) |
| 84 | + assert ( |
| 85 | + to_stored_task_state(TaskState.working) |
| 86 | + == vertexai_types.A2aTaskState.WORKING |
| 87 | + ) |
| 88 | + assert ( |
| 89 | + to_stored_task_state(TaskState.completed) |
| 90 | + == vertexai_types.A2aTaskState.COMPLETED |
| 91 | + ) |
| 92 | + assert ( |
| 93 | + to_stored_task_state(TaskState.canceled) |
| 94 | + == vertexai_types.A2aTaskState.CANCELLED |
| 95 | + ) |
| 96 | + assert ( |
| 97 | + to_stored_task_state(TaskState.failed) |
| 98 | + == vertexai_types.A2aTaskState.FAILED |
| 99 | + ) |
| 100 | + assert ( |
| 101 | + to_stored_task_state(TaskState.rejected) |
| 102 | + == vertexai_types.A2aTaskState.REJECTED |
| 103 | + ) |
| 104 | + assert ( |
| 105 | + to_stored_task_state(TaskState.input_required) |
| 106 | + == vertexai_types.A2aTaskState.INPUT_REQUIRED |
| 107 | + ) |
| 108 | + assert ( |
| 109 | + to_stored_task_state(TaskState.auth_required) |
| 110 | + == vertexai_types.A2aTaskState.AUTH_REQUIRED |
101 | 111 | ) |
102 | 112 |
|
103 | | - stored = to_stored_message(original_message) |
104 | | - assert stored is not None |
105 | | - assert isinstance(stored, vertexai_types.TaskMessage) |
106 | 113 |
|
107 | | - assert stored.message_id == 'msg456' |
108 | | - assert stored.role == 'agent' |
109 | | - assert stored.metadata['__vertex_compat_v'] == 1.0 |
| 114 | +def test_to_stored_part_text() -> None: |
| 115 | + sdk_part = Part(root=TextPart(text='hello world')) |
| 116 | + stored_part = to_stored_part(sdk_part) |
| 117 | + assert stored_part.text == 'hello world' |
| 118 | + assert not stored_part.inline_data |
| 119 | + assert not stored_part.file_data |
110 | 120 |
|
111 | | - restored_message = to_sdk_message(stored) |
112 | | - assert restored_message is not None |
113 | 121 |
|
114 | | - assert restored_message.message_id == original_message.message_id |
115 | | - assert restored_message.role == original_message.role |
116 | | - # context_id and task_id are not serialized via Message metadata in Go implementation but via Task, |
117 | | - # but reference_task_ids and extensions ARE part of Message metadata. |
118 | | - assert ( |
119 | | - restored_message.reference_task_ids |
120 | | - == original_message.reference_task_ids |
121 | | - ) |
122 | | - assert restored_message.extensions == original_message.extensions |
123 | | - assert restored_message.metadata == original_message.metadata |
| 122 | +def test_to_stored_part_data() -> None: |
| 123 | + sdk_part = Part(root=DataPart(data={'key': 'value'})) |
| 124 | + stored_part = to_stored_part(sdk_part) |
| 125 | + assert stored_part.inline_data is not None |
| 126 | + assert stored_part.inline_data.mime_type == 'application/json' |
| 127 | + assert stored_part.inline_data.data == b'{"key": "value"}' |
124 | 128 |
|
125 | | - assert len(restored_message.parts) == 1 |
126 | | - assert isinstance(restored_message.parts[0].root, TextPart) |
127 | | - assert restored_message.parts[0].root.text == 'message text' |
128 | | - assert restored_message.parts[0].root.metadata is None |
129 | 129 |
|
| 130 | +def test_to_stored_part_file_bytes() -> None: |
| 131 | + encoded_b64 = base64.b64encode(b'test data').decode('utf-8') |
| 132 | + sdk_part = Part( |
| 133 | + root=FilePart( |
| 134 | + file=FileWithBytes( |
| 135 | + bytes=encoded_b64, |
| 136 | + mime_type='text/plain', |
| 137 | + ) |
| 138 | + ) |
| 139 | + ) |
| 140 | + stored_part = to_stored_part(sdk_part) |
| 141 | + assert stored_part.inline_data is not None |
| 142 | + assert stored_part.inline_data.mime_type == 'text/plain' |
| 143 | + assert stored_part.inline_data.data == b'test data' |
130 | 144 |
|
131 | | -def test_to_stored_part_unsupported() -> None: |
132 | | - part = Part.model_construct( |
133 | | - root=Task( # type: ignore[arg-type] |
134 | | - id='invalid-part', |
135 | | - context_id='ctx', |
136 | | - status=TaskStatus(state=TaskState.submitted), |
137 | | - history=[], |
| 145 | + |
| 146 | +def test_to_stored_part_file_uri() -> None: |
| 147 | + sdk_part = Part( |
| 148 | + root=FilePart( |
| 149 | + file=FileWithUri( |
| 150 | + uri='gs://test-bucket/file.txt', |
| 151 | + mime_type='text/plain', |
| 152 | + ) |
138 | 153 | ) |
139 | 154 | ) |
| 155 | + stored_part = to_stored_part(sdk_part) |
| 156 | + assert stored_part.file_data is not None |
| 157 | + assert stored_part.file_data.mime_type == 'text/plain' |
| 158 | + assert stored_part.file_data.file_uri == 'gs://test-bucket/file.txt' |
| 159 | + |
| 160 | + |
| 161 | +def test_to_stored_part_unsupported() -> None: |
| 162 | + class BadPart: |
| 163 | + pass |
| 164 | + |
| 165 | + part = Part(root=TextPart(text='t')) |
| 166 | + part.root = BadPart() # type: ignore |
140 | 167 | with pytest.raises(ValueError, match='Unsupported part type'): |
141 | 168 | to_stored_part(part) |
142 | 169 |
|
|
0 commit comments