diff --git a/src/database/datasets.py b/src/database/datasets.py index 4e76dcf9..26eb33d8 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -1,6 +1,7 @@ """Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707.""" import datetime +from collections import defaultdict from sqlalchemy import text from sqlalchemy.engine import Row @@ -134,6 +135,26 @@ async def get_features(dataset_id: int, connection: AsyncConnection) -> list[Fea return [Feature(**row, nominal_values=None) for row in rows] +async def get_feature_ontologies( + dataset_id: int, + connection: AsyncConnection, +) -> dict[int, list[str]]: + rows = await connection.execute( + text( + """ + SELECT `index`, `value` + FROM data_feature_description + WHERE `did` = :dataset_id AND `description_type` = 'ontology' + """, + ), + parameters={"dataset_id": dataset_id}, + ) + ontologies: dict[int, list[str]] = defaultdict(list) + for row in rows.mappings(): + ontologies[row["index"]].append(row["value"]) + return ontologies + + async def get_feature_values( dataset_id: int, *, diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index d86ed848..164efd7d 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -294,6 +294,10 @@ async def get_dataset_features( assert expdb is not None # noqa: S101 await _get_dataset_raise_otherwise(dataset_id, user, expdb) features = await database.datasets.get_features(dataset_id, expdb) + ontologies = await database.datasets.get_feature_ontologies(dataset_id, expdb) + for feature in features: + feature.ontology = ontologies.get(feature.index) + for feature in [f for f in features if f.data_type == FeatureType.NOMINAL]: feature.nominal_values = await database.datasets.get_feature_values( dataset_id, diff --git a/src/schemas/datasets/openml.py b/src/schemas/datasets/openml.py index b1c23a73..767cbb81 100644 --- a/src/schemas/datasets/openml.py +++ b/src/schemas/datasets/openml.py @@ -40,6 +40,7 @@ class Feature(BaseModel): index: int name: str data_type: FeatureType + ontology: list[str] | None = None is_target: bool is_ignore: bool is_row_identifier: bool diff --git a/tests/routers/openml/migration/datasets_migration_test.py b/tests/routers/openml/migration/datasets_migration_test.py index 5ff6fe86..75f30863 100644 --- a/tests/routers/openml/migration/datasets_migration_test.py +++ b/tests/routers/openml/migration/datasets_migration_test.py @@ -259,6 +259,8 @@ async def test_datasets_feature_is_identical( values = feature.pop(key) # The old API returns a str if there is only a single element feature["nominal_value"] = values if len(values) > 1 else values[0] + elif key == "ontology": + del feature[key] # Added back in with follow up PR #262 else: # The old API formats bool as string in lower-case feature[key] = str(value) if not isinstance(value, bool) else str(value).lower()