Skip to content

fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504)#5689

Draft
aviruthen wants to merge 2 commits intoaws:masterfrom
aviruthen:fix/bug-pipeline-parameters-parameterinteger-5504
Draft

fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504)#5689
aviruthen wants to merge 2 commits intoaws:masterfrom
aviruthen:fix/bug-pipeline-parameters-parameterinteger-5504

Conversation

@aviruthen
Copy link
Copy Markdown
Collaborator

Description

The safe_serialize function in sagemaker-train/src/sagemaker/train/utils.py already has a PipelineVariable isinstance check (lines 185-186), but the user's SDK version (3.3.1) may not have this fix. Additionally, the fallback except TypeError: return str(data) block is dangerous because PipelineVariable.__str__() intentionally raises TypeError, meaning if the isinstance check ever fails (e.g., import path mismatch, reload issues), the except block will re-raise. The fix needs to: (1) ensure the PipelineVariable isinstance check is solid, (2) make the except fallback more robust by catching the case where str() also raises TypeError, and (3) add missing unit tests for safe_serialize from the train utils module covering PipelineVariable inputs. There are no existing tests for the safe_serialize in sagemaker-train/src/sagemaker/train/utils.py.

Related Issue

Related issue: 5504

Changes Made

  • sagemaker-train/src/sagemaker/train/utils.py
  • sagemaker-train/tests/unit/train/test_utils.py
  • sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py

AI-Generated PR

This PR was automatically generated by the PySDK Issue Agent.

  • Confidence score: 85%
  • Classification: bug
  • SDK version target: V3

Merge Checklist

  • Changes are backward compatible
  • Commit message follows prefix: description format
  • Unit tests added/updated
  • Integration tests added (if applicable)
  • Documentation updated (if applicable)

Copy link
Copy Markdown
Member

@mufaddal-rohawala mufaddal-rohawala left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review

This PR fixes a bug where PipelineVariable objects (ParameterInteger, ParameterString) fail in safe_serialize due to their str() raising TypeError. The fallback approach using hasattr(data, 'expr') is fragile and duck-typing based, when a more robust solution would be to catch the PipelineVariable case explicitly. The tests are a good addition but have some issues including unused variables and redundant test logic.

except TypeError:
# PipelineVariable.__str__ raises TypeError by design.
# If the isinstance check above didn't catch it (e.g. import
# path mismatch), fall back to returning the object directly
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using hasattr(data, 'expr') as a duck-typing check for PipelineVariable is fragile — any object with an expr attribute would match this condition. Since there's already a PipelineVariable isinstance check earlier in the function (lines 185-186), a more robust approach would be to repeat that isinstance check here, or better yet, import and check against the specific parameter types. This would also be more maintainable.

Alternatively, consider catching TypeError more broadly:

except TypeError:
    try:
        return str(data)
    except TypeError:
        # PipelineVariable.__str__ raises TypeError by design.
        # Return the object directly so pipeline serialization can handle it.
        return data

Returning data unconditionally in the inner except is safer than a duck-type check that could silently pass through unrelated objects. If the object truly can't be serialized or stringified, it's likely a PipelineVariable that should be passed through. If you want to be defensive, re-use the PipelineVariable isinstance check instead of hasattr.

output_data_config=DEFAULT_OUTPUT,
hyperparameters={"max_depth": max_depth},
)
# safe_serialize should return the PipelineVariable object directly
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports should be at the top of the file, not inside test functions. Move from sagemaker.train.utils import safe_serialize to the module-level imports. This applies to all three new test methods.



class TestModelTrainerHyperparametersPipelineVariable:
"""Test that PipelineVariable objects in hyperparameters survive safe_serialize."""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests in test_model_trainer_pipeline_variable.py are essentially duplicates of the tests already added in test_utils.py. The safe_serialize unit tests in test_utils.py are sufficient for testing the serialization behavior. These tests should instead verify the integration — that ModelTrainer correctly preserves PipelineVariable objects in its hyperparameters dict after construction, rather than just re-testing safe_serialize.

from sagemaker.core.helper.session_helper import Session
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
from sagemaker.core.workflow.parameters import ParameterString
from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger, ParameterFloat
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ParameterFloat and PipelineSession are imported but ParameterFloat is never used in the new tests, and PipelineSession is also unused. Remove unused imports.

Copy link
Copy Markdown
Member

@mufaddal-rohawala mufaddal-rohawala left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review

This PR fixes a bug where PipelineVariable objects (ParameterInteger, ParameterString) fail in safe_serialize due to their str() raising TypeError. The fix adds a nested try/except in the fallback path and includes good test coverage. However, there are several issues: the return type annotation is problematic, the function signature change could cause downstream issues, and there's a missing from __future__ import annotations import.



def safe_serialize(data):
def safe_serialize(data) -> "str | PipelineVariable":
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type "str | PipelineVariable" is a forward reference string, but this module should use from __future__ import annotations at the top to enable PEP 604 union syntax properly. Also, this is a behavioral change to the function's contract — previously it always returned str, now it can return PipelineVariable. This could cause TypeError in downstream callers that expect a str (e.g., calling .encode(), concatenation, etc.). Have you audited all call sites of safe_serialize to ensure they handle a PipelineVariable return value correctly?

Suggestion: Add from __future__ import annotations at the module top, and change the annotation to:

def safe_serialize(data) -> str | PipelineVariable:

Also consider using Union[str, PipelineVariable] with an explicit import if from __future__ import annotations is not already used in this module.

return json.dumps(data)
except TypeError:
return str(data)
try:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nested try/except is a reasonable defensive measure, but catching a bare TypeError from str(data) and silently returning the raw object is risky. This means any object whose __str__ raises TypeError (not just PipelineVariable) will be returned as-is, potentially causing unexpected behavior downstream. Consider being more explicit:

except TypeError:
    try:
        return str(data)
    except TypeError:
        if isinstance(data, PipelineVariable):
            return data
        raise

This way, only known PipelineVariable objects get the pass-through treatment, and truly broken objects still raise.



def safe_serialize(data):
def safe_serialize(data) -> "str | PipelineVariable":
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing type annotation for the data parameter. Per SDK conventions, all public functions must have type annotations for parameters and return types.

def safe_serialize(data: Any) -> str | PipelineVariable:

# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Unit tests for sagemaker.train.utils – specifically safe_serialize."""
from __future__ import absolute_import
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per PEP 484 / SDK conventions, new modules should use from __future__ import annotations instead of from __future__ import absolute_import (which is a Python 2 artifact and unnecessary in Python 3).



def test_safe_serialize_with_empty_string():
assert safe_serialize("") == ""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good test coverage overall! Consider adding a test for the specific edge case this PR is fixing — where isinstance check for PipelineVariable might fail (e.g., import path mismatch). You could mock isinstance or create a mock object whose __str__ raises TypeError to verify the nested except path:

def test_safe_serialize_with_object_whose_str_raises_typeerror():
    """Objects whose __str__ raises TypeError should be returned as-is."""
    class BadStr:
        def __str__(self):
            raise TypeError("cannot convert")
    obj = BadStr()
    result = safe_serialize(obj)
    assert result is obj

This directly tests the new fallback code path added in this PR.


def test_hyperparameters_preserves_pipeline_variable_string(self):
"""ParameterString in hyperparameters should be preserved in ModelTrainer."""
optimizer = ParameterString(name="Optimizer", default_value="sgd")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests verify that ModelTrainer preserves PipelineVariable objects in the hyperparameters dict, but they don't test the actual serialization path (i.e., when safe_serialize is called during job creation). Consider adding a test that mocks the training job creation to verify that PipelineVariable values survive the full serialization pipeline, not just assignment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants