-
-
Notifications
You must be signed in to change notification settings - Fork 116
Expand file tree
/
Copy pathtest_progress_tracker.py
More file actions
122 lines (100 loc) · 3.17 KB
/
test_progress_tracker.py
File metadata and controls
122 lines (100 loc) · 3.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, Optional
import pytest
from pydantic import ValidationError
from taskiq import (
AsyncTaskiqDecoratedTask,
InMemoryBroker,
TaskiqDepends,
TaskiqMessage,
)
from taskiq.abc import AsyncBroker
from taskiq.depends.progress_tracker import ProgressTracker, TaskState
from taskiq.receiver import Receiver
def get_receiver(
broker: Optional[AsyncBroker] = None,
no_parse: bool = False,
max_async_tasks: Optional[int] = None,
) -> Receiver:
"""
Returns receiver with custom broker and args.
:param broker: broker, defaults to None
:param no_parse: parameter to taskiq_args, defaults to False
:param cli_args: Taskiq worker CLI arguments.
:return: new receiver.
"""
if broker is None:
broker = InMemoryBroker()
return Receiver(
broker,
executor=ThreadPoolExecutor(max_workers=10),
validate_params=not no_parse,
max_async_tasks=max_async_tasks,
)
def get_message(
task: AsyncTaskiqDecoratedTask[Any, Any],
task_id: Optional[str] = None,
*args: Any,
labels: Optional[Dict[str, str]] = None,
**kwargs: Dict[str, Any],
) -> TaskiqMessage:
if labels is None:
labels = {}
return TaskiqMessage(
task_id=task_id or task.broker.id_generator(),
task_name=task.task_name,
labels=labels,
queue="taskiq",
args=list(args),
kwargs=kwargs,
)
@pytest.mark.anyio
@pytest.mark.parametrize(
"state,meta",
[
(TaskState.STARTED, "hello world!"),
("retry", "retry error!"),
("custom state", {"Complex": "Value"}),
],
)
async def test_progress_tracker_ctx_raw(state: Any, meta: Any) -> None:
broker = InMemoryBroker()
@broker.task
async def test_func(tes_val: ProgressTracker[Any] = TaskiqDepends()) -> None:
await tes_val.set_progress(state, meta)
kicker = await test_func.kiq()
result = await kicker.wait_result()
assert not result.is_err
progress = await broker.result_backend.get_progress(kicker.task_id)
assert progress is not None
assert progress.meta == meta
assert progress.state == state
@pytest.mark.anyio
async def test_progress_tracker_ctx_none() -> None:
broker = InMemoryBroker()
@broker.task
async def test_func() -> None:
pass
kicker = await test_func.kiq()
result = await kicker.wait_result()
assert not result.is_err
progress = await broker.result_backend.get_progress(kicker.task_id)
assert progress is None
@pytest.mark.anyio
@pytest.mark.parametrize(
"state,meta",
[
(("state", "error"), 1),
],
)
async def test_progress_tracker_validation_error(state: Any, meta: Any) -> None:
broker = InMemoryBroker()
@broker.task
async def test_func(progress: ProgressTracker[int] = TaskiqDepends()) -> None:
await progress.set_progress(state, meta) # type: ignore
kicker = await test_func.kiq()
result = await kicker.wait_result()
with pytest.raises(ValidationError):
result.raise_for_error()
progress = await broker.result_backend.get_progress(kicker.task_id)
assert progress is None