Skip to content
Closed
181 changes: 175 additions & 6 deletions src/database/datasets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707."""

import datetime
from collections import defaultdict
import re
from collections.abc import Sequence
from typing import Any, cast

from sqlalchemy import text
from sqlalchemy.engine import Row
from sqlalchemy.ext.asyncio import AsyncConnection

from routers.types import integer_range_regex
from schemas.datasets.openml import Feature


Expand Down Expand Up @@ -149,9 +150,13 @@ async def get_feature_ontologies(
),
parameters={"dataset_id": dataset_id},
)
ontologies: dict[int, list[str]] = defaultdict(list)
for row in rows.mappings():
ontologies[row["index"]].append(row["value"])
ontologies: dict[int, list[str]] = {}
for mapping in rows.mappings():
index = int(mapping["index"])
value = str(mapping["value"])
if index not in ontologies:
ontologies[index] = []
ontologies[index].append(value)
return ontologies


Expand All @@ -175,6 +180,30 @@ async def get_feature_values(
return [row.value for row in rows]


async def get_feature_values_bulk(
dataset_id: int,
connection: AsyncConnection,
) -> dict[int, list[str]]:
rows = await connection.execute(
text(
"""
SELECT `index`, `value`
FROM data_feature_value
WHERE `did` = :dataset_id
""",
),
parameters={"dataset_id": dataset_id},
)
values: dict[int, list[str]] = {}
for mapping in rows.mappings():
index = int(mapping["index"])
value = str(mapping["value"])
if index not in values:
values[index] = []
values[index].append(value)
return values


async def update_status(
dataset_id: int,
status: str,
Expand Down Expand Up @@ -208,3 +237,143 @@ async def remove_deactivated_status(dataset_id: int, connection: AsyncConnection
),
parameters={"data": dataset_id},
)


def _get_quality_filter(quality: str, range_str: str | None, param_name: str) -> str:
if not range_str:
return ""
if not (match := re.match(integer_range_regex, range_str)):
msg = f"Invalid range format for {quality}: {range_str}"
raise ValueError(msg)
_start, end = match.groups()
if end:
value = f"`value` BETWEEN :{param_name}_start AND :{param_name}_end"
else:
value = f"`value` = :{param_name}_start"

return f""" AND
d.`did` IN (
SELECT `data`
FROM data_quality
WHERE `quality` = :quality_name_{param_name} AND {value}
)
""" # noqa: S608


def _get_range_params(
quality_name: str,
range_str: str | None,
param_prefix: str,
) -> dict[str, Any]:
if not range_str:
return {}
if not (match := re.match(integer_range_regex, range_str)):
return {}
start, end = match.groups()
params: dict[str, Any] = {
f"quality_name_{param_prefix}": quality_name,
f"{param_prefix}_start": int(start),
}
if end:
params[f"{param_prefix}_end"] = int(end[2:])
return params


async def list_datasets( # noqa: PLR0913
*,
limit: int,
offset: int,
data_name: str | None = None,
data_version: str | None = None,
tag: str | None = None,
data_ids: list[int] | None = None,
uploader: int | None = None,
number_instances: str | None = None,
number_features: str | None = None,
number_classes: str | None = None,
number_missing_values: str | None = None,
statuses: list[str],
user_id: int | None = None,
is_admin: bool = False,
connection: AsyncConnection,
) -> Sequence[Row]:
current_status = """
SELECT ds1.`did`, ds1.`status`
FROM dataset_status AS ds1
WHERE ds1.`status_date`=(
SELECT MAX(ds2.`status_date`)
FROM dataset_status as ds2
WHERE ds1.`did`=ds2.`did`
)
"""

if is_admin:
visible_to_user = "TRUE"
elif user_id:
visible_to_user = f"(`visibility`='public' OR `uploader`={user_id})"
else:
visible_to_user = "`visibility`='public'"

where_name = "AND `name`=:data_name" if data_name else ""
where_version = "AND `version`=:data_version" if data_version else ""
where_uploader = "AND `uploader`=:uploader" if uploader else ""
where_data_id = "AND d.`did` IN :data_ids" if data_ids else ""

matching_tag = (
"""
AND d.`did` IN (
SELECT `id`
FROM dataset_tag as dt
WHERE dt.`tag`=:tag
)
"""
if tag
else ""
)

q_params: dict[str, Any] = {}
q_filters: list[str] = []

for quality, range_str, prefix in [
("NumberOfInstances", number_instances, "instances"),
("NumberOfFeatures", number_features, "features"),
("NumberOfClasses", number_classes, "classes"),
("NumberOfMissingValues", number_missing_values, "missing_vals"),
]:
q_filters.append(_get_quality_filter(quality, range_str, prefix))
q_params.update(_get_range_params(quality, range_str, prefix))

instances_filter, features_filter, classes_filter, missing_values_filter = q_filters

sql = text(
f"""
SELECT d.`did`, d.`name`, d.`version`, d.`format`, d.`file_id`,
IFNULL(cs.`status`, 'in_preparation') AS status
FROM dataset AS d
LEFT JOIN ({current_status}) AS cs ON d.`did`=cs.`did`
WHERE {visible_to_user} {where_name} {where_version} {where_uploader}
{where_data_id} {matching_tag} {instances_filter} {features_filter}
{classes_filter} {missing_values_filter}
AND IFNULL(cs.`status`, 'in_preparation') IN :statuses
LIMIT :limit OFFSET :offset
""", # noqa: S608
)

parameters = {
"data_name": data_name,
"data_version": data_version,
"uploader": uploader,
"tag": tag,
"statuses": statuses,
"limit": limit,
"offset": offset,
**q_params,
}
if data_ids:
parameters["data_ids"] = data_ids

result = await connection.execute(
sql.bindparams(statuses=statuses, data_ids=data_ids) if data_ids else sql,
parameters=parameters,
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
)
return cast("Sequence[Row]", result.all())
27 changes: 26 additions & 1 deletion src/database/tasks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from collections.abc import Sequence
from typing import cast

from sqlalchemy import Row, text
from sqlalchemy import Row, RowMapping, text
from sqlalchemy.ext.asyncio import AsyncConnection

ALLOWED_LOOKUP_TABLES = ["estimation_procedure", "evaluation_measure", "task_type", "dataset"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did you determine which tables to allow? It's been a little while since I've looked at this code so correct me if I am wrong, but based on the docstring of fill_template it seems that LOOKUP statements always must reference a table that is a input in task_inputs.

PK_MAPPING = {
"task_type": "ttid",
"dataset": "did",
}


async def get(id_: int, expdb: AsyncConnection) -> Row | None:
row = await expdb.execute(
Expand Down Expand Up @@ -115,3 +121,22 @@ async def get_tags(id_: int, expdb: AsyncConnection) -> list[str]:
)
tag_rows = rows.all()
return [row.tag for row in tag_rows]


async def get_lookup_data(table: str, id_: int, expdb: AsyncConnection) -> RowMapping | None:
if table not in ALLOWED_LOOKUP_TABLES:
msg = f"Table {table} is not allowed for lookup."
raise ValueError(msg)

pk = PK_MAPPING.get(table, "id")
result = await expdb.execute(
text(
f"""
SELECT *
FROM {table}
WHERE `{pk}` = :id_
""", # noqa: S608
),
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
parameters={"id_": id_},
)
return result.mappings().one_or_none()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Loading