diff --git a/sagemaker-train/src/sagemaker/train/utils.py b/sagemaker-train/src/sagemaker/train/utils.py index 0abd7596b5..74f4a66695 100644 --- a/sagemaker-train/src/sagemaker/train/utils.py +++ b/sagemaker-train/src/sagemaker/train/utils.py @@ -168,22 +168,26 @@ def convert_unassigned_to_none(instance) -> Any: return instance -def safe_serialize(data): +def safe_serialize(data) -> "str | PipelineVariable": """Serialize the data without wrapping strings in quotes. This function handles the following cases: - 1. If `data` is a string, it returns the string as-is without wrapping in quotes. - 2. If `data` is of type `PipelineVariable`, it returns the json representation of the PipelineVariable - 3. If `data` is serializable (e.g., a dictionary, list, int, float), it returns - the JSON-encoded string using `json.dumps()`. - 4. If `data` cannot be serialized (e.g., a custom object), it returns the string - representation of the data using `str(data)`. + 1. If ``data`` is a string, it returns the string as-is without wrapping in quotes. + 2. If ``data`` is of type :class:`~sagemaker.core.workflow.parameters.PipelineVariable`, + it returns the object directly so that pipeline serialization can handle it + downstream. Callers should be aware that the return value may be a + ``PipelineVariable`` rather than a plain ``str``. + 3. If ``data`` is serializable (e.g., a dictionary, list, int, float), it returns + the JSON-encoded string using ``json.dumps()``. + 4. If ``data`` cannot be serialized (e.g., a custom object), it returns the string + representation of the data using ``str(data)``. Args: data (Any): The data to serialize. Returns: - str: The serialized JSON-compatible string or the string representation of the input. + str | PipelineVariable: The serialized JSON-compatible string, the string + representation of the input, or the original ``PipelineVariable`` object. """ if isinstance(data, str): return data @@ -192,7 +196,14 @@ def safe_serialize(data): try: return json.dumps(data) except TypeError: - return str(data) + try: + return str(data) + except TypeError: + # PipelineVariable.__str__ raises TypeError by design. + # If the isinstance check above didn't catch it (e.g. import + # path mismatch or reload issues), return the object directly + # so pipeline serialization can handle it downstream. + return data def _run_clone_command_silent(repo_url, dest_dir): diff --git a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py index 3fd34fa47b..ce75c133f4 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py @@ -26,7 +26,7 @@ from sagemaker.core.helper.session_helper import Session from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar -from sagemaker.core.workflow.parameters import ParameterString +from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger from sagemaker.train.model_trainer import ModelTrainer, Mode from sagemaker.train.configs import ( Compute, @@ -176,3 +176,54 @@ def test_training_image_rejects_invalid_type(self): stopping_condition=DEFAULT_STOPPING, output_data_config=DEFAULT_OUTPUT, ) + + +class TestModelTrainerHyperparametersPipelineVariable: + """Test that ModelTrainer correctly preserves PipelineVariable objects in hyperparameters.""" + + def test_hyperparameters_preserves_pipeline_variable_integer(self): + """ParameterInteger in hyperparameters should be preserved in ModelTrainer.""" + max_depth = ParameterInteger(name="MaxDepth", default_value=5) + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + hyperparameters={"max_depth": max_depth}, + ) + assert trainer.hyperparameters["max_depth"] is max_depth + + def test_hyperparameters_preserves_pipeline_variable_string(self): + """ParameterString in hyperparameters should be preserved in ModelTrainer.""" + optimizer = ParameterString(name="Optimizer", default_value="sgd") + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + hyperparameters={"optimizer": optimizer}, + ) + assert trainer.hyperparameters["optimizer"] is optimizer + + def test_hyperparameters_preserves_mixed_pipeline_and_regular_values(self): + """Mixed PipelineVariable and regular values should all be preserved.""" + max_depth = ParameterInteger(name="MaxDepth", default_value=5) + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + hyperparameters={ + "max_depth": max_depth, + "eta": 0.1, + "objective": "binary:logistic", + }, + ) + # PipelineVariable should be preserved as-is + assert trainer.hyperparameters["max_depth"] is max_depth + # Regular values should also be preserved + assert trainer.hyperparameters["eta"] == 0.1 + assert trainer.hyperparameters["objective"] == "binary:logistic" diff --git a/sagemaker-train/tests/unit/train/test_utils.py b/sagemaker-train/tests/unit/train/test_utils.py new file mode 100644 index 0000000000..1c72ee73b6 --- /dev/null +++ b/sagemaker-train/tests/unit/train/test_utils.py @@ -0,0 +1,96 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for sagemaker.train.utils – specifically safe_serialize.""" +from __future__ import absolute_import + +import json + +from sagemaker.train.utils import safe_serialize +from sagemaker.core.workflow.parameters import ( + ParameterInteger, + ParameterString, + ParameterFloat, +) + + +# --------------------------------------------------------------------------- +# PipelineVariable inputs – should be returned as-is (identity) +# --------------------------------------------------------------------------- + +def test_safe_serialize_with_pipeline_variable_integer_returns_object_directly(): + param = ParameterInteger(name="MaxDepth", default_value=5) + result = safe_serialize(param) + assert result is param + + +def test_safe_serialize_with_pipeline_variable_string_returns_object_directly(): + param = ParameterString(name="Optimizer", default_value="sgd") + result = safe_serialize(param) + assert result is param + + +def test_safe_serialize_with_pipeline_variable_float_returns_object_directly(): + param = ParameterFloat(name="LearningRate", default_value=0.01) + result = safe_serialize(param) + assert result is param + + +# --------------------------------------------------------------------------- +# Regular / primitive inputs +# --------------------------------------------------------------------------- + +def test_safe_serialize_with_string_returns_string_as_is(): + assert safe_serialize("hello") == "hello" + assert safe_serialize("12345") == "12345" + + +def test_safe_serialize_with_json_like_string_returns_as_is(): + """A string that looks like JSON should be returned as-is, not double-serialized.""" + json_str = '{"key": "value"}' + assert safe_serialize(json_str) == json_str + + +def test_safe_serialize_with_int_returns_json_string(): + assert safe_serialize(5) == "5" + assert safe_serialize(0) == "0" + + +def test_safe_serialize_with_dict_returns_json_string(): + data = {"key": "value", "num": 1} + assert safe_serialize(data) == json.dumps(data) + + +def test_safe_serialize_with_bool_returns_json_string(): + assert safe_serialize(True) == "true" + assert safe_serialize(False) == "false" + + +def test_safe_serialize_with_custom_object_returns_str(): + class CustomObject: + def __str__(self): + return "CustomObject" + + obj = CustomObject() + assert safe_serialize(obj) == "CustomObject" + + +def test_safe_serialize_with_none_returns_json_null(): + assert safe_serialize(None) == "null" + + +def test_safe_serialize_with_list_returns_json_string(): + assert safe_serialize([1, 2, 3]) == "[1, 2, 3]" + + +def test_safe_serialize_with_empty_string(): + assert safe_serialize("") == ""