Skip to content
23 changes: 16 additions & 7 deletions src/routers/openml/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,38 @@
from core.conversions import _str_to_num
from core.errors import FlowNotFoundError
from routers.dependencies import expdb_connection
from schemas.flows import Flow, Parameter, Subflow
from schemas.flows import Flow, FlowExistsBody, Parameter, Subflow

router = APIRouter(prefix="/flows", tags=["flows"])


@router.get("/exists/{name}/{external_version}")
@router.post("/exists")
async def flow_exists(
name: str,
external_version: str,
body: FlowExistsBody,
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[Literal["flow_id"], int]:
"""Check if a Flow with the name and version exists, if so, return the flow id."""
flow = await database.flows.get_by_name(
name=name,
external_version=external_version,
name=body.name,
external_version=body.external_version,
expdb=expdb,
)
if flow is None:
msg = f"Flow with name {name} and external version {external_version} not found."
msg = f"Flow with name {body.name} and external version {body.external_version} not found."
raise FlowNotFoundError(msg)
return {"flow_id": flow.id}


@router.get("/exists/{name}/{external_version}", deprecated=True)
async def flow_exists_get(
name: str,
external_version: str,
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[Literal["flow_id"], int]:
"""Use POST /flows/exists instead."""
return await flow_exists(FlowExistsBody(name=name, external_version=external_version), expdb)


@router.get("/{flow_id}")
async def get_flow(
flow_id: int,
Expand Down
5 changes: 5 additions & 0 deletions src/schemas/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from pydantic import BaseModel, ConfigDict, Field


class FlowExistsBody(BaseModel):
name: str = Field(min_length=1, max_length=1024)
external_version: str = Field(min_length=1, max_length=128)


class Parameter(BaseModel):
name: str
default_value: Any
Expand Down
39 changes: 33 additions & 6 deletions tests/routers/openml/flows_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from core.errors import FlowNotFoundError
from routers.openml.flows import flow_exists
from schemas.flows import FlowExistsBody
from tests.conftest import Flow


Expand All @@ -28,7 +29,7 @@ async def test_flow_exists_calls_db_correctly(
"database.flows.get_by_name",
new_callable=mocker.AsyncMock,
)
await flow_exists(name, external_version, expdb_test)
await flow_exists(FlowExistsBody(name=name, external_version=external_version), expdb_test)
mocked_db.assert_called_once_with(
name=name,
external_version=external_version,
Expand All @@ -51,29 +52,37 @@ async def test_flow_exists_processes_found(
new_callable=mocker.AsyncMock,
return_value=fake_flow,
)
response = await flow_exists("name", "external_version", expdb_test)
response = await flow_exists(
FlowExistsBody(name="name", external_version="external_version"), expdb_test
)
assert response == {"flow_id": fake_flow.id}


async def test_flow_exists_handles_flow_not_found(
mocker: MockerFixture, expdb_test: AsyncConnection
) -> None:
mocker.patch("database.flows.get_by_name", return_value=None)
mocker.patch(
"database.flows.get_by_name",
new_callable=mocker.AsyncMock,
return_value=None,
)
with pytest.raises(FlowNotFoundError) as error:
await flow_exists("foo", "bar", expdb_test)
await flow_exists(FlowExistsBody(name="foo", external_version="bar"), expdb_test)
assert error.value.status_code == HTTPStatus.NOT_FOUND
assert error.value.uri == FlowNotFoundError.uri


async def test_flow_exists(flow: Flow, py_api: httpx.AsyncClient) -> None:
response = await py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}")
response = await py_api.post(
"/flows/exists", json={"name": flow.name, "external_version": flow.external_version}
)
assert response.status_code == HTTPStatus.OK
assert response.json() == {"flow_id": flow.id}


async def test_flow_exists_not_exists(py_api: httpx.AsyncClient) -> None:
name, version = "foo", "bar"
response = await py_api.get(f"/flows/exists/{name}/{version}")
response = await py_api.post("/flows/exists", json={"name": name, "external_version": version})
assert response.status_code == HTTPStatus.NOT_FOUND
assert response.headers["content-type"] == "application/problem+json"
error = response.json()
Expand All @@ -82,6 +91,24 @@ async def test_flow_exists_not_exists(py_api: httpx.AsyncClient) -> None:
assert version in error["detail"]


@pytest.mark.parametrize(
("name", "external_version"),
[
("", "v1"),
("some-flow", ""),
],
)
async def test_flow_exists_rejects_empty_fields(
py_api: httpx.AsyncClient,
name: str,
external_version: str,
) -> None:
response = await py_api.post(
"/flows/exists", json={"name": name, "external_version": external_version}
)
assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY


async def test_get_flow_no_subflow(py_api: httpx.AsyncClient) -> None:
response = await py_api.get("/flows/1")
assert response.status_code == HTTPStatus.OK
Expand Down
10 changes: 8 additions & 2 deletions tests/routers/openml/migration/flows_migration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def test_flow_exists_not(
) -> None:
path = "exists/foo/bar"
py_response, php_response = await asyncio.gather(
py_api.get(f"/flows/{path}"),
py_api.post("/flows/exists", json={"name": "foo", "external_version": "bar"}),
php_api.get(f"/flow/{path}"),
)

Expand All @@ -45,7 +45,13 @@ async def test_flow_exists(
) -> None:
path = f"exists/{persisted_flow.name}/{persisted_flow.external_version}"
py_response, php_response = await asyncio.gather(
py_api.get(f"/flows/{path}"),
py_api.post(
"/flows/exists",
json={
"name": persisted_flow.name,
"external_version": persisted_flow.external_version,
},
),
php_api.get(f"/flow/{path}"),
)

Expand Down