diff --git a/sagemaker-mlops/src/sagemaker/mlops/workflow/steps.py b/sagemaker-mlops/src/sagemaker/mlops/workflow/steps.py index 5e0eb3dda3..76e90a5309 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/workflow/steps.py +++ b/sagemaker-mlops/src/sagemaker/mlops/workflow/steps.py @@ -205,7 +205,7 @@ def _find_dependencies_in_step_arguments( else: dependencies.add(self._get_step_name_from_str(referenced_step, step_map)) - from sagemaker.core.workflow.function_step import DelayedReturn + from sagemaker.mlops.workflow.function_step import DelayedReturn # TODO: we can remove the if-elif once move the validators to JsonGet constructor if isinstance(pipeline_variable, JsonGet): diff --git a/sagemaker-mlops/tests/unit/workflow/test_steps.py b/sagemaker-mlops/tests/unit/workflow/test_steps.py index 06cc729ceb..d4108abcff 100644 --- a/sagemaker-mlops/tests/unit/workflow/test_steps.py +++ b/sagemaker-mlops/tests/unit/workflow/test_steps.py @@ -252,7 +252,7 @@ def test_step_find_dependencies_in_depends_on_list_with_string(): def test_step_validate_json_get_property_file_reference_invalid_step_type(): from sagemaker.mlops.workflow.steps import Step, StepTypeEnum - from sagemaker.core.workflow.functions import JsonGet + step = Mock(spec=Step) step.name = "current-step" @@ -383,7 +383,45 @@ def test_step_validate_json_get_function_with_property_file(): step_map = {"processing-step": processing_step} - Step._validate_json_get_function(step, json_get, step_map) + Step._validate_json_get_function(step, json_get, step_map) + + +def test_step_find_dependencies_in_step_arguments_with_json_get(): + from sagemaker.mlops.workflow.steps import Step, StepTypeEnum + from sagemaker.core.workflow.functions import JsonGet + + + + + + + + + + + + + from sagemaker.mlops.workflow.steps import Step, StepTypeEnum + + from sagemaker.core.workflow.functions import JsonGet + + + + + + + + + + + + + + + + + + def test_step_find_dependencies_in_step_arguments_with_json_get(): @@ -408,74 +446,61 @@ def test_step_find_dependencies_in_step_arguments_with_json_get(): obj = {"key": json_get} - with patch('sagemaker.mlops.workflow.steps.TYPE_CHECKING', False): - with patch.dict('sys.modules', {'sagemaker.core.workflow.function_step': Mock()}): - dependencies = Step._find_dependencies_in_step_arguments(step2, obj, {"step1": step1}) - assert "step1" in dependencies + dependencies = Step._find_dependencies_in_step_arguments(step2, obj, {"step1": step1}) + assert "step1" in dependencies def test_step_find_dependencies_in_step_arguments_with_delayed_return(): - from unittest.mock import patch from sagemaker.mlops.workflow.steps import Step, StepTypeEnum + from sagemaker.mlops.workflow.function_step import DelayedReturn from sagemaker.core.workflow.functions import JsonGet - from sagemaker.core.helper.pipeline_variable import PipelineVariable - + step1 = Mock(spec=Step) step1.name = "step1" step1.step_type = StepTypeEnum.PROCESSING step1.property_files = [] step1.arguments = {} - + json_get = Mock(spec=JsonGet) json_get.property_file = None - - delayed_return_class = type('DelayedReturn', (PipelineVariable,), {}) - delayed_return = Mock(spec=delayed_return_class) + + delayed_return = Mock(spec=DelayedReturn) delayed_return._referenced_steps = [step1] delayed_return._to_json_get = Mock(return_value=json_get) - delayed_return.__class__ = delayed_return_class - + step2 = Mock(spec=Step) step2.name = "step2" step2._validate_json_get_function = Mock() step2._get_step_name_from_str = Step._get_step_name_from_str - + obj = {"key": delayed_return} - - mock_module = Mock() - mock_module.DelayedReturn = delayed_return_class - - with patch.dict('sys.modules', {'sagemaker.core.workflow.function_step': mock_module}): - dependencies = Step._find_dependencies_in_step_arguments(step2, obj, {"step1": step1}) - assert "step1" in dependencies + + dependencies = Step._find_dependencies_in_step_arguments(step2, obj, {"step1": step1}) + assert "step1" in dependencies + + def test_step_find_dependencies_in_step_arguments_with_string_reference(): - from unittest.mock import patch from sagemaker.mlops.workflow.steps import Step from sagemaker.core.helper.pipeline_variable import PipelineVariable - + step1 = Mock(spec=Step) step1.name = "step1" - + pipeline_var = Mock(spec=PipelineVariable) pipeline_var._referenced_steps = ["step1"] - + step2 = Mock(spec=Step) step2.name = "step2" step2._get_step_name_from_str = Step._get_step_name_from_str - + obj = {"key": pipeline_var} - + step_map = {"step1": step1} - - delayed_return_class = type('DelayedReturn', (PipelineVariable,), {}) - mock_module = Mock() - mock_module.DelayedReturn = delayed_return_class - - with patch.dict('sys.modules', {'sagemaker.core.workflow.function_step': mock_module}): - dependencies = Step._find_dependencies_in_step_arguments(step2, obj, step_map) - assert "step1" in dependencies + + dependencies = Step._find_dependencies_in_step_arguments(step2, obj, step_map) + assert "step1" in dependencies def test_tuning_step_requires_step_args():