Skip to content

Commit 5820bfd

Browse files
committed
Added intial group support.
1 parent ef097a5 commit 5820bfd

9 files changed

Lines changed: 499 additions & 238 deletions

File tree

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ repos:
3131
args:
3232
- "check"
3333
- "--fix"
34-
- "."
34+
- "taskiq_pipelines"
35+
- "tests"
3536

3637
- id: mypy
3738
name: Validate types with MyPy

poetry.lock

Lines changed: 253 additions & 235 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

taskiq_pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
from taskiq_pipelines.exceptions import AbortPipeline, PipelineError
44
from taskiq_pipelines.middleware import PipelineMiddleware
55
from taskiq_pipelines.pipeliner import Pipeline
6+
from taskiq_pipelines.task_group import Group
67

78
__all__ = [
89
"AbortPipeline",
10+
"Group",
911
"Pipeline",
1012
"PipelineError",
1113
"PipelineMiddleware",

taskiq_pipelines/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22

33
CURRENT_STEP = "_pipe_current_step"
44
PIPELINE_DATA = "_pipe_data"
5+
PARENT_TASK_ID = "_parent_task_id"
56

67
EMPTY_PARAM_NAME: Literal[-1] = -1

taskiq_pipelines/pipeliner.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
from taskiq_pipelines.constants import CURRENT_STEP, EMPTY_PARAM_NAME, PIPELINE_DATA
2222
from taskiq_pipelines.steps import FilterStep, MapperStep, SequentialStep, parse_step
23+
from taskiq_pipelines.steps.group import GroupStep
24+
from taskiq_pipelines.task_group import Group
2325

2426
_ReturnType = TypeVar("_ReturnType")
2527
_FuncParams = ParamSpec("_FuncParams")
@@ -325,6 +327,31 @@ def filter(
325327
)
326328
return self
327329

330+
def group(
331+
self: "Pipeline[_FuncParams, _ReturnType]",
332+
group: Group[_T2],
333+
) -> "Pipeline[_FuncParams, _T2]":
334+
"""
335+
Add group task execution step.
336+
337+
This step will run all tasks in parallel
338+
and will wait for all of them to finish.
339+
340+
Results of all tasks will be returned as an iterable
341+
where each item is a result of the task in the group
342+
with the same order.
343+
344+
:param group: group to execute.
345+
"""
346+
self.steps.append(
347+
DumpedStep(
348+
step_type=GroupStep._step_name,
349+
step_data=group.to_step().model_dump(),
350+
task_id="",
351+
),
352+
)
353+
return self # type: ignore
354+
328355
def dumpb(self) -> bytes:
329356
"""
330357
Dumps current pipeline as string.

taskiq_pipelines/steps/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from taskiq_pipelines.abc import AbstractStep
77
from taskiq_pipelines.steps.filter import FilterStep
8+
from taskiq_pipelines.steps.group import GroupStep
89
from taskiq_pipelines.steps.mapper import MapperStep
910
from taskiq_pipelines.steps.sequential import SequentialStep
1011

@@ -21,6 +22,7 @@ def parse_step(step_type: str, step_data: Dict[str, Any]) -> AbstractStep:
2122

2223
__all__ = [
2324
"FilterStep",
25+
"GroupStep",
2426
"MapperStep",
2527
"SequentialStep",
2628
]

taskiq_pipelines/steps/group.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from typing import Any, Dict, List, Optional
2+
3+
import pydantic
4+
from taskiq import (
5+
AsyncBroker,
6+
Context,
7+
TaskiqDepends,
8+
TaskiqMessage,
9+
TaskiqResult,
10+
async_shared_broker,
11+
)
12+
13+
from taskiq_pipelines.abc import AbstractStep
14+
from taskiq_pipelines.constants import CURRENT_STEP, PIPELINE_DATA
15+
from taskiq_pipelines.steps.mapper import wait_tasks
16+
17+
18+
@async_shared_broker.task(task_name="taskiq_pipelines.shared.wait_group_tasks")
19+
async def wait_group_tasks(
20+
task_ids: List[str],
21+
check_interval: float,
22+
skip_errors: bool = True,
23+
context: Context = TaskiqDepends(),
24+
) -> tuple[Any, ...]:
25+
"""Waits for subtasks to complete."""
26+
res = await wait_tasks(
27+
task_ids,
28+
check_interval=check_interval,
29+
skip_errors=False,
30+
none_if_errors=skip_errors,
31+
context=context,
32+
)
33+
return tuple(res)
34+
35+
36+
class GroupStepItem(pydantic.BaseModel):
37+
"""Item of a group step."""
38+
39+
task_name: str
40+
labels: Dict[str, Any]
41+
labels_types: Optional[Dict[str, int]] = None
42+
args: List[Any]
43+
kwargs: Dict[str, Any]
44+
45+
def from_message(self, message: TaskiqMessage) -> None:
46+
"""
47+
Parse labels and kwargs from message.
48+
49+
:param message: message to parse.
50+
"""
51+
self.labels = message.labels
52+
self.labels_types = message.labels_types
53+
self.args = message.args
54+
self.kwargs = message.kwargs
55+
56+
def to_message(self, task_id: str) -> TaskiqMessage:
57+
"""
58+
Convert this item to message.
59+
60+
:return: message
61+
"""
62+
return TaskiqMessage(
63+
task_id=task_id,
64+
task_name=self.task_name,
65+
labels=self.labels,
66+
labels_types=self.labels_types,
67+
args=self.args,
68+
kwargs=self.kwargs,
69+
)
70+
71+
72+
class GroupStep(pydantic.BaseModel, AbstractStep, step_name="group"):
73+
"""Step that maps iterables."""
74+
75+
tasks: list[GroupStepItem]
76+
skip_errors: bool
77+
check_interval: float
78+
79+
async def act(
80+
self,
81+
broker: AsyncBroker,
82+
step_number: int,
83+
parent_task_id: str,
84+
task_id: str,
85+
pipe_data: str,
86+
result: "TaskiqResult[Any]",
87+
) -> None:
88+
"""
89+
Execute group action.
90+
91+
This steps creates many small tasks
92+
and one waiter task.
93+
94+
The waiter task awaits for all small tasks to complete,
95+
and then assembles the final result.
96+
"""
97+
ids: List[str] = []
98+
for task in self.tasks:
99+
subtask_id = broker.id_generator()
100+
ids.append(subtask_id)
101+
await broker.kick(broker.formatter.dumps(task.to_message(subtask_id)))
102+
103+
await (
104+
wait_group_tasks.kicker()
105+
.with_broker(broker)
106+
.with_task_id(task_id)
107+
.with_labels(
108+
**{CURRENT_STEP: step_number, PIPELINE_DATA: pipe_data}, # type: ignore
109+
)
110+
.kiq(
111+
task_ids=ids,
112+
skip_errors=True,
113+
check_interval=self.check_interval,
114+
)
115+
)

taskiq_pipelines/steps/mapper.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from logging import getLogger
23
from typing import Any, Dict, Iterable, List, Optional, Union
34

45
import pydantic
@@ -16,12 +17,15 @@
1617
from taskiq_pipelines.constants import CURRENT_STEP, PIPELINE_DATA
1718
from taskiq_pipelines.exceptions import AbortPipeline, MappingError
1819

20+
logger = getLogger("taskiq_pipelines")
21+
1922

2023
@async_shared_broker.task(task_name="taskiq_pipelines.shared.wait_tasks")
21-
async def wait_tasks(
24+
async def wait_tasks( # noqa: C901
2225
task_ids: List[str],
2326
check_interval: float,
2427
skip_errors: bool = True,
28+
none_if_errors: bool = False,
2529
context: Context = TaskiqDepends(),
2630
) -> List[Any]:
2731
"""
@@ -53,12 +57,16 @@ async def wait_tasks(
5357
if tasks_set:
5458
await asyncio.sleep(check_interval)
5559

56-
results = []
60+
results: List[Any] = []
5761
for task_id in ordered_ids:
5862
result = await context.broker.result_backend.get_result(task_id)
63+
logger.warning("Found error: %s", result.error)
5964
if result.is_err:
6065
if skip_errors:
6166
continue
67+
if none_if_errors:
68+
results.append(None)
69+
continue
6270
err_cause = None
6371
if isinstance(result.error, BaseException):
6472
err_cause = result.error
@@ -137,6 +145,7 @@ async def act(
137145
sub_task_ids,
138146
check_interval=self.check_interval,
139147
skip_errors=self.skip_errors,
148+
none_if_errors=False,
140149
)
141150
)
142151

taskiq_pipelines/task_group.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from types import CoroutineType
2+
from typing import Any, Coroutine, Generic, Tuple, Union, overload
3+
4+
from taskiq import AsyncTaskiqDecoratedTask
5+
from taskiq.kicker import AsyncKicker
6+
from typing_extensions import ParamSpec, TypeVar, TypeVarTuple, Unpack
7+
8+
from taskiq_pipelines.steps.group import GroupStep, GroupStepItem
9+
10+
_Tups = TypeVarTuple("_Tups")
11+
_T = TypeVar("_T")
12+
_TVal = TypeVar("_TVal")
13+
_Params = ParamSpec("_Params")
14+
15+
16+
class Group(Generic[_T]):
17+
"""
18+
Group of tasks.
19+
20+
This class gathers multiple tasks together.
21+
They will run in parallel
22+
23+
:param skip_errors: If True, errors in one task will not affect others.
24+
"""
25+
26+
def __init__(
27+
self: "Group[Tuple[()]]",
28+
skip_errors: bool = False,
29+
check_interval: float = 0.1,
30+
) -> None:
31+
self.tasks: Tuple[GroupStepItem, ...] = ()
32+
self.skip_errors = skip_errors
33+
self.check_interval = check_interval
34+
35+
@overload
36+
def add(
37+
self: "Group[Tuple[Unpack[_Tups]]]",
38+
task: Union[
39+
AsyncKicker[_Params, Coroutine[Any, Any, _TVal]],
40+
AsyncKicker[_Params, "CoroutineType[Any, Any, _TVal]"],
41+
AsyncTaskiqDecoratedTask[_Params, Coroutine[Any, Any, _TVal]],
42+
AsyncTaskiqDecoratedTask[_Params, "CoroutineType[Any, Any, _TVal]"],
43+
],
44+
*args: _Params.args,
45+
**kwargs: _Params.kwargs,
46+
) -> "Group[Tuple[Unpack[_Tups], _TVal]]": ...
47+
48+
@overload
49+
def add(
50+
self: "Group[Tuple[Unpack[_Tups]]]",
51+
task: Union[
52+
AsyncKicker[_Params, _TVal],
53+
AsyncTaskiqDecoratedTask[_Params, _TVal],
54+
],
55+
*args: _Params.args,
56+
**kwargs: _Params.kwargs,
57+
) -> "Group[Tuple[Unpack[_Tups], _TVal]]": ...
58+
59+
def add(
60+
self: "Group[Any]",
61+
task: Union[AsyncKicker[_Params, Any], AsyncTaskiqDecoratedTask[_Params, Any]],
62+
*args: _Params.args,
63+
**kwargs: _Params.kwargs,
64+
) -> "Any":
65+
"""Add task to a group."""
66+
kicker = task.kicker() if isinstance(task, AsyncTaskiqDecoratedTask) else task
67+
message = kicker._prepare_message(*args, **kwargs)
68+
self.tasks = (
69+
*self.tasks,
70+
GroupStepItem(
71+
task_name=message.task_name,
72+
labels=message.labels,
73+
labels_types=message.labels_types,
74+
args=message.args,
75+
kwargs=message.kwargs,
76+
),
77+
)
78+
return self
79+
80+
def to_step(self) -> GroupStep:
81+
"""Convert group definition to a step."""
82+
return GroupStep(
83+
tasks=list(self.tasks),
84+
skip_errors=self.skip_errors,
85+
check_interval=self.check_interval,
86+
)

0 commit comments

Comments
 (0)