-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504) #5689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -168,22 +168,26 @@ def convert_unassigned_to_none(instance) -> Any: | |
| return instance | ||
|
|
||
|
|
||
| def safe_serialize(data): | ||
| def safe_serialize(data) -> "str | PipelineVariable": | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing type annotation for the def safe_serialize(data: Any) -> str | PipelineVariable: |
||
| """Serialize the data without wrapping strings in quotes. | ||
|
|
||
| This function handles the following cases: | ||
| 1. If `data` is a string, it returns the string as-is without wrapping in quotes. | ||
| 2. If `data` is of type `PipelineVariable`, it returns the json representation of the PipelineVariable | ||
| 3. If `data` is serializable (e.g., a dictionary, list, int, float), it returns | ||
| the JSON-encoded string using `json.dumps()`. | ||
| 4. If `data` cannot be serialized (e.g., a custom object), it returns the string | ||
| representation of the data using `str(data)`. | ||
| 1. If ``data`` is a string, it returns the string as-is without wrapping in quotes. | ||
| 2. If ``data`` is of type :class:`~sagemaker.core.workflow.parameters.PipelineVariable`, | ||
| it returns the object directly so that pipeline serialization can handle it | ||
| downstream. Callers should be aware that the return value may be a | ||
| ``PipelineVariable`` rather than a plain ``str``. | ||
| 3. If ``data`` is serializable (e.g., a dictionary, list, int, float), it returns | ||
| the JSON-encoded string using ``json.dumps()``. | ||
| 4. If ``data`` cannot be serialized (e.g., a custom object), it returns the string | ||
| representation of the data using ``str(data)``. | ||
|
|
||
| Args: | ||
| data (Any): The data to serialize. | ||
|
|
||
| Returns: | ||
| str: The serialized JSON-compatible string or the string representation of the input. | ||
| str | PipelineVariable: The serialized JSON-compatible string, the string | ||
| representation of the input, or the original ``PipelineVariable`` object. | ||
| """ | ||
| if isinstance(data, str): | ||
| return data | ||
|
|
@@ -192,7 +196,14 @@ def safe_serialize(data): | |
| try: | ||
| return json.dumps(data) | ||
| except TypeError: | ||
| return str(data) | ||
| try: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The nested try/except is a reasonable defensive measure, but catching a bare except TypeError:
try:
return str(data)
except TypeError:
if isinstance(data, PipelineVariable):
return data
raiseThis way, only known |
||
| return str(data) | ||
| except TypeError: | ||
| # PipelineVariable.__str__ raises TypeError by design. | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # If the isinstance check above didn't catch it (e.g. import | ||
| # path mismatch or reload issues), return the object directly | ||
| # so pipeline serialization can handle it downstream. | ||
| return data | ||
|
|
||
|
|
||
| def _run_clone_command_silent(repo_url, dest_dir): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,7 +26,7 @@ | |
|
|
||
| from sagemaker.core.helper.session_helper import Session | ||
| from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar | ||
| from sagemaker.core.workflow.parameters import ParameterString | ||
| from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger | ||
| from sagemaker.train.model_trainer import ModelTrainer, Mode | ||
| from sagemaker.train.configs import ( | ||
| Compute, | ||
|
|
@@ -176,3 +176,54 @@ def test_training_image_rejects_invalid_type(self): | |
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| ) | ||
|
|
||
|
|
||
| class TestModelTrainerHyperparametersPipelineVariable: | ||
| """Test that ModelTrainer correctly preserves PipelineVariable objects in hyperparameters.""" | ||
|
|
||
| def test_hyperparameters_preserves_pipeline_variable_integer(self): | ||
| """ParameterInteger in hyperparameters should be preserved in ModelTrainer.""" | ||
| max_depth = ParameterInteger(name="MaxDepth", default_value=5) | ||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
| role=DEFAULT_ROLE, | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| hyperparameters={"max_depth": max_depth}, | ||
| ) | ||
| assert trainer.hyperparameters["max_depth"] is max_depth | ||
|
|
||
| def test_hyperparameters_preserves_pipeline_variable_string(self): | ||
| """ParameterString in hyperparameters should be preserved in ModelTrainer.""" | ||
| optimizer = ParameterString(name="Optimizer", default_value="sgd") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These tests verify that |
||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
| role=DEFAULT_ROLE, | ||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| hyperparameters={"optimizer": optimizer}, | ||
| ) | ||
| assert trainer.hyperparameters["optimizer"] is optimizer | ||
|
|
||
| def test_hyperparameters_preserves_mixed_pipeline_and_regular_values(self): | ||
| """Mixed PipelineVariable and regular values should all be preserved.""" | ||
| max_depth = ParameterInteger(name="MaxDepth", default_value=5) | ||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
| role=DEFAULT_ROLE, | ||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| hyperparameters={ | ||
| "max_depth": max_depth, | ||
| "eta": 0.1, | ||
| "objective": "binary:logistic", | ||
| }, | ||
| ) | ||
| # PipelineVariable should be preserved as-is | ||
| assert trainer.hyperparameters["max_depth"] is max_depth | ||
| # Regular values should also be preserved | ||
| assert trainer.hyperparameters["eta"] == 0.1 | ||
| assert trainer.hyperparameters["objective"] == "binary:logistic" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"). You | ||
| # may not use this file except in compliance with the License. A copy of | ||
| # the License is located at | ||
| # | ||
| # http://aws.amazon.com/apache2.0/ | ||
| # | ||
| # or in the "license" file accompanying this file. This file is | ||
| # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
| # ANY KIND, either express or implied. See the License for the specific | ||
| # language governing permissions and limitations under the License. | ||
| """Unit tests for sagemaker.train.utils – specifically safe_serialize.""" | ||
| from __future__ import absolute_import | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Per PEP 484 / SDK conventions, new modules should use |
||
|
|
||
| import json | ||
|
|
||
| from sagemaker.train.utils import safe_serialize | ||
| from sagemaker.core.workflow.parameters import ( | ||
| ParameterInteger, | ||
| ParameterString, | ||
| ParameterFloat, | ||
| ) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # PipelineVariable inputs – should be returned as-is (identity) | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| def test_safe_serialize_with_pipeline_variable_integer_returns_object_directly(): | ||
| param = ParameterInteger(name="MaxDepth", default_value=5) | ||
| result = safe_serialize(param) | ||
| assert result is param | ||
|
|
||
|
|
||
| def test_safe_serialize_with_pipeline_variable_string_returns_object_directly(): | ||
| param = ParameterString(name="Optimizer", default_value="sgd") | ||
| result = safe_serialize(param) | ||
| assert result is param | ||
|
|
||
|
|
||
| def test_safe_serialize_with_pipeline_variable_float_returns_object_directly(): | ||
| param = ParameterFloat(name="LearningRate", default_value=0.01) | ||
| result = safe_serialize(param) | ||
| assert result is param | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Regular / primitive inputs | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| def test_safe_serialize_with_string_returns_string_as_is(): | ||
| assert safe_serialize("hello") == "hello" | ||
| assert safe_serialize("12345") == "12345" | ||
|
|
||
|
|
||
| def test_safe_serialize_with_json_like_string_returns_as_is(): | ||
| """A string that looks like JSON should be returned as-is, not double-serialized.""" | ||
| json_str = '{"key": "value"}' | ||
| assert safe_serialize(json_str) == json_str | ||
|
|
||
|
|
||
| def test_safe_serialize_with_int_returns_json_string(): | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| assert safe_serialize(5) == "5" | ||
| assert safe_serialize(0) == "0" | ||
|
|
||
|
|
||
| def test_safe_serialize_with_dict_returns_json_string(): | ||
| data = {"key": "value", "num": 1} | ||
| assert safe_serialize(data) == json.dumps(data) | ||
|
|
||
|
|
||
| def test_safe_serialize_with_bool_returns_json_string(): | ||
| assert safe_serialize(True) == "true" | ||
| assert safe_serialize(False) == "false" | ||
|
|
||
|
|
||
| def test_safe_serialize_with_custom_object_returns_str(): | ||
| class CustomObject: | ||
| def __str__(self): | ||
| return "CustomObject" | ||
|
|
||
| obj = CustomObject() | ||
| assert safe_serialize(obj) == "CustomObject" | ||
|
|
||
|
|
||
| def test_safe_serialize_with_none_returns_json_null(): | ||
| assert safe_serialize(None) == "null" | ||
|
|
||
|
|
||
| def test_safe_serialize_with_list_returns_json_string(): | ||
| assert safe_serialize([1, 2, 3]) == "[1, 2, 3]" | ||
|
|
||
|
|
||
| def test_safe_serialize_with_empty_string(): | ||
| assert safe_serialize("") == "" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good test coverage overall! Consider adding a test for the specific edge case this PR is fixing — where def test_safe_serialize_with_object_whose_str_raises_typeerror():
"""Objects whose __str__ raises TypeError should be returned as-is."""
class BadStr:
def __str__(self):
raise TypeError("cannot convert")
obj = BadStr()
result = safe_serialize(obj)
assert result is objThis directly tests the new fallback code path added in this PR. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type
"str | PipelineVariable"is a forward reference string, but this module should usefrom __future__ import annotationsat the top to enable PEP 604 union syntax properly. Also, this is a behavioral change to the function's contract — previously it always returnedstr, now it can returnPipelineVariable. This could causeTypeErrorin downstream callers that expect astr(e.g., calling.encode(), concatenation, etc.). Have you audited all call sites ofsafe_serializeto ensure they handle aPipelineVariablereturn value correctly?Suggestion: Add
from __future__ import annotationsat the module top, and change the annotation to:Also consider using
Union[str, PipelineVariable]with an explicit import iffrom __future__ import annotationsis not already used in this module.