From 95a19c5de6a2fa132395ff89130f7ab085666c0a Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Fri, 27 Mar 2026 15:48:57 -0400 Subject: [PATCH 1/2] fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504) --- sagemaker-train/src/sagemaker/train/utils.py | 11 ++- .../test_model_trainer_pipeline_variable.py | 61 ++++++++++++- .../tests/unit/train/test_utils.py | 90 +++++++++++++++++++ 3 files changed, 160 insertions(+), 2 deletions(-) create mode 100644 sagemaker-train/tests/unit/train/test_utils.py diff --git a/sagemaker-train/src/sagemaker/train/utils.py b/sagemaker-train/src/sagemaker/train/utils.py index 0abd7596b5..633dab67d3 100644 --- a/sagemaker-train/src/sagemaker/train/utils.py +++ b/sagemaker-train/src/sagemaker/train/utils.py @@ -192,7 +192,16 @@ def safe_serialize(data): try: return json.dumps(data) except TypeError: - return str(data) + try: + return str(data) + except TypeError: + # PipelineVariable.__str__ raises TypeError by design. + # If the isinstance check above didn't catch it (e.g. import + # path mismatch), fall back to returning the object directly + # when it looks like a PipelineVariable (has an ``expr`` property). + if hasattr(data, "expr"): + return data + raise def _run_clone_command_silent(repo_url, dest_dir): diff --git a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py index 3fd34fa47b..14a1de7ced 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py @@ -26,13 +26,14 @@ from sagemaker.core.helper.session_helper import Session from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar -from sagemaker.core.workflow.parameters import ParameterString +from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger, ParameterFloat from sagemaker.train.model_trainer import ModelTrainer, Mode from sagemaker.train.configs import ( Compute, StoppingCondition, OutputDataConfig, ) +from sagemaker.core.workflow.pipeline_context import PipelineSession from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE @@ -176,3 +177,61 @@ def test_training_image_rejects_invalid_type(self): stopping_condition=DEFAULT_STOPPING, output_data_config=DEFAULT_OUTPUT, ) + + +class TestModelTrainerHyperparametersPipelineVariable: + """Test that PipelineVariable objects in hyperparameters survive safe_serialize.""" + + def test_hyperparameters_with_pipeline_variable_integer(self): + """ParameterInteger in hyperparameters should be passed through as-is.""" + max_depth = ParameterInteger(name="MaxDepth", default_value=5) + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + hyperparameters={"max_depth": max_depth}, + ) + # safe_serialize should return the PipelineVariable object directly + from sagemaker.train.utils import safe_serialize + result = safe_serialize(max_depth) + assert result is max_depth + + def test_hyperparameters_with_pipeline_variable_string(self): + """ParameterString in hyperparameters should be passed through as-is.""" + optimizer = ParameterString(name="Optimizer", default_value="sgd") + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + hyperparameters={"optimizer": optimizer}, + ) + from sagemaker.train.utils import safe_serialize + result = safe_serialize(optimizer) + assert result is optimizer + + def test_hyperparameters_with_mixed_pipeline_and_regular_values(self): + """Mixed PipelineVariable and regular values should both serialize correctly.""" + max_depth = ParameterInteger(name="MaxDepth", default_value=5) + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + hyperparameters={ + "max_depth": max_depth, + "eta": 0.1, + "objective": "binary:logistic", + }, + ) + from sagemaker.train.utils import safe_serialize + # PipelineVariable should be returned as-is + assert safe_serialize(max_depth) is max_depth + # Float should be JSON-serialized + assert safe_serialize(0.1) == "0.1" + # String should be returned as-is + assert safe_serialize("binary:logistic") == "binary:logistic" diff --git a/sagemaker-train/tests/unit/train/test_utils.py b/sagemaker-train/tests/unit/train/test_utils.py new file mode 100644 index 0000000000..e5d593dc91 --- /dev/null +++ b/sagemaker-train/tests/unit/train/test_utils.py @@ -0,0 +1,90 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for sagemaker.train.utils – specifically safe_serialize.""" +from __future__ import absolute_import + +import json + +from sagemaker.train.utils import safe_serialize +from sagemaker.core.workflow.parameters import ( + ParameterInteger, + ParameterString, + ParameterFloat, +) + + +# --------------------------------------------------------------------------- +# PipelineVariable inputs – should be returned as-is (identity) +# --------------------------------------------------------------------------- + +def test_safe_serialize_with_pipeline_variable_integer_returns_object_directly(): + param = ParameterInteger(name="MaxDepth", default_value=5) + result = safe_serialize(param) + assert result is param + + +def test_safe_serialize_with_pipeline_variable_string_returns_object_directly(): + param = ParameterString(name="Optimizer", default_value="sgd") + result = safe_serialize(param) + assert result is param + + +def test_safe_serialize_with_pipeline_variable_float_returns_object_directly(): + param = ParameterFloat(name="LearningRate", default_value=0.01) + result = safe_serialize(param) + assert result is param + + +# --------------------------------------------------------------------------- +# Regular / primitive inputs +# --------------------------------------------------------------------------- + +def test_safe_serialize_with_string_returns_string_as_is(): + assert safe_serialize("hello") == "hello" + assert safe_serialize("12345") == "12345" + + +def test_safe_serialize_with_int_returns_json_string(): + assert safe_serialize(5) == "5" + assert safe_serialize(0) == "0" + + +def test_safe_serialize_with_dict_returns_json_string(): + data = {"key": "value", "num": 1} + assert safe_serialize(data) == json.dumps(data) + + +def test_safe_serialize_with_bool_returns_json_string(): + assert safe_serialize(True) == "true" + assert safe_serialize(False) == "false" + + +def test_safe_serialize_with_custom_object_returns_str(): + class CustomObject: + def __str__(self): + return "CustomObject" + + obj = CustomObject() + assert safe_serialize(obj) == "CustomObject" + + +def test_safe_serialize_with_none_returns_json_null(): + assert safe_serialize(None) == "null" + + +def test_safe_serialize_with_list_returns_json_string(): + assert safe_serialize([1, 2, 3]) == "[1, 2, 3]" + + +def test_safe_serialize_with_empty_string(): + assert safe_serialize("") == "" From 00f479e7b35173aae1cb9f03603bef7771f5ae38 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Fri, 27 Mar 2026 15:54:10 -0400 Subject: [PATCH 2/2] fix: address review comments (iteration #1) --- sagemaker-train/src/sagemaker/train/utils.py | 28 +++++++------- .../test_model_trainer_pipeline_variable.py | 38 ++++++++----------- .../tests/unit/train/test_utils.py | 6 +++ 3 files changed, 36 insertions(+), 36 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/utils.py b/sagemaker-train/src/sagemaker/train/utils.py index 633dab67d3..74f4a66695 100644 --- a/sagemaker-train/src/sagemaker/train/utils.py +++ b/sagemaker-train/src/sagemaker/train/utils.py @@ -168,22 +168,26 @@ def convert_unassigned_to_none(instance) -> Any: return instance -def safe_serialize(data): +def safe_serialize(data) -> "str | PipelineVariable": """Serialize the data without wrapping strings in quotes. This function handles the following cases: - 1. If `data` is a string, it returns the string as-is without wrapping in quotes. - 2. If `data` is of type `PipelineVariable`, it returns the json representation of the PipelineVariable - 3. If `data` is serializable (e.g., a dictionary, list, int, float), it returns - the JSON-encoded string using `json.dumps()`. - 4. If `data` cannot be serialized (e.g., a custom object), it returns the string - representation of the data using `str(data)`. + 1. If ``data`` is a string, it returns the string as-is without wrapping in quotes. + 2. If ``data`` is of type :class:`~sagemaker.core.workflow.parameters.PipelineVariable`, + it returns the object directly so that pipeline serialization can handle it + downstream. Callers should be aware that the return value may be a + ``PipelineVariable`` rather than a plain ``str``. + 3. If ``data`` is serializable (e.g., a dictionary, list, int, float), it returns + the JSON-encoded string using ``json.dumps()``. + 4. If ``data`` cannot be serialized (e.g., a custom object), it returns the string + representation of the data using ``str(data)``. Args: data (Any): The data to serialize. Returns: - str: The serialized JSON-compatible string or the string representation of the input. + str | PipelineVariable: The serialized JSON-compatible string, the string + representation of the input, or the original ``PipelineVariable`` object. """ if isinstance(data, str): return data @@ -197,11 +201,9 @@ def safe_serialize(data): except TypeError: # PipelineVariable.__str__ raises TypeError by design. # If the isinstance check above didn't catch it (e.g. import - # path mismatch), fall back to returning the object directly - # when it looks like a PipelineVariable (has an ``expr`` property). - if hasattr(data, "expr"): - return data - raise + # path mismatch or reload issues), return the object directly + # so pipeline serialization can handle it downstream. + return data def _run_clone_command_silent(repo_url, dest_dir): diff --git a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py index 14a1de7ced..ce75c133f4 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py @@ -26,14 +26,13 @@ from sagemaker.core.helper.session_helper import Session from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar -from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger, ParameterFloat +from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger from sagemaker.train.model_trainer import ModelTrainer, Mode from sagemaker.train.configs import ( Compute, StoppingCondition, OutputDataConfig, ) -from sagemaker.core.workflow.pipeline_context import PipelineSession from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE @@ -180,10 +179,10 @@ def test_training_image_rejects_invalid_type(self): class TestModelTrainerHyperparametersPipelineVariable: - """Test that PipelineVariable objects in hyperparameters survive safe_serialize.""" + """Test that ModelTrainer correctly preserves PipelineVariable objects in hyperparameters.""" - def test_hyperparameters_with_pipeline_variable_integer(self): - """ParameterInteger in hyperparameters should be passed through as-is.""" + def test_hyperparameters_preserves_pipeline_variable_integer(self): + """ParameterInteger in hyperparameters should be preserved in ModelTrainer.""" max_depth = ParameterInteger(name="MaxDepth", default_value=5) trainer = ModelTrainer( training_image=DEFAULT_IMAGE, @@ -193,13 +192,10 @@ def test_hyperparameters_with_pipeline_variable_integer(self): output_data_config=DEFAULT_OUTPUT, hyperparameters={"max_depth": max_depth}, ) - # safe_serialize should return the PipelineVariable object directly - from sagemaker.train.utils import safe_serialize - result = safe_serialize(max_depth) - assert result is max_depth + assert trainer.hyperparameters["max_depth"] is max_depth - def test_hyperparameters_with_pipeline_variable_string(self): - """ParameterString in hyperparameters should be passed through as-is.""" + def test_hyperparameters_preserves_pipeline_variable_string(self): + """ParameterString in hyperparameters should be preserved in ModelTrainer.""" optimizer = ParameterString(name="Optimizer", default_value="sgd") trainer = ModelTrainer( training_image=DEFAULT_IMAGE, @@ -209,12 +205,10 @@ def test_hyperparameters_with_pipeline_variable_string(self): output_data_config=DEFAULT_OUTPUT, hyperparameters={"optimizer": optimizer}, ) - from sagemaker.train.utils import safe_serialize - result = safe_serialize(optimizer) - assert result is optimizer + assert trainer.hyperparameters["optimizer"] is optimizer - def test_hyperparameters_with_mixed_pipeline_and_regular_values(self): - """Mixed PipelineVariable and regular values should both serialize correctly.""" + def test_hyperparameters_preserves_mixed_pipeline_and_regular_values(self): + """Mixed PipelineVariable and regular values should all be preserved.""" max_depth = ParameterInteger(name="MaxDepth", default_value=5) trainer = ModelTrainer( training_image=DEFAULT_IMAGE, @@ -228,10 +222,8 @@ def test_hyperparameters_with_mixed_pipeline_and_regular_values(self): "objective": "binary:logistic", }, ) - from sagemaker.train.utils import safe_serialize - # PipelineVariable should be returned as-is - assert safe_serialize(max_depth) is max_depth - # Float should be JSON-serialized - assert safe_serialize(0.1) == "0.1" - # String should be returned as-is - assert safe_serialize("binary:logistic") == "binary:logistic" + # PipelineVariable should be preserved as-is + assert trainer.hyperparameters["max_depth"] is max_depth + # Regular values should also be preserved + assert trainer.hyperparameters["eta"] == 0.1 + assert trainer.hyperparameters["objective"] == "binary:logistic" diff --git a/sagemaker-train/tests/unit/train/test_utils.py b/sagemaker-train/tests/unit/train/test_utils.py index e5d593dc91..1c72ee73b6 100644 --- a/sagemaker-train/tests/unit/train/test_utils.py +++ b/sagemaker-train/tests/unit/train/test_utils.py @@ -54,6 +54,12 @@ def test_safe_serialize_with_string_returns_string_as_is(): assert safe_serialize("12345") == "12345" +def test_safe_serialize_with_json_like_string_returns_as_is(): + """A string that looks like JSON should be returned as-is, not double-serialized.""" + json_str = '{"key": "value"}' + assert safe_serialize(json_str) == json_str + + def test_safe_serialize_with_int_returns_json_string(): assert safe_serialize(5) == "5" assert safe_serialize(0) == "0"