Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sagemaker-core/src/sagemaker/core/workflow/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ def get_code_hash(step: Entity) -> str:
source_dir = source_code.source_dir
requirements = source_code.requirements
entry_point = source_code.entry_script
return get_training_code_hash(entry_point, source_dir, requirements)
return get_training_code_hash(
entry_point, source_dir, requirements
)
return None


Expand Down
176 changes: 172 additions & 4 deletions sagemaker-core/tests/unit/workflow/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_processing_dependencies,
get_processing_code_hash,
get_training_code_hash,
get_code_hash,
validate_step_args_input,
override_pipeline_parameter_var,
trim_request_dict,
Expand Down Expand Up @@ -273,10 +274,14 @@ def test_get_training_code_hash_with_source_dir(self):
requirements_file.write_text("numpy==1.21.0")

result_no_deps = get_training_code_hash(
entry_point=str(entry_file), source_dir=temp_dir, dependencies=None
entry_point=str(entry_file),
source_dir=temp_dir,
dependencies=None,
)
result_with_deps = get_training_code_hash(
entry_point=str(entry_file), source_dir=temp_dir, dependencies=str(requirements_file)
entry_point=str(entry_file),
source_dir=temp_dir,
dependencies=str(requirements_file),
)

assert result_no_deps is not None
Expand All @@ -285,6 +290,33 @@ def test_get_training_code_hash_with_source_dir(self):
assert len(result_with_deps) == 64
assert result_no_deps != result_with_deps

def test_get_training_code_hash_source_dir_none_deps(
self,
):
"""Test get_training_code_hash with source_dir
and None dependencies does not raise TypeError.
"""
with tempfile.TemporaryDirectory() as temp_dir:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good: this test directly validates the reported bug scenario (dependencies=None with source_dir). However, consider also asserting that the result matches the expected hash when dependencies=[] (empty list) to confirm None and empty list are treated equivalently, which strengthens the regression test.

entry_file = Path(temp_dir, "train.py")
entry_file.write_text("print('training')")

# Should NOT raise TypeError
result_none = get_training_code_hash(
entry_point=str(entry_file),
source_dir=temp_dir,
dependencies=None,
)
# Empty list should be equivalent to None
result_empty = get_training_code_hash(
entry_point=str(entry_file),
source_dir=temp_dir,
dependencies=[],
)

assert result_none is not None
assert len(result_none) == 64
assert result_none == result_empty

def test_get_training_code_hash_entry_point_only(self):
"""Test get_training_code_hash with entry_point only"""
with tempfile.TemporaryDirectory() as temp_dir:
Expand All @@ -295,11 +327,15 @@ def test_get_training_code_hash_entry_point_only(self):

# Without dependencies
result_no_deps = get_training_code_hash(
entry_point=str(entry_file), source_dir=None, dependencies=None
entry_point=str(entry_file),
source_dir=None,
dependencies=None,
)
# With dependencies
result_with_deps = get_training_code_hash(
entry_point=str(entry_file), source_dir=None, dependencies=str(requirements_file)
entry_point=str(entry_file),
source_dir=None,
dependencies=str(requirements_file),
)

assert result_no_deps is not None
Expand All @@ -308,6 +344,33 @@ def test_get_training_code_hash_entry_point_only(self):
assert len(result_with_deps) == 64
assert result_no_deps != result_with_deps

def test_get_training_code_hash_entry_point_none_deps(
self,
):
"""Test get_training_code_hash with entry_point
and None dependencies does not raise TypeError.
"""
with tempfile.TemporaryDirectory() as temp_dir:
entry_file = Path(temp_dir, "train.py")
entry_file.write_text("print('training')")

# Should NOT raise TypeError
result_none = get_training_code_hash(
entry_point=str(entry_file),
source_dir=None,
dependencies=None,
)
# Empty list should be equivalent to None
result_empty = get_training_code_hash(
entry_point=str(entry_file),
source_dir=None,
dependencies=[],
)

assert result_none is not None
assert len(result_none) == 64
assert result_none == result_empty

def test_get_training_code_hash_s3_uri(self):
"""Test get_training_code_hash with S3 URI returns None"""
result = get_training_code_hash(
Expand All @@ -325,6 +388,111 @@ def test_get_training_code_hash_pipeline_variable(self):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: The actual source code fix is missing from this PR. The PR description mentions fixing sagemaker-core/src/sagemaker/core/workflow/utilities.py, but no changes to that file are included in the diff. The tests alone don't fix the bug — the defensive None handling in get_training_code_hash() and get_code_hash() must also be committed. Please include the source code changes.

assert result is None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skipped tests provide zero coverage. These three get_code_hash tests are all decorated with @pytest.mark.skip and will never execute in CI. Skipped tests don't prove the fix works and don't count toward coverage.

Since get_code_hash imports TrainingStep internally, you can mock the isinstance check or patch the import rather than skipping entirely. For example:

with patch('sagemaker.core.workflow.utilities.TrainingStep', new=type('TrainingStep', (), {})):
    ...

Alternatively, restructure the test to call get_training_code_hash directly (which you already do in the non-skipped tests) and add a separate focused test for the None-to-list coercion in get_code_hash.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line length exceeds 100 characters. This @pytest.mark.skip(reason=...) line is ~93 chars for the decorator alone plus the reason string, pushing well past the 100-char limit. Same issue on lines 383 and 406. Break the reason string or use a variable:

_SKIP_REASON = "Requires sagemaker-mlops module which is not installed in sagemaker-core tests"

@pytest.mark.skip(reason=_SKIP_REASON)

def test_get_code_hash_training_step_no_requirements(
self,
):
"""Test get_code_hash with TrainingStep where
SourceCode has requirements=None.
"""
# Create a fake TrainingStep class to patch isinstance
FakeTrainingStep = type(
"TrainingStep", (), {}
)

with tempfile.TemporaryDirectory() as temp_dir:
entry_file = Path(temp_dir, "train.py")
entry_file.write_text("print('training')")

mock_source_code = Mock()
mock_source_code.source_dir = temp_dir
mock_source_code.requirements = None
mock_source_code.entry_script = str(entry_file)

mock_model_trainer = Mock()
mock_model_trainer.source_code = mock_source_code

mock_step_args = Mock()
mock_step_args.func_args = [
mock_model_trainer
]

mock_step = MagicMock(spec=FakeTrainingStep)
mock_step.step_args = mock_step_args

with patch(
"sagemaker.core.workflow.utilities"
".TrainingStep",
new=FakeTrainingStep,
):
result = get_code_hash(mock_step)

assert result is not None
assert len(result) == 64

def test_get_code_hash_training_step_with_requirements(
self,
):
"""Test get_code_hash with TrainingStep where
SourceCode has valid requirements.
"""
FakeTrainingStep = type(
"TrainingStep", (), {}
)

with tempfile.TemporaryDirectory() as temp_dir:
entry_file = Path(temp_dir, "train.py")
entry_file.write_text("print('training')")
req_file = Path(temp_dir, "requirements.txt")
req_file.write_text("numpy==1.21.0")

mock_sc_no_req = Mock()
mock_sc_no_req.source_dir = temp_dir
mock_sc_no_req.requirements = None
mock_sc_no_req.entry_script = str(entry_file)

mock_sc_with_req = Mock()
mock_sc_with_req.source_dir = temp_dir
mock_sc_with_req.requirements = str(req_file)
mock_sc_with_req.entry_script = str(entry_file)

mock_mt_no_req = Mock()
mock_mt_no_req.source_code = mock_sc_no_req

mock_mt_with_req = Mock()
mock_mt_with_req.source_code = mock_sc_with_req

mock_step_no_req = MagicMock(
spec=FakeTrainingStep
)
mock_step_no_req.step_args = Mock()
mock_step_no_req.step_args.func_args = [
mock_mt_no_req
]

mock_step_with_req = MagicMock(
spec=FakeTrainingStep
)
mock_step_with_req.step_args = Mock()
mock_step_with_req.step_args.func_args = [
mock_mt_with_req
]

with patch(
"sagemaker.core.workflow.utilities"
".TrainingStep",
new=FakeTrainingStep,
):
result_no_req = get_code_hash(
mock_step_no_req
)
result_with_req = get_code_hash(
mock_step_with_req
)

assert result_no_req is not None
assert result_with_req is not None
assert result_no_req != result_with_req

def test_validate_step_args_input_valid(self):
"""Test validate_step_args_input with valid input"""
step_args = _StepArguments(
Expand Down
Loading