Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 117 additions & 1 deletion skops/io/_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sklearn.tree._tree import Tree

from ._audit import Node, get_tree
from ._general import TypeNode, unsupported_get_state
from ._general import ObjectNode, TypeNode, object_get_state, unsupported_get_state
from ._protocol import PROTOCOL
from ._utils import LoadContext, SaveContext, get_module, get_state, gettype
from .exceptions import UnsupportedTypeException
Expand Down Expand Up @@ -97,9 +97,94 @@
LossFunction = None


SKLEARN_INTERNAL_OBJECTS: set[type] = set()
SKLEARN_TYPE_NAME_OVERRIDES: dict[type, str] = {}

try:
from sklearn._loss.link import (
HalfLogitLink,
IdentityLink,
Interval,
LogitLink,
LogLink,
MultinomialLogit,
)

SKLEARN_INTERNAL_OBJECTS |= {
HalfLogitLink,
IdentityLink,
Interval,
LogLink,
LogitLink,
MultinomialLogit,
}
except ImportError:
pass

try:
from sklearn._loss.loss import (
AbsoluteError,
ExponentialLoss,
HalfBinomialLoss,
HalfGammaLoss,
HalfMultinomialLoss,
HalfPoissonLoss,
HalfSquaredError,
HuberLoss,
PinballLoss,
)

SKLEARN_INTERNAL_OBJECTS |= {
AbsoluteError,
ExponentialLoss,
HalfBinomialLoss,
HalfGammaLoss,
HalfMultinomialLoss,
HalfPoissonLoss,
HalfSquaredError,
HuberLoss,
PinballLoss,
}
except ImportError:
pass

if "CyHalfMultinomialLoss" in globals():
SKLEARN_INTERNAL_OBJECTS.add(CyHalfMultinomialLoss)
SKLEARN_TYPE_NAME_OVERRIDES[CyHalfMultinomialLoss] = (
"sklearn._loss._loss.CyHalfMultinomialLoss"
)

try:
from sklearn.ensemble._hist_gradient_boosting.binning import _BinMapper
from sklearn.ensemble._hist_gradient_boosting.predictor import TreePredictor

SKLEARN_INTERNAL_OBJECTS |= {_BinMapper, TreePredictor}
except ImportError:
pass


UNSUPPORTED_TYPES = {Birch}


def get_sklearn_internal_type_name(type_: type) -> str:
return SKLEARN_TYPE_NAME_OVERRIDES.get(
type_, get_module(type_) + "." + type_.__name__
)


TRUSTED_SKLEARN_INTERNAL_TYPE_NAMES = [
get_sklearn_internal_type_name(type_) for type_ in SKLEARN_INTERNAL_OBJECTS
]

if not all(
type_name.startswith("sklearn.")
for type_name in TRUSTED_SKLEARN_INTERNAL_TYPE_NAMES
):
raise RuntimeError(
"All trusted sklearn internal type names must start with 'sklearn.'."
)


def reduce_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
# This method is for objects for which we have to use the __reduce__
# method to get the state.
Expand Down Expand Up @@ -260,6 +345,33 @@ def __init__(
)


def sklearn_internal_object_get_state(
obj: Any, save_context: SaveContext
) -> dict[str, Any]:
state = object_get_state(obj, save_context)
module_name, _, class_name = get_sklearn_internal_type_name(type(obj)).rpartition(
"."
)
state["__module__"] = module_name
state["__class__"] = class_name
state["__loader__"] = "SklearnInternalObjectNode"
return state


class SklearnInternalObjectNode(ObjectNode):
def __init__(
self,
state: dict[str, Any],
load_context: LoadContext,
trusted: Optional[Sequence[str]] = None,
) -> None:
super().__init__(state, load_context, trusted)
self.trusted = self._get_trusted(
trusted,
default=TRUSTED_SKLEARN_INTERNAL_TYPE_NAMES,
)


# TODO: remove once support for sklearn<1.2 is dropped.
def _DictWithDeprecatedKeys_get_state(
obj: Any, save_context: SaveContext
Expand Down Expand Up @@ -321,12 +433,16 @@ def _construct(self):
if CyLossFunction is not None:
GET_STATE_DISPATCH_FUNCTIONS.append((CyLossFunction, loss_get_state))

for type_ in SKLEARN_INTERNAL_OBJECTS:
GET_STATE_DISPATCH_FUNCTIONS.append((type_, sklearn_internal_object_get_state))

for type_ in UNSUPPORTED_TYPES:
GET_STATE_DISPATCH_FUNCTIONS.append((type_, unsupported_get_state))

# tuples of type and function that creates the instance of that type
NODE_TYPE_MAPPING: dict[tuple[str, int], Any] = {
("LossNode", PROTOCOL): LossNode,
("SklearnInternalObjectNode", PROTOCOL): SklearnInternalObjectNode,
("TreeNode", PROTOCOL): TreeNode,
}

Expand Down
141 changes: 141 additions & 0 deletions skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
from sklearn.datasets import load_sample_images, make_classification, make_regression
from sklearn.decomposition import SparseCoder
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
from sklearn.ensemble import (
GradientBoostingClassifier,
GradientBoostingRegressor,
HistGradientBoostingClassifier,
HistGradientBoostingRegressor,
)
from sklearn.exceptions import SkipTestWarning
from sklearn.experimental import enable_halving_search_cv # noqa
from sklearn.feature_extraction.text import TfidfVectorizer
Expand Down Expand Up @@ -438,6 +444,141 @@ def test_can_trust_types(type_):
assert len(untrusted_types) == 0


@pytest.mark.parametrize(
("estimator", "problem_type"),
[
pytest.param(
GradientBoostingClassifier(loss="log_loss", n_estimators=5),
"multiclass",
id="GradientBoostingClassifier-log_loss-multiclass",
),
pytest.param(
GradientBoostingClassifier(loss="exponential", n_estimators=5),
"binary",
id="GradientBoostingClassifier-exponential",
),
pytest.param(
GradientBoostingRegressor(loss="squared_error", n_estimators=5),
"regression",
id="GradientBoostingRegressor-squared_error",
),
pytest.param(
GradientBoostingRegressor(loss="absolute_error", n_estimators=5),
"regression",
id="GradientBoostingRegressor-absolute_error",
),
pytest.param(
GradientBoostingRegressor(loss="huber", n_estimators=5),
"regression",
id="GradientBoostingRegressor-huber",
),
pytest.param(
GradientBoostingRegressor(loss="quantile", n_estimators=5, alpha=0.8),
"regression",
id="GradientBoostingRegressor-quantile",
),
pytest.param(
HistGradientBoostingClassifier(loss="log_loss", max_iter=5),
"binary",
id="HistGradientBoostingClassifier-log_loss",
),
pytest.param(
HistGradientBoostingRegressor(loss="gamma", max_iter=5),
"positive_regression",
id="HistGradientBoostingRegressor-gamma",
),
pytest.param(
HistGradientBoostingRegressor(loss="poisson", max_iter=5),
"positive_regression",
id="HistGradientBoostingRegressor-poisson",
),
],
)
def test_gradient_boosting_estimators_have_no_untrusted_types(estimator, problem_type):
set_random_state(estimator, random_state=0)

if problem_type == "binary":
X, y = make_classification(
n_samples=N_SAMPLES,
n_features=N_FEATURES,
n_classes=2,
n_informative=5,
random_state=0,
)
elif problem_type == "multiclass":
X, y = make_classification(
n_samples=140,
n_features=N_FEATURES,
n_classes=3,
n_informative=8,
n_clusters_per_class=1,
random_state=0,
)
elif problem_type == "positive_regression":
X, y = make_regression(
n_samples=N_SAMPLES,
n_features=N_FEATURES,
random_state=0,
)
y = np.abs(y) + 1
else:
X, y = make_regression(
n_samples=N_SAMPLES,
n_features=N_FEATURES,
random_state=0,
)

with warnings.catch_warnings():
warnings.filterwarnings("ignore", module="sklearn")
estimator.fit(X, y)

dumped = dumps(estimator)

assert get_untrusted_types(data=dumped) == []

loaded = loads(dumped)
assert_method_outputs_equal(estimator, loaded, X)


def test_cyhalfmultinomialloss_is_serialized_under_sklearn_module():
estimator = GradientBoostingClassifier(loss="log_loss", n_estimators=5)
set_random_state(estimator, random_state=0)
X, y = make_classification(
n_samples=140,
n_features=N_FEATURES,
n_classes=3,
n_informative=8,
n_clusters_per_class=1,
random_state=0,
)

with warnings.catch_warnings():
warnings.filterwarnings("ignore", module="sklearn")
estimator.fit(X, y)

dumped = dumps(estimator)
with ZipFile(io.BytesIO(dumped), "r") as zip_file:
schema = json.loads(zip_file.read("schema.json"))

found = []

def walk(obj):
if isinstance(obj, dict):
if obj.get("__class__") == "CyHalfMultinomialLoss":
found.append(obj)
for value in obj.values():
walk(value)
elif isinstance(obj, list):
for value in obj:
walk(value)

walk(schema)

assert len(found) == 1
assert found[0]["__module__"] == "sklearn._loss._loss"
assert found[0]["__loader__"] == "SklearnInternalObjectNode"


@pytest.mark.parametrize(
"estimator", _unsupported_estimators(), ids=_get_check_estimator_ids
)
Expand Down