Skip to content

Commit a1e8a42

Browse files
committed
Add tag/untag endpoints for tasks, flows, and runs
- Add POST /tasks/tag, /tasks/untag, /flows/tag, /flows/untag, /runs/tag, /runs/untag - Extract shared database helpers (database/tagging.py) to avoid SQL duplication - Extract shared router logic (core/tagging.py) with tag_entity/untag_entity helpers - Use new error classes (TagAlreadyExistsError, TagNotFoundError, TagNotOwnedError) - Use fetch_user_or_raise dependency for auth, matching existing patterns - Add database/runs.py with get, get_tags, tag, get_tag, delete_tag - Register runs router in main.py - 33 tests covering auth, tagging, untagging, duplicates, ownership
1 parent 5357b01 commit a1e8a42

12 files changed

Lines changed: 781 additions & 29 deletions

File tree

src/core/tagging.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from collections.abc import Awaitable, Callable
2+
from typing import Any
3+
4+
from sqlalchemy import Row
5+
from sqlalchemy.ext.asyncio import AsyncConnection
6+
7+
from core.errors import TagAlreadyExistsError, TagNotFoundError, TagNotOwnedError
8+
from database.users import User, UserGroup
9+
10+
11+
async def tag_entity(
12+
entity_id: int,
13+
tag: str,
14+
user: User,
15+
expdb: AsyncConnection,
16+
*,
17+
get_tags_fn: Callable[[int, AsyncConnection], Awaitable[list[str]]],
18+
tag_fn: Callable[..., Awaitable[None]],
19+
response_key: str,
20+
) -> dict[str, dict[str, Any]]:
21+
tags = await get_tags_fn(entity_id, expdb)
22+
if tag.casefold() in (t.casefold() for t in tags):
23+
msg = f"Entity {entity_id} already tagged with {tag!r}."
24+
raise TagAlreadyExistsError(msg)
25+
await tag_fn(entity_id, tag, user_id=user.user_id, expdb=expdb)
26+
tags = await get_tags_fn(entity_id, expdb)
27+
return {response_key: {"id": str(entity_id), "tag": tags}}
28+
29+
30+
async def untag_entity(
31+
entity_id: int,
32+
tag: str,
33+
user: User,
34+
expdb: AsyncConnection,
35+
*,
36+
get_tag_fn: Callable[[int, str, AsyncConnection], Awaitable[Row | None]],
37+
delete_tag_fn: Callable[[int, str, AsyncConnection], Awaitable[None]],
38+
get_tags_fn: Callable[[int, AsyncConnection], Awaitable[list[str]]],
39+
response_key: str,
40+
) -> dict[str, dict[str, Any]]:
41+
existing = await get_tag_fn(entity_id, tag, expdb)
42+
if existing is None:
43+
msg = f"Tag {tag!r} not found on entity {entity_id}."
44+
raise TagNotFoundError(msg)
45+
groups = await user.get_groups()
46+
if existing.uploader != user.user_id and UserGroup.ADMIN not in groups:
47+
msg = f"Tag {tag!r} on entity {entity_id} is not owned by you."
48+
raise TagNotOwnedError(msg)
49+
await delete_tag_fn(entity_id, tag, expdb)
50+
tags = await get_tags_fn(entity_id, expdb)
51+
return {response_key: {"id": str(entity_id), "tag": tags}}

src/database/flows.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
from sqlalchemy import Row, text
55
from sqlalchemy.ext.asyncio import AsyncConnection
66

7+
from database.tagging import insert_tag, remove_tag, select_tag, select_tags
8+
9+
_TABLE = "implementation_tag"
10+
_ID_COLUMN = "id"
11+
712

813
async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]:
914
rows = await expdb.execute(
@@ -23,18 +28,7 @@ async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]:
2328

2429

2530
async def get_tags(flow_id: int, expdb: AsyncConnection) -> list[str]:
26-
rows = await expdb.execute(
27-
text(
28-
"""
29-
SELECT tag
30-
FROM implementation_tag
31-
WHERE id = :flow_id
32-
""",
33-
),
34-
parameters={"flow_id": flow_id},
35-
)
36-
tag_rows = rows.all()
37-
return [tag.tag for tag in tag_rows]
31+
return await select_tags(table=_TABLE, id_column=_ID_COLUMN, id_=flow_id, expdb=expdb)
3832

3933

4034
async def get_parameters(flow_id: int, expdb: AsyncConnection) -> Sequence[Row]:
@@ -54,6 +48,20 @@ async def get_parameters(flow_id: int, expdb: AsyncConnection) -> Sequence[Row]:
5448
)
5549

5650

51+
async def tag(id_: int, tag_: str, *, user_id: int, expdb: AsyncConnection) -> None:
52+
await insert_tag(
53+
table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, user_id=user_id, expdb=expdb,
54+
)
55+
56+
57+
async def get_tag(id_: int, tag_: str, expdb: AsyncConnection) -> Row | None:
58+
return await select_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb)
59+
60+
61+
async def delete_tag(id_: int, tag_: str, expdb: AsyncConnection) -> None:
62+
await remove_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb)
63+
64+
5765
async def get_by_name(name: str, external_version: str, expdb: AsyncConnection) -> Row | None:
5866
"""Get flow by name and external version."""
5967
row = await expdb.execute(

src/database/runs.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from sqlalchemy import Row, text
2+
from sqlalchemy.ext.asyncio import AsyncConnection
3+
4+
from database.tagging import insert_tag, remove_tag, select_tag, select_tags
5+
6+
_TABLE = "run_tag"
7+
_ID_COLUMN = "id"
8+
9+
10+
async def get(id_: int, expdb: AsyncConnection) -> Row | None:
11+
row = await expdb.execute(
12+
text(
13+
"""
14+
SELECT *
15+
FROM run
16+
WHERE `id` = :run_id
17+
""",
18+
),
19+
parameters={"run_id": id_},
20+
)
21+
return row.one_or_none()
22+
23+
24+
async def get_tags(id_: int, expdb: AsyncConnection) -> list[str]:
25+
return await select_tags(table=_TABLE, id_column=_ID_COLUMN, id_=id_, expdb=expdb)
26+
27+
28+
async def tag(id_: int, tag_: str, *, user_id: int, expdb: AsyncConnection) -> None:
29+
await insert_tag(
30+
table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, user_id=user_id, expdb=expdb,
31+
)
32+
33+
34+
async def get_tag(id_: int, tag_: str, expdb: AsyncConnection) -> Row | None:
35+
return await select_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb)
36+
37+
38+
async def delete_tag(id_: int, tag_: str, expdb: AsyncConnection) -> None:
39+
await remove_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb)

src/database/tagging.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from sqlalchemy import Row, text
2+
from sqlalchemy.ext.asyncio import AsyncConnection
3+
4+
5+
async def insert_tag(
6+
*,
7+
table: str,
8+
id_column: str,
9+
id_: int,
10+
tag_: str,
11+
user_id: int,
12+
expdb: AsyncConnection,
13+
) -> None:
14+
await expdb.execute(
15+
text(
16+
f"""
17+
INSERT INTO {table}(`{id_column}`, `tag`, `uploader`)
18+
VALUES (:id, :tag, :user_id)
19+
""",
20+
),
21+
parameters={"id": id_, "tag": tag_, "user_id": user_id},
22+
)
23+
24+
25+
async def select_tag(
26+
*,
27+
table: str,
28+
id_column: str,
29+
id_: int,
30+
tag_: str,
31+
expdb: AsyncConnection,
32+
) -> Row | None:
33+
result = await expdb.execute(
34+
text(
35+
f"""
36+
SELECT `{id_column}` as id, `tag`, `uploader`
37+
FROM {table}
38+
WHERE `{id_column}` = :id AND `tag` = :tag
39+
""",
40+
),
41+
parameters={"id": id_, "tag": tag_},
42+
)
43+
return result.one_or_none()
44+
45+
46+
async def remove_tag(
47+
*,
48+
table: str,
49+
id_column: str,
50+
id_: int,
51+
tag_: str,
52+
expdb: AsyncConnection,
53+
) -> None:
54+
await expdb.execute(
55+
text(
56+
f"""
57+
DELETE FROM {table}
58+
WHERE `{id_column}` = :id AND `tag` = :tag
59+
""",
60+
),
61+
parameters={"id": id_, "tag": tag_},
62+
)
63+
64+
65+
async def select_tags(
66+
*,
67+
table: str,
68+
id_column: str,
69+
id_: int,
70+
expdb: AsyncConnection,
71+
) -> list[str]:
72+
result = await expdb.execute(
73+
text(
74+
f"""
75+
SELECT `tag`
76+
FROM {table}
77+
WHERE `{id_column}` = :id
78+
""",
79+
),
80+
parameters={"id": id_},
81+
)
82+
return [row.tag for row in result.all()]

src/database/tasks.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
from sqlalchemy import Row, text
55
from sqlalchemy.ext.asyncio import AsyncConnection
66

7+
from database.tagging import insert_tag, remove_tag, select_tag, select_tags
8+
9+
_TABLE = "task_tag"
10+
_ID_COLUMN = "id"
11+
712

813
async def get(id_: int, expdb: AsyncConnection) -> Row | None:
914
row = await expdb.execute(
@@ -103,15 +108,18 @@ async def get_task_type_inout_with_template(
103108

104109

105110
async def get_tags(id_: int, expdb: AsyncConnection) -> list[str]:
106-
rows = await expdb.execute(
107-
text(
108-
"""
109-
SELECT `tag`
110-
FROM task_tag
111-
WHERE `id` = :task_id
112-
""",
113-
),
114-
parameters={"task_id": id_},
111+
return await select_tags(table=_TABLE, id_column=_ID_COLUMN, id_=id_, expdb=expdb)
112+
113+
114+
async def tag(id_: int, tag_: str, *, user_id: int, expdb: AsyncConnection) -> None:
115+
await insert_tag(
116+
table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, user_id=user_id, expdb=expdb,
115117
)
116-
tag_rows = rows.all()
117-
return [row.tag for row in tag_rows]
118+
119+
120+
async def get_tag(id_: int, tag_: str, expdb: AsyncConnection) -> Row | None:
121+
return await select_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb)
122+
123+
124+
async def delete_tag(id_: int, tag_: str, expdb: AsyncConnection) -> None:
125+
await remove_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb)

src/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from routers.openml.evaluations import router as evaluationmeasures_router
1616
from routers.openml.flows import router as flows_router
1717
from routers.openml.qualities import router as qualities_router
18+
from routers.openml.runs import router as runs_router
1819
from routers.openml.setups import router as setup_router
1920
from routers.openml.study import router as study_router
2021
from routers.openml.tasks import router as task_router
@@ -68,6 +69,7 @@ def create_api() -> FastAPI:
6869
app.include_router(estimationprocedure_router)
6970
app.include_router(task_router)
7071
app.include_router(flows_router)
72+
app.include_router(runs_router)
7173
app.include_router(study_router)
7274
app.include_router(setup_router)
7375
return app

src/routers/openml/flows.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,57 @@
1-
from typing import Annotated, Literal
1+
from typing import Annotated, Any, Literal
22

3-
from fastapi import APIRouter, Depends
3+
from fastapi import APIRouter, Body, Depends
44
from sqlalchemy.ext.asyncio import AsyncConnection
55

66
import database.flows
77
from core.conversions import _str_to_num
88
from core.errors import FlowNotFoundError
9-
from routers.dependencies import expdb_connection
9+
from core.tagging import tag_entity, untag_entity
10+
from database.users import User
11+
from routers.dependencies import expdb_connection, fetch_user_or_raise
12+
from routers.types import SystemString64
1013
from schemas.flows import Flow, Parameter, Subflow
1114

1215
router = APIRouter(prefix="/flows", tags=["flows"])
1316

1417

18+
@router.post(path="/tag")
19+
async def tag_flow(
20+
flow_id: Annotated[int, Body()],
21+
tag: Annotated[str, SystemString64],
22+
user: Annotated[User, Depends(fetch_user_or_raise)],
23+
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
24+
) -> dict[str, dict[str, Any]]:
25+
return await tag_entity(
26+
flow_id,
27+
tag,
28+
user,
29+
expdb,
30+
get_tags_fn=database.flows.get_tags,
31+
tag_fn=database.flows.tag,
32+
response_key="flow_tag",
33+
)
34+
35+
36+
@router.post(path="/untag")
37+
async def untag_flow(
38+
flow_id: Annotated[int, Body()],
39+
tag: Annotated[str, SystemString64],
40+
user: Annotated[User, Depends(fetch_user_or_raise)],
41+
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
42+
) -> dict[str, dict[str, Any]]:
43+
return await untag_entity(
44+
flow_id,
45+
tag,
46+
user,
47+
expdb,
48+
get_tag_fn=database.flows.get_tag,
49+
delete_tag_fn=database.flows.delete_tag,
50+
get_tags_fn=database.flows.get_tags,
51+
response_key="flow_tag",
52+
)
53+
54+
1555
@router.get("/exists/{name}/{external_version}")
1656
async def flow_exists(
1757
name: str,

src/routers/openml/runs.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Annotated, Any
2+
3+
from fastapi import APIRouter, Body, Depends
4+
from sqlalchemy.ext.asyncio import AsyncConnection
5+
6+
import database.runs
7+
from core.tagging import tag_entity, untag_entity
8+
from database.users import User
9+
from routers.dependencies import expdb_connection, fetch_user_or_raise
10+
from routers.types import SystemString64
11+
12+
router = APIRouter(prefix="/runs", tags=["runs"])
13+
14+
15+
@router.post(path="/tag")
16+
async def tag_run(
17+
run_id: Annotated[int, Body()],
18+
tag: Annotated[str, SystemString64],
19+
user: Annotated[User, Depends(fetch_user_or_raise)],
20+
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
21+
) -> dict[str, dict[str, Any]]:
22+
return await tag_entity(
23+
run_id,
24+
tag,
25+
user,
26+
expdb,
27+
get_tags_fn=database.runs.get_tags,
28+
tag_fn=database.runs.tag,
29+
response_key="run_tag",
30+
)
31+
32+
33+
@router.post(path="/untag")
34+
async def untag_run(
35+
run_id: Annotated[int, Body()],
36+
tag: Annotated[str, SystemString64],
37+
user: Annotated[User, Depends(fetch_user_or_raise)],
38+
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
39+
) -> dict[str, dict[str, Any]]:
40+
return await untag_entity(
41+
run_id,
42+
tag,
43+
user,
44+
expdb,
45+
get_tag_fn=database.runs.get_tag,
46+
delete_tag_fn=database.runs.delete_tag,
47+
get_tags_fn=database.runs.get_tags,
48+
response_key="run_tag",
49+
)

0 commit comments

Comments
 (0)