Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions sagemaker-train/src/sagemaker/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,22 +168,26 @@ def convert_unassigned_to_none(instance) -> Any:
return instance


def safe_serialize(data):
def safe_serialize(data) -> "str | PipelineVariable":
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type "str | PipelineVariable" is a forward reference string, but this module should use from __future__ import annotations at the top to enable PEP 604 union syntax properly. Also, this is a behavioral change to the function's contract — previously it always returned str, now it can return PipelineVariable. This could cause TypeError in downstream callers that expect a str (e.g., calling .encode(), concatenation, etc.). Have you audited all call sites of safe_serialize to ensure they handle a PipelineVariable return value correctly?

Suggestion: Add from __future__ import annotations at the module top, and change the annotation to:

def safe_serialize(data) -> str | PipelineVariable:

Also consider using Union[str, PipelineVariable] with an explicit import if from __future__ import annotations is not already used in this module.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing type annotation for the data parameter. Per SDK conventions, all public functions must have type annotations for parameters and return types.

def safe_serialize(data: Any) -> str | PipelineVariable:

"""Serialize the data without wrapping strings in quotes.

This function handles the following cases:
1. If `data` is a string, it returns the string as-is without wrapping in quotes.
2. If `data` is of type `PipelineVariable`, it returns the json representation of the PipelineVariable
3. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
the JSON-encoded string using `json.dumps()`.
4. If `data` cannot be serialized (e.g., a custom object), it returns the string
representation of the data using `str(data)`.
1. If ``data`` is a string, it returns the string as-is without wrapping in quotes.
2. If ``data`` is of type :class:`~sagemaker.core.workflow.parameters.PipelineVariable`,
it returns the object directly so that pipeline serialization can handle it
downstream. Callers should be aware that the return value may be a
``PipelineVariable`` rather than a plain ``str``.
3. If ``data`` is serializable (e.g., a dictionary, list, int, float), it returns
the JSON-encoded string using ``json.dumps()``.
4. If ``data`` cannot be serialized (e.g., a custom object), it returns the string
representation of the data using ``str(data)``.

Args:
data (Any): The data to serialize.

Returns:
str: The serialized JSON-compatible string or the string representation of the input.
str | PipelineVariable: The serialized JSON-compatible string, the string
representation of the input, or the original ``PipelineVariable`` object.
"""
if isinstance(data, str):
return data
Expand All @@ -192,7 +196,14 @@ def safe_serialize(data):
try:
return json.dumps(data)
except TypeError:
return str(data)
try:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nested try/except is a reasonable defensive measure, but catching a bare TypeError from str(data) and silently returning the raw object is risky. This means any object whose __str__ raises TypeError (not just PipelineVariable) will be returned as-is, potentially causing unexpected behavior downstream. Consider being more explicit:

except TypeError:
    try:
        return str(data)
    except TypeError:
        if isinstance(data, PipelineVariable):
            return data
        raise

This way, only known PipelineVariable objects get the pass-through treatment, and truly broken objects still raise.

return str(data)
except TypeError:
# PipelineVariable.__str__ raises TypeError by design.
# If the isinstance check above didn't catch it (e.g. import
# path mismatch or reload issues), return the object directly
# so pipeline serialization can handle it downstream.
return data


def _run_clone_command_silent(repo_url, dest_dir):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from sagemaker.core.helper.session_helper import Session
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
from sagemaker.core.workflow.parameters import ParameterString
from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger
from sagemaker.train.model_trainer import ModelTrainer, Mode
from sagemaker.train.configs import (
Compute,
Expand Down Expand Up @@ -176,3 +176,54 @@ def test_training_image_rejects_invalid_type(self):
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
)


class TestModelTrainerHyperparametersPipelineVariable:
"""Test that ModelTrainer correctly preserves PipelineVariable objects in hyperparameters."""

def test_hyperparameters_preserves_pipeline_variable_integer(self):
"""ParameterInteger in hyperparameters should be preserved in ModelTrainer."""
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
hyperparameters={"max_depth": max_depth},
)
assert trainer.hyperparameters["max_depth"] is max_depth

def test_hyperparameters_preserves_pipeline_variable_string(self):
"""ParameterString in hyperparameters should be preserved in ModelTrainer."""
optimizer = ParameterString(name="Optimizer", default_value="sgd")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests verify that ModelTrainer preserves PipelineVariable objects in the hyperparameters dict, but they don't test the actual serialization path (i.e., when safe_serialize is called during job creation). Consider adding a test that mocks the training job creation to verify that PipelineVariable values survive the full serialization pipeline, not just assignment.

trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
hyperparameters={"optimizer": optimizer},
)
assert trainer.hyperparameters["optimizer"] is optimizer

def test_hyperparameters_preserves_mixed_pipeline_and_regular_values(self):
"""Mixed PipelineVariable and regular values should all be preserved."""
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
hyperparameters={
"max_depth": max_depth,
"eta": 0.1,
"objective": "binary:logistic",
},
)
# PipelineVariable should be preserved as-is
assert trainer.hyperparameters["max_depth"] is max_depth
# Regular values should also be preserved
assert trainer.hyperparameters["eta"] == 0.1
assert trainer.hyperparameters["objective"] == "binary:logistic"
96 changes: 96 additions & 0 deletions sagemaker-train/tests/unit/train/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Unit tests for sagemaker.train.utils – specifically safe_serialize."""
from __future__ import absolute_import
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per PEP 484 / SDK conventions, new modules should use from __future__ import annotations instead of from __future__ import absolute_import (which is a Python 2 artifact and unnecessary in Python 3).


import json

from sagemaker.train.utils import safe_serialize
from sagemaker.core.workflow.parameters import (
ParameterInteger,
ParameterString,
ParameterFloat,
)


# ---------------------------------------------------------------------------
# PipelineVariable inputs – should be returned as-is (identity)
# ---------------------------------------------------------------------------

def test_safe_serialize_with_pipeline_variable_integer_returns_object_directly():
param = ParameterInteger(name="MaxDepth", default_value=5)
result = safe_serialize(param)
assert result is param


def test_safe_serialize_with_pipeline_variable_string_returns_object_directly():
param = ParameterString(name="Optimizer", default_value="sgd")
result = safe_serialize(param)
assert result is param


def test_safe_serialize_with_pipeline_variable_float_returns_object_directly():
param = ParameterFloat(name="LearningRate", default_value=0.01)
result = safe_serialize(param)
assert result is param


# ---------------------------------------------------------------------------
# Regular / primitive inputs
# ---------------------------------------------------------------------------

def test_safe_serialize_with_string_returns_string_as_is():
assert safe_serialize("hello") == "hello"
assert safe_serialize("12345") == "12345"


def test_safe_serialize_with_json_like_string_returns_as_is():
"""A string that looks like JSON should be returned as-is, not double-serialized."""
json_str = '{"key": "value"}'
assert safe_serialize(json_str) == json_str


def test_safe_serialize_with_int_returns_json_string():
assert safe_serialize(5) == "5"
assert safe_serialize(0) == "0"


def test_safe_serialize_with_dict_returns_json_string():
data = {"key": "value", "num": 1}
assert safe_serialize(data) == json.dumps(data)


def test_safe_serialize_with_bool_returns_json_string():
assert safe_serialize(True) == "true"
assert safe_serialize(False) == "false"


def test_safe_serialize_with_custom_object_returns_str():
class CustomObject:
def __str__(self):
return "CustomObject"

obj = CustomObject()
assert safe_serialize(obj) == "CustomObject"


def test_safe_serialize_with_none_returns_json_null():
assert safe_serialize(None) == "null"


def test_safe_serialize_with_list_returns_json_string():
assert safe_serialize([1, 2, 3]) == "[1, 2, 3]"


def test_safe_serialize_with_empty_string():
assert safe_serialize("") == ""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good test coverage overall! Consider adding a test for the specific edge case this PR is fixing — where isinstance check for PipelineVariable might fail (e.g., import path mismatch). You could mock isinstance or create a mock object whose __str__ raises TypeError to verify the nested except path:

def test_safe_serialize_with_object_whose_str_raises_typeerror():
    """Objects whose __str__ raises TypeError should be returned as-is."""
    class BadStr:
        def __str__(self):
            raise TypeError("cannot convert")
    obj = BadStr()
    result = safe_serialize(obj)
    assert result is obj

This directly tests the new fallback code path added in this PR.

Loading