-
Notifications
You must be signed in to change notification settings - Fork 422
Expand file tree
/
Copy pathfake_vertex_client.py
More file actions
143 lines (120 loc) · 4.7 KB
/
fake_vertex_client.py
File metadata and controls
143 lines (120 loc) · 4.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""Fake Vertex AI Client implementations for testing."""
import copy
from google.genai import errors as genai_errors
from vertexai import types as vertexai_types
class FakeAgentEnginesA2aTasksEventsClient:
def __init__(self, parent_client):
self.parent_client = parent_client
async def append(
self, name: str, task_events: list[vertexai_types.TaskEvent]
) -> None:
task = self.parent_client.tasks.get(name)
if not task:
raise genai_errors.APIError(
code=404,
response_json={
'error': {
'status': 'NOT_FOUND',
'message': 'Task not found',
}
},
)
task = copy.deepcopy(task)
if (
not hasattr(task, 'next_event_sequence_number')
or not task.next_event_sequence_number
):
task.next_event_sequence_number = 0
for event in task_events:
data = event.event_data
if getattr(data, 'state_change', None):
task.state = getattr(data.state_change, 'new_state', task.state)
if getattr(data, 'status_details_change', None):
task.status_details = getattr(
data.status_details_change,
'new_task_status',
getattr(task, 'status_details', None),
)
if getattr(data, 'metadata_change', None):
task.metadata = getattr(
data.metadata_change, 'new_metadata', task.metadata
)
if getattr(data, 'output_change', None):
change = getattr(
data.output_change, 'task_artifact_change', None
)
if not change:
continue
if not getattr(task, 'output', None):
task.output = vertexai_types.TaskOutput()
current_artifacts = (
list(task.output.artifacts)
if getattr(task.output, 'artifacts', None)
else []
)
deleted_ids = getattr(change, 'deleted_artifact_ids', []) or []
if deleted_ids:
current_artifacts = [
a
for a in current_artifacts
if a.artifact_id not in deleted_ids
]
added = getattr(change, 'added_artifacts', []) or []
if added:
current_artifacts.extend(added)
updated = getattr(change, 'updated_artifacts', []) or []
if updated:
updated_map = {a.artifact_id: a for a in updated}
current_artifacts = [
updated_map.get(a.artifact_id, a)
for a in current_artifacts
]
try:
del task.output.artifacts[:]
task.output.artifacts.extend(current_artifacts)
except Exception:
task.output.artifacts = current_artifacts
task.next_event_sequence_number += 1
self.parent_client.tasks[name] = task
class FakeAgentEnginesA2aTasksClient:
def __init__(self):
self.tasks: dict[str, vertexai_types.A2aTask] = {}
self.events = FakeAgentEnginesA2aTasksEventsClient(self)
async def create(
self,
name: str,
a2a_task_id: str,
config: vertexai_types.CreateAgentEngineTaskConfig,
) -> vertexai_types.A2aTask:
full_name = f'{name}/a2aTasks/{a2a_task_id}'
task = vertexai_types.A2aTask(
name=full_name,
context_id=config.context_id,
metadata=config.metadata,
output=config.output,
state=vertexai_types.State.SUBMITTED,
)
task.next_event_sequence_number = 1
self.tasks[full_name] = task
return task
async def get(self, name: str) -> vertexai_types.A2aTask:
if name not in self.tasks:
raise genai_errors.APIError(
code=404,
response_json={
'error': {
'status': 'NOT_FOUND',
'message': 'Task not found',
}
},
)
return copy.deepcopy(self.tasks[name])
class FakeAgentEnginesClient:
def __init__(self):
self.a2a_tasks = FakeAgentEnginesA2aTasksClient()
class FakeAioClient:
def __init__(self):
self.agent_engines = FakeAgentEnginesClient()
class FakeVertexClient:
def __init__(self):
self.aio = FakeAioClient()