diff --git a/src/core/errors.py b/src/core/errors.py index 3f53364a..93cfae76 100644 --- a/src/core/errors.py +++ b/src/core/errors.py @@ -4,6 +4,7 @@ See: https://www.rfc-editor.org/rfc/rfc9457.html """ +from enum import IntEnum from http import HTTPStatus from fastapi import Request @@ -89,6 +90,17 @@ def problem_detail_exception_handler( ) +# ============================================================================= +# User Error Codes +# ============================================================================= + + +class UserError(IntEnum): + NOT_FOUND = 120 + NO_ACCESS = 121 + HAS_RESOURCES = 122 + + # ============================================================================= # Dataset Errors # ============================================================================= diff --git a/src/database/users.py b/src/database/users.py index b4f2e63e..49fdf11e 100644 --- a/src/database/users.py +++ b/src/database/users.py @@ -73,3 +73,73 @@ async def get_groups(self) -> list[UserGroup]: group_ids = await get_user_groups_for(user_id=self.user_id, connection=self._database) self._groups = [UserGroup(group_id) for group_id in group_ids] return self._groups + + +async def get_user_resource_count(*, user_id: int, expdb: AsyncConnection) -> int: + """Return the total number of datasets, flows, and runs owned by the user.""" + dataset_count = ( + await expdb.execute( + text("SELECT COUNT(*) FROM dataset WHERE uploader = :user_id"), + parameters={"user_id": user_id}, + ) + ).scalar() or 0 + flow_count = ( + await expdb.execute( + text("SELECT COUNT(*) FROM implementation WHERE uploader = :user_id"), + parameters={"user_id": user_id}, + ) + ).scalar() or 0 + run_count = ( + await expdb.execute( + text("SELECT COUNT(*) FROM run WHERE uploader = :user_id"), + parameters={"user_id": user_id}, + ) + ).scalar() or 0 + + study_count = ( + await expdb.execute( + text("SELECT COUNT(*) FROM study WHERE creator = :user_id"), + parameters={"user_id": user_id}, + ) + ).scalar() or 0 + task_study_count = ( + await expdb.execute( + text("SELECT COUNT(*) FROM task_study WHERE uploader = :user_id"), + parameters={"user_id": user_id}, + ) + ).scalar() or 0 + run_study_count = ( + await expdb.execute( + text("SELECT COUNT(*) FROM run_study WHERE uploader = :user_id"), + parameters={"user_id": user_id}, + ) + ).scalar() or 0 + dataset_tag_count = ( + await expdb.execute( + text("SELECT COUNT(*) FROM dataset_tag WHERE uploader = :user_id"), + parameters={"user_id": user_id}, + ) + ).scalar() or 0 + + return int( + dataset_count + + flow_count + + run_count + + study_count + + task_study_count + + run_study_count + + dataset_tag_count, + ) + + +async def delete_user(*, user_id: int, connection: AsyncConnection) -> None: + """Remove the user and their group memberships from the user database.""" + async with connection.begin_nested(): + await connection.execute( + text("DELETE FROM users_groups WHERE user_id = :user_id"), + parameters={"user_id": user_id}, + ) + await connection.execute( + text("DELETE FROM users WHERE id = :user_id"), + parameters={"user_id": user_id}, + ) diff --git a/src/main.py b/src/main.py index 76a52ad3..e3079672 100644 --- a/src/main.py +++ b/src/main.py @@ -19,6 +19,7 @@ from routers.openml.study import router as study_router from routers.openml.tasks import router as task_router from routers.openml.tasktype import router as ttype_router +from routers.openml.users import router as users_router @asynccontextmanager @@ -69,6 +70,7 @@ def create_api() -> FastAPI: app.include_router(task_router) app.include_router(flows_router) app.include_router(study_router) + app.include_router(users_router) app.include_router(setup_router) return app diff --git a/src/routers/openml/flows.py b/src/routers/openml/flows.py index 41254863..dbc467a9 100644 --- a/src/routers/openml/flows.py +++ b/src/routers/openml/flows.py @@ -7,29 +7,38 @@ from core.conversions import _str_to_num from core.errors import FlowNotFoundError from routers.dependencies import expdb_connection -from schemas.flows import Flow, Parameter, Subflow +from schemas.flows import Flow, FlowExistsBody, Parameter, Subflow router = APIRouter(prefix="/flows", tags=["flows"]) -@router.get("/exists/{name}/{external_version}") +@router.post("/exists") async def flow_exists( - name: str, - external_version: str, + body: FlowExistsBody, expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[Literal["flow_id"], int]: """Check if a Flow with the name and version exists, if so, return the flow id.""" flow = await database.flows.get_by_name( - name=name, - external_version=external_version, + name=body.name, + external_version=body.external_version, expdb=expdb, ) if flow is None: - msg = f"Flow with name {name} and external version {external_version} not found." + msg = f"Flow with name {body.name} and external version {body.external_version} not found." raise FlowNotFoundError(msg) return {"flow_id": flow.id} +@router.get("/exists/{name}/{external_version}", deprecated=True) +async def flow_exists_get( + name: str, + external_version: str, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> dict[Literal["flow_id"], int]: + """Use POST /flows/exists instead.""" + return await flow_exists(FlowExistsBody(name=name, external_version=external_version), expdb) + + @router.get("/{flow_id}") async def get_flow( flow_id: int, diff --git a/src/routers/openml/users.py b/src/routers/openml/users.py new file mode 100644 index 00000000..8a82c44b --- /dev/null +++ b/src/routers/openml/users.py @@ -0,0 +1,106 @@ +"""User account endpoints for the OpenML API.""" + +import uuid +from http import HTTPStatus +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncConnection + +from core.errors import UserError +from database.users import User, UserGroup, delete_user, get_user_resource_count +from routers.dependencies import expdb_connection, fetch_user, userdb_connection + +router = APIRouter(prefix="/users", tags=["users"]) + + +@router.delete( + "/{user_id}", + summary="Delete a user account", + description=( + "Deletes the account of the specified user. " + "Only the account owner or an admin may perform this action. " + "Deletion is blocked if the user has uploaded any owned resources." + ), +) +async def delete_account( + user_id: int, + caller: Annotated[User | None, Depends(fetch_user)] = None, + user_db: Annotated[AsyncConnection, Depends(userdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, +) -> dict[str, Any]: + """Delete a user account if authorized and no owned resources exist.""" + if caller is None: + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, + detail={"code": str(int(UserError.NO_ACCESS)), "message": "Authentication required"}, + ) + + groups = await caller.get_groups() + is_admin = UserGroup.ADMIN in groups + is_self = caller.user_id == user_id + + if not is_admin and not is_self: + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, + detail={"code": str(int(UserError.NO_ACCESS)), "message": "No access granted"}, + ) + + original_result = await user_db.execute( + text("SELECT session_hash FROM users WHERE id = :id FOR UPDATE"), + parameters={"id": user_id}, + ) + original = original_result.fetchone() + + if original is None: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail={"code": str(int(UserError.NOT_FOUND)), "message": "User not found"}, + ) + + # Invalidate session while delete flow is in-progress. + original_session_hash = original[0] + temp_lock_hash = uuid.uuid4().hex + await user_db.execute( + text("UPDATE users SET session_hash = :lock_hash WHERE id = :id"), + parameters={"lock_hash": temp_lock_hash, "id": user_id}, + ) + # Persist lock hash before cross-database checks so other connections + # cannot keep authenticating with the old session hash. + await user_db.commit() + + deletion_successful = False + try: + resource_count = await get_user_resource_count(user_id=user_id, expdb=expdb) + if resource_count > 0: + raise HTTPException( + status_code=HTTPStatus.CONFLICT, + detail={ + "code": str(int(UserError.HAS_RESOURCES)), + "message": ( + f"User has {resource_count} resource(s). " + "Remove or transfer resources before deleting the account." + ), + }, + ) + + await delete_user(user_id=user_id, connection=user_db) + await user_db.commit() + deletion_successful = True + return {"user_id": user_id, "deleted": True} + finally: + if not deletion_successful: + # Restore only if we still hold our lock value. + await user_db.execute( + text( + "UPDATE users SET session_hash = :hash " + "WHERE id = :id AND session_hash = :lock_hash", + ), + parameters={ + "hash": original_session_hash, + "id": user_id, + "lock_hash": temp_lock_hash, + }, + ) + await user_db.commit() diff --git a/src/schemas/flows.py b/src/schemas/flows.py index a6cd479c..50e2491c 100644 --- a/src/schemas/flows.py +++ b/src/schemas/flows.py @@ -6,6 +6,11 @@ from pydantic import BaseModel, ConfigDict, Field +class FlowExistsBody(BaseModel): + name: str + external_version: str + + class Parameter(BaseModel): name: str default_value: Any diff --git a/tests/routers/openml/flows_test.py b/tests/routers/openml/flows_test.py index 400ec4c0..0b5dc46d 100644 --- a/tests/routers/openml/flows_test.py +++ b/tests/routers/openml/flows_test.py @@ -8,6 +8,7 @@ from core.errors import FlowNotFoundError from routers.openml.flows import flow_exists +from schemas.flows import FlowExistsBody from tests.conftest import Flow @@ -28,7 +29,7 @@ async def test_flow_exists_calls_db_correctly( "database.flows.get_by_name", new_callable=mocker.AsyncMock, ) - await flow_exists(name, external_version, expdb_test) + await flow_exists(FlowExistsBody(name=name, external_version=external_version), expdb_test) mocked_db.assert_called_once_with( name=name, external_version=external_version, @@ -51,29 +52,42 @@ async def test_flow_exists_processes_found( new_callable=mocker.AsyncMock, return_value=fake_flow, ) - response = await flow_exists("name", "external_version", expdb_test) + response = await flow_exists( + FlowExistsBody(name="name", external_version="external_version"), + expdb_test, + ) assert response == {"flow_id": fake_flow.id} async def test_flow_exists_handles_flow_not_found( mocker: MockerFixture, expdb_test: AsyncConnection ) -> None: - mocker.patch("database.flows.get_by_name", return_value=None) + mocker.patch( + "database.flows.get_by_name", + new_callable=mocker.AsyncMock, + return_value=None, + ) with pytest.raises(FlowNotFoundError) as error: - await flow_exists("foo", "bar", expdb_test) + await flow_exists(FlowExistsBody(name="foo", external_version="bar"), expdb_test) assert error.value.status_code == HTTPStatus.NOT_FOUND assert error.value.uri == FlowNotFoundError.uri async def test_flow_exists(flow: Flow, py_api: httpx.AsyncClient) -> None: - response = await py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}") + response = await py_api.post( + "/flows/exists", + json={"name": flow.name, "external_version": flow.external_version}, + ) assert response.status_code == HTTPStatus.OK assert response.json() == {"flow_id": flow.id} async def test_flow_exists_not_exists(py_api: httpx.AsyncClient) -> None: name, version = "foo", "bar" - response = await py_api.get(f"/flows/exists/{name}/{version}") + response = await py_api.post( + "/flows/exists", + json={"name": name, "external_version": version}, + ) assert response.status_code == HTTPStatus.NOT_FOUND assert response.headers["content-type"] == "application/problem+json" error = response.json() @@ -82,6 +96,22 @@ async def test_flow_exists_not_exists(py_api: httpx.AsyncClient) -> None: assert version in error["detail"] +async def test_flow_exists_get_alias(flow: Flow, py_api: httpx.AsyncClient) -> None: + """Test the deprecated GET wrapper for backward compatibility.""" + response = await py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}") + assert response.status_code == HTTPStatus.OK + assert response.json() == {"flow_id": flow.id} + + +async def test_flow_exists_get_alias_not_exists(py_api: httpx.AsyncClient) -> None: + """Test the deprecated GET wrapper returns 404 for non-existent flows.""" + response = await py_api.get("/flows/exists/foo/bar") + assert response.status_code == HTTPStatus.NOT_FOUND + assert response.headers["content-type"] == "application/problem+json" + error = response.json() + assert error["type"] == FlowNotFoundError.uri + + async def test_get_flow_no_subflow(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/flows/1") assert response.status_code == HTTPStatus.OK diff --git a/tests/routers/openml/migration/flows_migration_test.py b/tests/routers/openml/migration/flows_migration_test.py index 38d11e8c..f442478f 100644 --- a/tests/routers/openml/migration/flows_migration_test.py +++ b/tests/routers/openml/migration/flows_migration_test.py @@ -19,10 +19,9 @@ async def test_flow_exists_not( py_api: httpx.AsyncClient, php_api: httpx.AsyncClient, ) -> None: - path = "exists/foo/bar" py_response, php_response = await asyncio.gather( - py_api.get(f"/flows/{path}"), - php_api.get(f"/flow/{path}"), + py_api.post("/flows/exists", json={"name": "foo", "external_version": "bar"}), + php_api.get("/flow/exists/foo/bar"), ) assert py_response.status_code == HTTPStatus.NOT_FOUND @@ -43,10 +42,15 @@ async def test_flow_exists( py_api: httpx.AsyncClient, php_api: httpx.AsyncClient, ) -> None: - path = f"exists/{persisted_flow.name}/{persisted_flow.external_version}" py_response, php_response = await asyncio.gather( - py_api.get(f"/flows/{path}"), - php_api.get(f"/flow/{path}"), + py_api.post( + "/flows/exists", + json={ + "name": persisted_flow.name, + "external_version": persisted_flow.external_version, + }, + ), + php_api.get(f"/flow/exists/{persisted_flow.name}/{persisted_flow.external_version}"), ) assert py_response.status_code == php_response.status_code, php_response.content diff --git a/tests/routers/openml/users_test.py b/tests/routers/openml/users_test.py index 7250a115..8cc34d45 100644 --- a/tests/routers/openml/users_test.py +++ b/tests/routers/openml/users_test.py @@ -1,4 +1,8 @@ +from http import HTTPStatus + +import httpx import pytest +from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncConnection from database.users import User @@ -25,3 +29,187 @@ async def test_fetch_user_invalid_key_returns_none(user_test: AsyncConnection) - assert await fetch_user(api_key=None, user_data=user_test) is None invalid_key = "f" * 32 assert await fetch_user(api_key=invalid_key, user_data=user_test) is None + + +@pytest.mark.mut +async def test_delete_user_self(py_api: httpx.AsyncClient, user_test: AsyncConnection) -> None: + """A user without resources can delete their own account.""" + await user_test.execute( + text( + "INSERT INTO users (session_hash, email, first_name, last_name, password)" + " VALUES ('aaaabbbbccccddddaaaabbbbccccdddd', 'del@test.com', 'Del', 'User', 'x')", + ), + ) + result = await user_test.execute(text("SELECT LAST_INSERT_ID()")) + (new_id,) = result.one() + + await user_test.execute( + text("INSERT INTO users_groups (user_id, group_id) VALUES (:id, 2)"), + parameters={"id": new_id}, + ) + + response = await py_api.delete(f"/users/{new_id}?api_key=aaaabbbbccccddddaaaabbbbccccdddd") + assert response.status_code == HTTPStatus.OK + assert response.json() == {"user_id": new_id, "deleted": True} + + user_count = ( + await user_test.execute( + text("SELECT COUNT(*) FROM users WHERE id = :id"), + parameters={"id": new_id}, + ) + ).scalar() + group_count = ( + await user_test.execute( + text("SELECT COUNT(*) FROM users_groups WHERE user_id = :id"), + parameters={"id": new_id}, + ) + ).scalar() + assert user_count == 0 + assert group_count == 0 + + +@pytest.mark.mut +async def test_delete_user_as_admin(py_api: httpx.AsyncClient, user_test: AsyncConnection) -> None: + """An admin can delete any user without resources.""" + await user_test.execute( + text( + "INSERT INTO users (session_hash, email, first_name, last_name, password)" + " VALUES ('eeeeffffaaaabbbbeeeeffffaaaabbbb', 'del2@test.com', 'Del2', 'User', 'x')", + ), + ) + result = await user_test.execute(text("SELECT LAST_INSERT_ID()")) + (new_id,) = result.one() + + await user_test.execute( + text("INSERT INTO users_groups (user_id, group_id) VALUES (:id, 2)"), + parameters={"id": new_id}, + ) + + response = await py_api.delete(f"/users/{new_id}?api_key={ApiKey.ADMIN}") + assert response.status_code == HTTPStatus.OK + assert response.json() == {"user_id": new_id, "deleted": True} + + user_count = ( + await user_test.execute( + text("SELECT COUNT(*) FROM users WHERE id = :id"), + parameters={"id": new_id}, + ) + ).scalar() + group_count = ( + await user_test.execute( + text("SELECT COUNT(*) FROM users_groups WHERE user_id = :id"), + parameters={"id": new_id}, + ) + ).scalar() + assert user_count == 0 + assert group_count == 0 + + +async def test_delete_user_no_auth(py_api: httpx.AsyncClient) -> None: + """No API key -> 401.""" + response = await py_api.delete("/users/2") + assert response.status_code == HTTPStatus.UNAUTHORIZED + + +async def test_delete_user_not_owner(py_api: httpx.AsyncClient) -> None: + """A non-owner non-admin user cannot delete someone else's account -> 403.""" + response = await py_api.delete(f"/users/3229?api_key={ApiKey.SOME_USER}") + assert response.status_code == HTTPStatus.FORBIDDEN + + +async def test_delete_user_not_found(py_api: httpx.AsyncClient) -> None: + """Deleting a non-existent user -> 404.""" + response = await py_api.delete(f"/users/99999999?api_key={ApiKey.ADMIN}") + assert response.status_code == HTTPStatus.NOT_FOUND + assert response.json()["detail"]["code"] == "120" + + +async def test_delete_user_has_resources( + py_api: httpx.AsyncClient, user_test: AsyncConnection +) -> None: + """A user with resources (datasets, flows, runs) gets a 409 Conflict.""" + target_id = 16 + response = await py_api.delete(f"/users/{target_id}?api_key={ApiKey.DATASET_130_OWNER}") + + assert response.status_code == HTTPStatus.CONFLICT + assert response.json()["detail"]["code"] == "122" + assert "resource(s)" in response.json()["detail"]["message"] + + user_count = ( + await user_test.execute( + text("SELECT COUNT(*) FROM users WHERE id = :id"), + parameters={"id": target_id}, + ) + ).scalar() + session_hash = ( + await user_test.execute( + text("SELECT session_hash FROM users WHERE id = :id"), + parameters={"id": target_id}, + ) + ).scalar() + assert user_count == 1 + assert session_hash == ApiKey.DATASET_130_OWNER + + +@pytest.mark.mut +@pytest.mark.parametrize( + "insert_sql", + [ + "INSERT INTO dataset (uploader, name, format) VALUES (:id, 'x', 'ARFF')", + ( + "INSERT INTO implementation (uploader, fullname, name, version, " + "external_version, uploadDate) VALUES (:id, 'x', 'x', 1, '1', '2024-01-01')" + ), + "INSERT INTO run (uploader, task_id, setup) VALUES (:id, 1, 1)", + "INSERT INTO study (creator, name, main_entity_type) VALUES (:id, 'x', 'run')", + "INSERT INTO task_study (uploader, study_id, task_id) VALUES (:id, 14, 1)", + "INSERT INTO run_study (uploader, study_id, run_id) VALUES (:id, 14, 1)", + "INSERT INTO dataset_tag (uploader, id, tag) VALUES (:id, 1, 'x')", + ], + ids=[ + "dataset", + "implementation", + "run", + "study", + "task_study", + "run_study", + "dataset_tag", + ], +) +async def test_delete_user_has_resources_parametrized( + py_api: httpx.AsyncClient, + user_test: AsyncConnection, + expdb_test: AsyncConnection, + insert_sql: str, +) -> None: + """Verify that possessing any tracked resource blocks deletion.""" + await user_test.execute( + text( + "INSERT INTO users (session_hash, email, first_name, last_name, password)" + " VALUES ('eeeeffffccccddddaaaabbbbccccdddd', 'res@test.com', 'Del', 'User', 'x')", + ), + ) + result = await user_test.execute(text("SELECT LAST_INSERT_ID()")) + (new_id,) = result.one() + + # Keep inserts inside rollback-scoped transaction used by the test harness. + async with expdb_test.begin_nested(): + await expdb_test.execute(text("SET FOREIGN_KEY_CHECKS=0")) + try: + await expdb_test.execute(text(insert_sql), parameters={"id": new_id}) + finally: + await expdb_test.execute(text("SET FOREIGN_KEY_CHECKS=1")) + + response = await py_api.delete(f"/users/{new_id}?api_key=eeeeffffccccddddaaaabbbbccccdddd") + + assert response.status_code == HTTPStatus.CONFLICT + assert response.json()["detail"]["code"] == "122" + assert "resource(s)" in response.json()["detail"]["message"] + + user_count = ( + await user_test.execute( + text("SELECT COUNT(*) FROM users WHERE id = :id"), + parameters={"id": new_id}, + ) + ).scalar() + assert user_count == 1