Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/groundlight/edge/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from .config import (
DEFAULT,
DISABLED,
EDGE_ANSWERS_WITH_ESCALATION,
NO_CLOUD,
DetectorConfig,
DetectorsConfig,
EdgeEndpointConfig,
GlobalConfig,
InferenceConfig,
)

__all__ = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, this is a little redundant since it includes all objects that could be importet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ChatGPT says: "good point. all is kept intentionally to make the public API explicit/stable for this new module (and to control wildcard-import surface), but it can be removed if repo convention prefers omitting it."

I'm not really sure which way is best...

"DEFAULT",
"DISABLED",
"EDGE_ANSWERS_WITH_ESCALATION",
"NO_CLOUD",
"DetectorsConfig",
"DetectorConfig",
"EdgeEndpointConfig",
"GlobalConfig",
"InferenceConfig",
]
183 changes: 183 additions & 0 deletions src/groundlight/edge/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from typing import Any, Optional, Union

from model import Detector
from pydantic import BaseModel, ConfigDict, Field, model_serializer, model_validator
from typing_extensions import Self


class GlobalConfig(BaseModel):
refresh_rate: float = Field(
default=60.0,
description="The interval (in seconds) at which the inference server checks for a new model binary update.",
)
confident_audit_rate: float = Field(
default=1e-5, # A detector running at 1 FPS = ~100,000 IQ/day, so 1e-5 is ~1 confident IQ/day audited
description="The probability that any given confident prediction will be sent to the cloud for auditing.",
)


class InferenceConfig(BaseModel):
"""
Configuration for edge inference on a specific detector.
"""

# Keep shared presets immutable (DEFAULT/NO_CLOUD/etc.) so one mutation cannot globally change behavior.
model_config = ConfigDict(frozen=True)

name: str = Field(..., exclude=True, description="A unique name for this inference config preset.")
enabled: bool = Field(
default=True, description="Whether the edge endpoint should accept image queries for this detector."
)
api_token: Optional[str] = Field(
default=None, description="API token used to fetch the inference model for this detector."
)
always_return_edge_prediction: bool = Field(
default=False,
description=(
"Indicates if the edge-endpoint should always provide edge ML predictions, regardless of confidence. "
"When this setting is true, whether or not the edge-endpoint should escalate low-confidence predictions "
"to the cloud is determined by `disable_cloud_escalation`."
),
)
disable_cloud_escalation: bool = Field(
default=False,
description=(
"Never escalate ImageQueries from the edge-endpoint to the cloud. "
"Requires `always_return_edge_prediction=True`."
),
)
min_time_between_escalations: float = Field(
default=2.0,
description=(
"The minimum time (in seconds) to wait between cloud escalations for a given detector. "
"Cannot be less than 0.0. "
"Only applies when `always_return_edge_prediction=True` and `disable_cloud_escalation=False`."
),
)

@model_validator(mode="after")
def validate_configuration(self) -> Self:
if self.disable_cloud_escalation and not self.always_return_edge_prediction:
raise ValueError(
"The `disable_cloud_escalation` flag is only valid when `always_return_edge_prediction` is set to True."
)
if self.min_time_between_escalations < 0.0:
raise ValueError("`min_time_between_escalations` cannot be less than 0.0.")
return self


class DetectorConfig(BaseModel):
"""
Configuration for a specific detector.
"""

detector_id: str = Field(..., description="Detector ID")
edge_inference_config: str = Field(..., description="Config for edge inference.")


class DetectorsConfig(BaseModel):
"""
Detector and inference-config mappings for edge inference.
"""

edge_inference_configs: dict[str, InferenceConfig] = Field(default_factory=dict)
detectors: list[DetectorConfig] = Field(default_factory=list)

@model_validator(mode="after")
def validate_inference_configs(self):
"""
Validates detector config state.
Raises ValueError if dict keys mismatch InferenceConfig.name, detector IDs are duplicated,
or any detector references an undefined inference config.
"""
for name, config in self.edge_inference_configs.items():
if name != config.name:
raise ValueError(f"Edge inference config key '{name}' must match InferenceConfig.name '{config.name}'.")

seen_detector_ids = set()
duplicate_detector_ids = set()
for detector_config in self.detectors:
detector_id = detector_config.detector_id
if detector_id in seen_detector_ids:
duplicate_detector_ids.add(detector_id)
else:
seen_detector_ids.add(detector_id)
if duplicate_detector_ids:
duplicates = ", ".join(sorted(duplicate_detector_ids))
raise ValueError(f"Duplicate detector IDs are not allowed: {duplicates}.")

for detector_config in self.detectors:
if detector_config.edge_inference_config not in self.edge_inference_configs:
raise ValueError(f"Edge inference config '{detector_config.edge_inference_config}' not defined.")
return self

def add_detector(self, detector: Union[str, Detector], edge_inference_config: InferenceConfig) -> None:
"""Add a detector with the given inference config. Accepts detector ID or Detector object."""
detector_id = detector.id if isinstance(detector, Detector) else detector
if any(existing.detector_id == detector_id for existing in self.detectors):
raise ValueError(f"A detector with ID '{detector_id}' already exists.")

existing = self.edge_inference_configs.get(edge_inference_config.name)
if existing is None:
self.edge_inference_configs[edge_inference_config.name] = edge_inference_config
elif existing != edge_inference_config:
raise ValueError(
f"A different inference config named '{edge_inference_config.name}' is already registered."
)

self.detectors.append(DetectorConfig(detector_id=detector_id, edge_inference_config=edge_inference_config.name))

def to_payload(self) -> dict[str, Any]:
"""Return flattened detector payload used by edge-endpoint config HTTP APIs."""
return {
"edge_inference_configs": {
name: config.model_dump() for name, config in self.edge_inference_configs.items()
},
"detectors": [detector.model_dump() for detector in self.detectors],
}


class EdgeEndpointConfig(BaseModel):
"""
Top-level edge endpoint configuration.
"""

global_config: GlobalConfig = Field(default_factory=GlobalConfig)
detectors_config: DetectorsConfig = Field(default_factory=DetectorsConfig)

@property
def edge_inference_configs(self) -> dict[str, InferenceConfig]:
"""Convenience accessor for detector inference config map."""
return self.detectors_config.edge_inference_configs

@property
def detectors(self) -> list[DetectorConfig]:
"""Convenience accessor for detector assignments."""
return self.detectors_config.detectors

@model_serializer(mode="plain")
def serialize(self):
"""Serialize to the flattened shape expected by edge-endpoint configs."""
return {
"global_config": self.global_config.model_dump(),
**self.detectors_config.to_payload(),
}

def add_detector(self, detector: Union[str, Detector], edge_inference_config: InferenceConfig) -> None:
"""Add a detector with the given inference config. Accepts detector ID or Detector object."""
self.detectors_config.add_detector(detector, edge_inference_config)


# Preset inference configs matching the standard edge-endpoint defaults.
DEFAULT = InferenceConfig(name="default")
EDGE_ANSWERS_WITH_ESCALATION = InferenceConfig(
name="edge_answers_with_escalation",
always_return_edge_prediction=True,
min_time_between_escalations=2.0,
)
NO_CLOUD = InferenceConfig(
name="no_cloud",
always_return_edge_prediction=True,
disable_cloud_escalation=True,
)
DISABLED = InferenceConfig(name="disabled", enabled=False)
178 changes: 178 additions & 0 deletions test/unit/test_edge_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from datetime import datetime, timezone

import pytest
from groundlight.edge import (
DEFAULT,
DISABLED,
EDGE_ANSWERS_WITH_ESCALATION,
NO_CLOUD,
DetectorsConfig,
EdgeEndpointConfig,
GlobalConfig,
InferenceConfig,
)
from model import Detector, DetectorTypeEnum

CUSTOM_REFRESH_RATE = 10.0
CUSTOM_AUDIT_RATE = 0.0


def _make_detector(detector_id: str) -> Detector:
return Detector(
id=detector_id,
type=DetectorTypeEnum.detector,
created_at=datetime.now(timezone.utc),
name="test detector",
query="Is there a dog?",
group_name="default",
metadata=None,
mode="BINARY",
mode_configuration=None,
)


def test_edge_endpoint_config_is_not_subclass_of_detectors_config():
assert not issubclass(EdgeEndpointConfig, DetectorsConfig)


def test_add_detector_allows_equivalent_named_inference_config():
detectors_config = DetectorsConfig()
detectors_config.add_detector(
"det_1",
InferenceConfig(
name="custom_config",
always_return_edge_prediction=True,
min_time_between_escalations=0.5,
),
)
detectors_config.add_detector(
"det_2",
InferenceConfig(
name="custom_config",
always_return_edge_prediction=True,
min_time_between_escalations=0.5,
),
)

assert len(detectors_config.detectors) == 2 # noqa: PLR2004
assert list(detectors_config.edge_inference_configs.keys()) == ["custom_config"]


def test_add_detector_rejects_different_named_inference_config():
detectors_config = DetectorsConfig()
detectors_config.add_detector("det_1", InferenceConfig(name="custom_config"))

with pytest.raises(ValueError, match="different inference config named 'custom_config'"):
detectors_config.add_detector(
"det_2",
InferenceConfig(name="custom_config", always_return_edge_prediction=True),
)


def test_add_detector_rejects_duplicate_detector_id():
detectors_config = DetectorsConfig()
detectors_config.add_detector("det_1", DEFAULT)

with pytest.raises(ValueError, match="already exists"):
detectors_config.add_detector("det_1", DEFAULT)


def test_constructor_rejects_duplicate_detector_ids():
with pytest.raises(ValueError, match="Duplicate detector IDs"):
DetectorsConfig(
edge_inference_configs={"default": DEFAULT},
detectors=[
{"detector_id": "det_1", "edge_inference_config": "default"},
{"detector_id": "det_1", "edge_inference_config": "default"},
],
)


def test_constructor_rejects_mismatched_inference_config_key_and_name():
with pytest.raises(ValueError, match="must match InferenceConfig.name"):
DetectorsConfig(
edge_inference_configs={"default": InferenceConfig(name="not_default")},
detectors=[],
)


def test_constructor_accepts_matching_inference_config_key_and_name():
config = DetectorsConfig(
edge_inference_configs={"default": InferenceConfig(name="default")},
detectors=[{"detector_id": "det_1", "edge_inference_config": "default"}],
)

assert list(config.edge_inference_configs.keys()) == ["default"]
assert [detector.detector_id for detector in config.detectors] == ["det_1"]


def test_constructor_rejects_undefined_inference_config_reference():
with pytest.raises(ValueError, match="not defined"):
DetectorsConfig(
edge_inference_configs={},
detectors=[{"detector_id": "det_1", "edge_inference_config": "does_not_exist"}],
)


def test_edge_endpoint_config_add_detector_delegates_to_detectors_logic():
config = EdgeEndpointConfig()
config.add_detector("det_1", NO_CLOUD)
config.add_detector("det_2", EDGE_ANSWERS_WITH_ESCALATION)
config.add_detector("det_3", DEFAULT)

assert [detector.detector_id for detector in config.detectors] == ["det_1", "det_2", "det_3"]
assert set(config.edge_inference_configs.keys()) == {"no_cloud", "edge_answers_with_escalation", "default"}


def test_add_detector_accepts_detector_object():
config = EdgeEndpointConfig()
config.add_detector(_make_detector("det_1"), DEFAULT)

assert [detector.detector_id for detector in config.detectors] == ["det_1"]


def test_disabled_preset_can_be_used():
config = EdgeEndpointConfig()
config.add_detector("det_1", DISABLED)

assert [detector.edge_inference_config for detector in config.detectors] == ["disabled"]
assert config.edge_inference_configs["disabled"] == DISABLED


def test_detectors_config_to_payload_shape():
detectors_config = DetectorsConfig()
detectors_config.add_detector("det_1", DEFAULT)
detectors_config.add_detector("det_2", NO_CLOUD)

payload = detectors_config.to_payload()

assert len(payload["detectors"]) == 2 # noqa: PLR2004
assert set(payload["edge_inference_configs"].keys()) == {"default", "no_cloud"}


def test_model_dump_shape_for_edge_endpoint_config():
config = EdgeEndpointConfig(
global_config=GlobalConfig(refresh_rate=CUSTOM_REFRESH_RATE, confident_audit_rate=CUSTOM_AUDIT_RATE)
)
config.add_detector("det_1", DEFAULT)
config.add_detector("det_2", EDGE_ANSWERS_WITH_ESCALATION)
config.add_detector("det_3", NO_CLOUD)

payload = config.model_dump()

assert payload["global_config"]["refresh_rate"] == CUSTOM_REFRESH_RATE
assert payload["global_config"]["confident_audit_rate"] == CUSTOM_AUDIT_RATE
assert len(payload["detectors"]) == 3 # noqa: PLR2004
assert set(payload["edge_inference_configs"].keys()) == {"default", "edge_answers_with_escalation", "no_cloud"}


def test_inference_config_validation_errors():
with pytest.raises(ValueError, match="disable_cloud_escalation"):
InferenceConfig(name="bad", disable_cloud_escalation=True)

with pytest.raises(ValueError, match="cannot be less than 0.0"):
InferenceConfig(
name="bad_escalation_interval",
always_return_edge_prediction=True,
min_time_between_escalations=-1.0,
)