diff --git a/src/groundlight/edge/__init__.py b/src/groundlight/edge/__init__.py new file mode 100644 index 00000000..a3479721 --- /dev/null +++ b/src/groundlight/edge/__init__.py @@ -0,0 +1,23 @@ +from .config import ( + DEFAULT, + DISABLED, + EDGE_ANSWERS_WITH_ESCALATION, + NO_CLOUD, + DetectorConfig, + DetectorsConfig, + EdgeEndpointConfig, + GlobalConfig, + InferenceConfig, +) + +__all__ = [ + "DEFAULT", + "DISABLED", + "EDGE_ANSWERS_WITH_ESCALATION", + "NO_CLOUD", + "DetectorsConfig", + "DetectorConfig", + "EdgeEndpointConfig", + "GlobalConfig", + "InferenceConfig", +] diff --git a/src/groundlight/edge/config.py b/src/groundlight/edge/config.py new file mode 100644 index 00000000..677a56a2 --- /dev/null +++ b/src/groundlight/edge/config.py @@ -0,0 +1,183 @@ +from typing import Any, Optional, Union + +from model import Detector +from pydantic import BaseModel, ConfigDict, Field, model_serializer, model_validator +from typing_extensions import Self + + +class GlobalConfig(BaseModel): + refresh_rate: float = Field( + default=60.0, + description="The interval (in seconds) at which the inference server checks for a new model binary update.", + ) + confident_audit_rate: float = Field( + default=1e-5, # A detector running at 1 FPS = ~100,000 IQ/day, so 1e-5 is ~1 confident IQ/day audited + description="The probability that any given confident prediction will be sent to the cloud for auditing.", + ) + + +class InferenceConfig(BaseModel): + """ + Configuration for edge inference on a specific detector. + """ + + # Keep shared presets immutable (DEFAULT/NO_CLOUD/etc.) so one mutation cannot globally change behavior. + model_config = ConfigDict(frozen=True) + + name: str = Field(..., exclude=True, description="A unique name for this inference config preset.") + enabled: bool = Field( + default=True, description="Whether the edge endpoint should accept image queries for this detector." + ) + api_token: Optional[str] = Field( + default=None, description="API token used to fetch the inference model for this detector." + ) + always_return_edge_prediction: bool = Field( + default=False, + description=( + "Indicates if the edge-endpoint should always provide edge ML predictions, regardless of confidence. " + "When this setting is true, whether or not the edge-endpoint should escalate low-confidence predictions " + "to the cloud is determined by `disable_cloud_escalation`." + ), + ) + disable_cloud_escalation: bool = Field( + default=False, + description=( + "Never escalate ImageQueries from the edge-endpoint to the cloud. " + "Requires `always_return_edge_prediction=True`." + ), + ) + min_time_between_escalations: float = Field( + default=2.0, + description=( + "The minimum time (in seconds) to wait between cloud escalations for a given detector. " + "Cannot be less than 0.0. " + "Only applies when `always_return_edge_prediction=True` and `disable_cloud_escalation=False`." + ), + ) + + @model_validator(mode="after") + def validate_configuration(self) -> Self: + if self.disable_cloud_escalation and not self.always_return_edge_prediction: + raise ValueError( + "The `disable_cloud_escalation` flag is only valid when `always_return_edge_prediction` is set to True." + ) + if self.min_time_between_escalations < 0.0: + raise ValueError("`min_time_between_escalations` cannot be less than 0.0.") + return self + + +class DetectorConfig(BaseModel): + """ + Configuration for a specific detector. + """ + + detector_id: str = Field(..., description="Detector ID") + edge_inference_config: str = Field(..., description="Config for edge inference.") + + +class DetectorsConfig(BaseModel): + """ + Detector and inference-config mappings for edge inference. + """ + + edge_inference_configs: dict[str, InferenceConfig] = Field(default_factory=dict) + detectors: list[DetectorConfig] = Field(default_factory=list) + + @model_validator(mode="after") + def validate_inference_configs(self): + """ + Validates detector config state. + Raises ValueError if dict keys mismatch InferenceConfig.name, detector IDs are duplicated, + or any detector references an undefined inference config. + """ + for name, config in self.edge_inference_configs.items(): + if name != config.name: + raise ValueError(f"Edge inference config key '{name}' must match InferenceConfig.name '{config.name}'.") + + seen_detector_ids = set() + duplicate_detector_ids = set() + for detector_config in self.detectors: + detector_id = detector_config.detector_id + if detector_id in seen_detector_ids: + duplicate_detector_ids.add(detector_id) + else: + seen_detector_ids.add(detector_id) + if duplicate_detector_ids: + duplicates = ", ".join(sorted(duplicate_detector_ids)) + raise ValueError(f"Duplicate detector IDs are not allowed: {duplicates}.") + + for detector_config in self.detectors: + if detector_config.edge_inference_config not in self.edge_inference_configs: + raise ValueError(f"Edge inference config '{detector_config.edge_inference_config}' not defined.") + return self + + def add_detector(self, detector: Union[str, Detector], edge_inference_config: InferenceConfig) -> None: + """Add a detector with the given inference config. Accepts detector ID or Detector object.""" + detector_id = detector.id if isinstance(detector, Detector) else detector + if any(existing.detector_id == detector_id for existing in self.detectors): + raise ValueError(f"A detector with ID '{detector_id}' already exists.") + + existing = self.edge_inference_configs.get(edge_inference_config.name) + if existing is None: + self.edge_inference_configs[edge_inference_config.name] = edge_inference_config + elif existing != edge_inference_config: + raise ValueError( + f"A different inference config named '{edge_inference_config.name}' is already registered." + ) + + self.detectors.append(DetectorConfig(detector_id=detector_id, edge_inference_config=edge_inference_config.name)) + + def to_payload(self) -> dict[str, Any]: + """Return flattened detector payload used by edge-endpoint config HTTP APIs.""" + return { + "edge_inference_configs": { + name: config.model_dump() for name, config in self.edge_inference_configs.items() + }, + "detectors": [detector.model_dump() for detector in self.detectors], + } + + +class EdgeEndpointConfig(BaseModel): + """ + Top-level edge endpoint configuration. + """ + + global_config: GlobalConfig = Field(default_factory=GlobalConfig) + detectors_config: DetectorsConfig = Field(default_factory=DetectorsConfig) + + @property + def edge_inference_configs(self) -> dict[str, InferenceConfig]: + """Convenience accessor for detector inference config map.""" + return self.detectors_config.edge_inference_configs + + @property + def detectors(self) -> list[DetectorConfig]: + """Convenience accessor for detector assignments.""" + return self.detectors_config.detectors + + @model_serializer(mode="plain") + def serialize(self): + """Serialize to the flattened shape expected by edge-endpoint configs.""" + return { + "global_config": self.global_config.model_dump(), + **self.detectors_config.to_payload(), + } + + def add_detector(self, detector: Union[str, Detector], edge_inference_config: InferenceConfig) -> None: + """Add a detector with the given inference config. Accepts detector ID or Detector object.""" + self.detectors_config.add_detector(detector, edge_inference_config) + + +# Preset inference configs matching the standard edge-endpoint defaults. +DEFAULT = InferenceConfig(name="default") +EDGE_ANSWERS_WITH_ESCALATION = InferenceConfig( + name="edge_answers_with_escalation", + always_return_edge_prediction=True, + min_time_between_escalations=2.0, +) +NO_CLOUD = InferenceConfig( + name="no_cloud", + always_return_edge_prediction=True, + disable_cloud_escalation=True, +) +DISABLED = InferenceConfig(name="disabled", enabled=False) diff --git a/test/unit/test_edge_config.py b/test/unit/test_edge_config.py new file mode 100644 index 00000000..ffcac800 --- /dev/null +++ b/test/unit/test_edge_config.py @@ -0,0 +1,178 @@ +from datetime import datetime, timezone + +import pytest +from groundlight.edge import ( + DEFAULT, + DISABLED, + EDGE_ANSWERS_WITH_ESCALATION, + NO_CLOUD, + DetectorsConfig, + EdgeEndpointConfig, + GlobalConfig, + InferenceConfig, +) +from model import Detector, DetectorTypeEnum + +CUSTOM_REFRESH_RATE = 10.0 +CUSTOM_AUDIT_RATE = 0.0 + + +def _make_detector(detector_id: str) -> Detector: + return Detector( + id=detector_id, + type=DetectorTypeEnum.detector, + created_at=datetime.now(timezone.utc), + name="test detector", + query="Is there a dog?", + group_name="default", + metadata=None, + mode="BINARY", + mode_configuration=None, + ) + + +def test_edge_endpoint_config_is_not_subclass_of_detectors_config(): + assert not issubclass(EdgeEndpointConfig, DetectorsConfig) + + +def test_add_detector_allows_equivalent_named_inference_config(): + detectors_config = DetectorsConfig() + detectors_config.add_detector( + "det_1", + InferenceConfig( + name="custom_config", + always_return_edge_prediction=True, + min_time_between_escalations=0.5, + ), + ) + detectors_config.add_detector( + "det_2", + InferenceConfig( + name="custom_config", + always_return_edge_prediction=True, + min_time_between_escalations=0.5, + ), + ) + + assert len(detectors_config.detectors) == 2 # noqa: PLR2004 + assert list(detectors_config.edge_inference_configs.keys()) == ["custom_config"] + + +def test_add_detector_rejects_different_named_inference_config(): + detectors_config = DetectorsConfig() + detectors_config.add_detector("det_1", InferenceConfig(name="custom_config")) + + with pytest.raises(ValueError, match="different inference config named 'custom_config'"): + detectors_config.add_detector( + "det_2", + InferenceConfig(name="custom_config", always_return_edge_prediction=True), + ) + + +def test_add_detector_rejects_duplicate_detector_id(): + detectors_config = DetectorsConfig() + detectors_config.add_detector("det_1", DEFAULT) + + with pytest.raises(ValueError, match="already exists"): + detectors_config.add_detector("det_1", DEFAULT) + + +def test_constructor_rejects_duplicate_detector_ids(): + with pytest.raises(ValueError, match="Duplicate detector IDs"): + DetectorsConfig( + edge_inference_configs={"default": DEFAULT}, + detectors=[ + {"detector_id": "det_1", "edge_inference_config": "default"}, + {"detector_id": "det_1", "edge_inference_config": "default"}, + ], + ) + + +def test_constructor_rejects_mismatched_inference_config_key_and_name(): + with pytest.raises(ValueError, match="must match InferenceConfig.name"): + DetectorsConfig( + edge_inference_configs={"default": InferenceConfig(name="not_default")}, + detectors=[], + ) + + +def test_constructor_accepts_matching_inference_config_key_and_name(): + config = DetectorsConfig( + edge_inference_configs={"default": InferenceConfig(name="default")}, + detectors=[{"detector_id": "det_1", "edge_inference_config": "default"}], + ) + + assert list(config.edge_inference_configs.keys()) == ["default"] + assert [detector.detector_id for detector in config.detectors] == ["det_1"] + + +def test_constructor_rejects_undefined_inference_config_reference(): + with pytest.raises(ValueError, match="not defined"): + DetectorsConfig( + edge_inference_configs={}, + detectors=[{"detector_id": "det_1", "edge_inference_config": "does_not_exist"}], + ) + + +def test_edge_endpoint_config_add_detector_delegates_to_detectors_logic(): + config = EdgeEndpointConfig() + config.add_detector("det_1", NO_CLOUD) + config.add_detector("det_2", EDGE_ANSWERS_WITH_ESCALATION) + config.add_detector("det_3", DEFAULT) + + assert [detector.detector_id for detector in config.detectors] == ["det_1", "det_2", "det_3"] + assert set(config.edge_inference_configs.keys()) == {"no_cloud", "edge_answers_with_escalation", "default"} + + +def test_add_detector_accepts_detector_object(): + config = EdgeEndpointConfig() + config.add_detector(_make_detector("det_1"), DEFAULT) + + assert [detector.detector_id for detector in config.detectors] == ["det_1"] + + +def test_disabled_preset_can_be_used(): + config = EdgeEndpointConfig() + config.add_detector("det_1", DISABLED) + + assert [detector.edge_inference_config for detector in config.detectors] == ["disabled"] + assert config.edge_inference_configs["disabled"] == DISABLED + + +def test_detectors_config_to_payload_shape(): + detectors_config = DetectorsConfig() + detectors_config.add_detector("det_1", DEFAULT) + detectors_config.add_detector("det_2", NO_CLOUD) + + payload = detectors_config.to_payload() + + assert len(payload["detectors"]) == 2 # noqa: PLR2004 + assert set(payload["edge_inference_configs"].keys()) == {"default", "no_cloud"} + + +def test_model_dump_shape_for_edge_endpoint_config(): + config = EdgeEndpointConfig( + global_config=GlobalConfig(refresh_rate=CUSTOM_REFRESH_RATE, confident_audit_rate=CUSTOM_AUDIT_RATE) + ) + config.add_detector("det_1", DEFAULT) + config.add_detector("det_2", EDGE_ANSWERS_WITH_ESCALATION) + config.add_detector("det_3", NO_CLOUD) + + payload = config.model_dump() + + assert payload["global_config"]["refresh_rate"] == CUSTOM_REFRESH_RATE + assert payload["global_config"]["confident_audit_rate"] == CUSTOM_AUDIT_RATE + assert len(payload["detectors"]) == 3 # noqa: PLR2004 + assert set(payload["edge_inference_configs"].keys()) == {"default", "edge_answers_with_escalation", "no_cloud"} + + +def test_inference_config_validation_errors(): + with pytest.raises(ValueError, match="disable_cloud_escalation"): + InferenceConfig(name="bad", disable_cloud_escalation=True) + + with pytest.raises(ValueError, match="cannot be less than 0.0"): + InferenceConfig( + name="bad_escalation_interval", + always_return_edge_prediction=True, + min_time_between_escalations=-1.0, + )