Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sagemaker-mlops/src/sagemaker/mlops/workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
__version__ = "0.1.0"

# Pipeline and configuration
from sagemaker.mlops.workflow.pipeline import Pipeline, PipelineGraph
from sagemaker.mlops.workflow.pipeline import Pipeline, PipelineGraph, PipelineExecution
from sagemaker.mlops.workflow.pipeline_experiment_config import (
PipelineExperimentConfig,
PipelineExperimentConfigProperty,
Expand Down Expand Up @@ -74,6 +74,7 @@
__all__ = [
# Pipeline and configuration
"Pipeline",
"PipelineExecution",
"PipelineGraph",
"PipelineExperimentConfig",
"PipelineExperimentConfigProperty",
Expand Down
23 changes: 18 additions & 5 deletions sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def start(
specified, uses the latest version ID.

Returns:
A `_PipelineExecution` instance, if successful.
A `PipelineExecution` instance, if successful.
"""
if selective_execution_config is not None:
if (
Expand Down Expand Up @@ -438,7 +438,7 @@ def start(
lambda: self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs),
botocore_client_error_code="AccessDeniedException",
)
return _PipelineExecution(
return PipelineExecution(
arn=response["PipelineExecutionArn"],
sagemaker_session=self.sagemaker_session,
)
Expand Down Expand Up @@ -602,7 +602,7 @@ def _get_parameters_for_execution(self, pipeline_execution_arn: str) -> Dict[str
Returns:
A parameter dict from the execution.
"""
pipeline_execution = _PipelineExecution(
pipeline_execution = PipelineExecution(
arn=pipeline_execution_arn,
sagemaker_session=self.sagemaker_session,
)
Expand Down Expand Up @@ -950,8 +950,21 @@ def _generate_step_map(steps: Sequence[Step], step_map: dict):


@attr.s
class _PipelineExecution:
"""Internal class for encapsulating pipeline execution instances.
class PipelineExecution:
"""Encapsulates a pipeline execution instance.

This class can be used to interact with pipeline executions that were
started from any source (Python SDK, Studio UI, console, etc.).

Example::

execution = PipelineExecution(
arn="arn:aws:sagemaker:us-west-2:123456789012:pipeline/my-pipeline/execution/abc123",
sagemaker_session=sagemaker_session,
)
execution.describe()
execution.wait()
execution.list_steps()

Attributes:
arn (str): The arn of the pipeline execution.
Expand Down
30 changes: 15 additions & 15 deletions sagemaker-mlops/tests/unit/workflow/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_pipeline_get_latest_execution_arn_none(mock_session, mock_step):


def test_pipeline_build_parameters_from_execution(mock_session, mock_step):
from sagemaker.mlops.workflow.pipeline import _PipelineExecution
from sagemaker.mlops.workflow.pipeline import PipelineExecution
pipeline = Pipeline(name="test-pipeline", steps=[mock_step], sagemaker_session=mock_session)

mock_session.sagemaker_client.list_pipeline_parameters_for_execution.return_value = {
Expand Down Expand Up @@ -268,43 +268,43 @@ def test_pipeline_delete_triggers_not_found(mock_session, mock_step):


def test_pipeline_execution_stop(mock_session):
from sagemaker.mlops.workflow.pipeline import _PipelineExecution
from sagemaker.mlops.workflow.pipeline import PipelineExecution

execution = _PipelineExecution(arn="arn", sagemaker_session=mock_session)
execution = PipelineExecution(arn="arn", sagemaker_session=mock_session)
execution.stop()
mock_session.sagemaker_client.stop_pipeline_execution.assert_called_once()


def test_pipeline_execution_describe(mock_session):
from sagemaker.mlops.workflow.pipeline import _PipelineExecution
from sagemaker.mlops.workflow.pipeline import PipelineExecution

execution = _PipelineExecution(arn="arn", sagemaker_session=mock_session)
execution = PipelineExecution(arn="arn", sagemaker_session=mock_session)
execution.describe()
mock_session.sagemaker_client.describe_pipeline_execution.assert_called_once()


def test_pipeline_execution_list_steps(mock_session):
from sagemaker.mlops.workflow.pipeline import _PipelineExecution
from sagemaker.mlops.workflow.pipeline import PipelineExecution

mock_session.sagemaker_client.list_pipeline_execution_steps.return_value = {"PipelineExecutionSteps": []}
execution = _PipelineExecution(arn="arn", sagemaker_session=mock_session)
execution = PipelineExecution(arn="arn", sagemaker_session=mock_session)
result = execution.list_steps()
assert result == []


def test_pipeline_execution_list_parameters(mock_session):
from sagemaker.mlops.workflow.pipeline import _PipelineExecution
from sagemaker.mlops.workflow.pipeline import PipelineExecution

execution = _PipelineExecution(arn="arn", sagemaker_session=mock_session)
execution = PipelineExecution(arn="arn", sagemaker_session=mock_session)
execution.list_parameters(max_results=10, next_token="token")
mock_session.sagemaker_client.list_pipeline_parameters_for_execution.assert_called_once()


def test_pipeline_execution_wait(mock_session):
from sagemaker.mlops.workflow.pipeline import _PipelineExecution
from sagemaker.mlops.workflow.pipeline import PipelineExecution
import botocore.waiter

execution = _PipelineExecution(arn="arn", sagemaker_session=mock_session)
execution = PipelineExecution(arn="arn", sagemaker_session=mock_session)
with patch("botocore.waiter.create_waiter_with_client") as mock_waiter:
mock_waiter.return_value.wait = Mock()
execution.wait(delay=10, max_attempts=5)
Expand Down Expand Up @@ -476,22 +476,22 @@ def test_pipeline_list_versions(mock_session, mock_step):


def test_pipeline_execution_result_waiter_error(mock_session):
from sagemaker.mlops.workflow.pipeline import _PipelineExecution
from sagemaker.mlops.workflow.pipeline import PipelineExecution
from botocore.exceptions import WaiterError

execution = _PipelineExecution(arn="arn:aws:sagemaker:us-west-2:123456789012:pipeline/test/execution/exec-id", sagemaker_session=mock_session)
execution = PipelineExecution(arn="arn:aws:sagemaker:us-west-2:123456789012:pipeline/test/execution/exec-id", sagemaker_session=mock_session)

with patch.object(execution, "wait", side_effect=WaiterError("name", "reason", {})):
with pytest.raises(WaiterError):
execution.result("step1")


def test_pipeline_execution_result_terminal_failure(mock_session):
from sagemaker.mlops.workflow.pipeline import _PipelineExecution
from sagemaker.mlops.workflow.pipeline import PipelineExecution
from botocore.exceptions import WaiterError
from sagemaker.core.remote_function.job import JOBS_CONTAINER_ENTRYPOINT

execution = _PipelineExecution(arn="arn:aws:sagemaker:us-west-2:123456789012:pipeline/test/execution/exec-id", sagemaker_session=mock_session)
execution = PipelineExecution(arn="arn:aws:sagemaker:us-west-2:123456789012:pipeline/test/execution/exec-id", sagemaker_session=mock_session)
mock_session.sagemaker_client.list_pipeline_execution_steps.return_value = {
"PipelineExecutionSteps": [{"StepName": "step1", "Metadata": {"TrainingJob": {"Arn": "arn:aws:sagemaker:us-west-2:123456789012:training-job/job"}}}]
}
Expand Down
Loading