From 0f4ccb2d5813ab02a0228aa62f8c91f8beebc837 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Fri, 27 Mar 2026 16:01:35 -0400 Subject: [PATCH 1/2] fix: Enhancements Needed for Secure Tar Extraction (5560) --- .../src/sagemaker/core/common_utils.py | 6 +- .../tests/unit/test_common_utils.py | 230 ++++++++++++++++++ 2 files changed, 233 insertions(+), 3 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/common_utils.py b/sagemaker-core/src/sagemaker/core/common_utils.py index b8d9ca6866..037ef6a3b9 100644 --- a/sagemaker-core/src/sagemaker/core/common_utils.py +++ b/sagemaker-core/src/sagemaker/core/common_utils.py @@ -647,7 +647,7 @@ def _validate_source_directory(source_directory): # Check if the source path is under any sensitive directory for sensitive_path in _SENSITIVE_SYSTEM_PATHS: - if abs_source != "/" and abs_source.startswith(sensitive_path): + if abs_source != "/" and os.path.commonpath([abs_source, sensitive_path]) == sensitive_path: raise ValueError( f"source_directory cannot access sensitive system paths. " f"Got: {source_directory} (resolved to {abs_source})" @@ -673,7 +673,7 @@ def _validate_dependency_path(dependency): # Check if the dependency path is under any sensitive directory for sensitive_path in _SENSITIVE_SYSTEM_PATHS: - if abs_dependency != "/" and abs_dependency.startswith(sensitive_path): + if abs_dependency != "/" and os.path.commonpath([abs_dependency, sensitive_path]) == sensitive_path: raise ValueError( f"dependency path cannot access sensitive system paths. " f"Got: {dependency} (resolved to {abs_dependency})" @@ -689,7 +689,7 @@ def _create_or_update_code_dir( # Validate that code_dir does not resolve to a sensitive system path for sensitive_path in _SENSITIVE_SYSTEM_PATHS: - if resolved_code_dir != "/" and resolved_code_dir.startswith(sensitive_path): + if resolved_code_dir != "/" and os.path.commonpath([resolved_code_dir, sensitive_path]) == sensitive_path: raise ValueError( f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}" ) diff --git a/sagemaker-core/tests/unit/test_common_utils.py b/sagemaker-core/tests/unit/test_common_utils.py index 8aeb496922..a24a572f34 100644 --- a/sagemaker-core/tests/unit/test_common_utils.py +++ b/sagemaker-core/tests/unit/test_common_utils.py @@ -1139,6 +1139,236 @@ def test_custom_extractall_tarfile_basic(self, tmp_path): assert (extract_path / "file.txt").exists() + def test_custom_extractall_tarfile_without_data_filter(self, tmp_path): + """Test custom_extractall_tarfile uses safe members with extract_path as base when data_filter unavailable.""" + from sagemaker.core.common_utils import custom_extractall_tarfile + + # Create tar file + source = tmp_path / "source" + source.mkdir() + (source / "file.txt").write_text("content") + + tar_path = tmp_path / "test.tar.gz" + with tarfile.open(tar_path, "w:gz") as tar: + tar.add(source / "file.txt", arcname="file.txt") + + extract_path = tmp_path / "extract" + extract_path.mkdir() + + with patch('sagemaker.core.common_utils.tarfile') as mock_tarfile_module: + # Remove data_filter to force fallback path + if hasattr(mock_tarfile_module, 'data_filter'): + delattr(mock_tarfile_module, 'data_filter') + + with tarfile.open(tar_path, "r:gz") as tar: + with patch('sagemaker.core.common_utils._get_safe_members') as mock_safe: + mock_safe.return_value = tar.getmembers() + custom_extractall_tarfile(tar, str(extract_path)) + # Verify _get_safe_members was called with members list and base path + mock_safe.assert_called_once() + call_args = mock_safe.call_args + # First arg should be a list of TarInfo members + assert isinstance(call_args[0][0], list) + # Second arg should be the resolved extract path (not CWD) + from sagemaker.core.common_utils import _get_resolved_path + expected_base = _get_resolved_path(str(extract_path)) + assert call_args[0][1] == expected_base + + +class TestIsBadPath: + """Test _is_bad_path function for secure tar extraction.""" + + def test_is_bad_path_safe_relative(self): + """Test _is_bad_path returns False for safe relative paths.""" + from sagemaker.core.common_utils import _is_bad_path, _get_resolved_path + + base = _get_resolved_path("/tmp/safe") + assert _is_bad_path("safe/file.txt", base) is False + + def test_is_bad_path_actual_escape(self): + """Test _is_bad_path returns True for paths escaping base.""" + from sagemaker.core.common_utils import _is_bad_path, _get_resolved_path + + base = _get_resolved_path("/tmp/safe") + assert _is_bad_path("/etc/passwd", base) is True + + def test_is_bad_path_traversal(self): + """Test _is_bad_path detects parent directory traversal.""" + from sagemaker.core.common_utils import _is_bad_path, _get_resolved_path + + base = _get_resolved_path("/tmp/safe") + assert _is_bad_path("../../etc/passwd", base) is True + + def test_is_bad_path_prefix_collision(self): + """Test _is_bad_path does NOT flag /tmp/safe2 when base is /tmp/safe. + + This is the key test for the startswith() bug fix - /tmp/safe2 starts with + /tmp/safe but is NOT actually under /tmp/safe. + """ + from sagemaker.core.common_utils import _is_bad_path, _get_resolved_path + + base = _get_resolved_path("/tmp/safe") + # /tmp/safe2 is NOT under /tmp/safe, but startswith would incorrectly say it is + # With the fix using commonpath, this should correctly identify it as outside base + assert _is_bad_path("/tmp/safe2/file.txt", base) is True + + +class TestGetSafeMembers: + """Test _get_safe_members function for secure tar extraction.""" + + def test_get_safe_members_accepts_member_list_and_base(self): + """Test _get_safe_members works with a list of TarInfo mocks and a base path.""" + from sagemaker.core.common_utils import _get_safe_members, _get_resolved_path + + base = _get_resolved_path("/tmp/extract") + + mock_member = Mock() + mock_member.name = "safe/file.txt" + mock_member.issym = Mock(return_value=False) + mock_member.islnk = Mock(return_value=False) + + members = [mock_member] + safe = list(_get_safe_members(members, base)) + assert len(safe) == 1 + assert mock_member in safe + + def test_get_safe_members_filters_bad_paths(self): + """Test _get_safe_members filters out members with bad paths.""" + from sagemaker.core.common_utils import _get_safe_members, _get_resolved_path + + base = _get_resolved_path("/tmp/extract") + + mock_safe = Mock() + mock_safe.name = "safe/file.txt" + mock_safe.issym = Mock(return_value=False) + mock_safe.islnk = Mock(return_value=False) + + mock_bad = Mock() + mock_bad.name = "/etc/passwd" + mock_bad.issym = Mock(return_value=False) + mock_bad.islnk = Mock(return_value=False) + + with patch('sagemaker.core.common_utils._is_bad_path') as mock_is_bad: + mock_is_bad.side_effect = lambda name, base: name == "/etc/passwd" + safe = list(_get_safe_members([mock_safe, mock_bad], base)) + assert len(safe) == 1 + assert mock_safe in safe + + def test_get_safe_members_filters_bad_symlinks(self): + """Test _get_safe_members filters out bad symlinks.""" + from sagemaker.core.common_utils import _get_safe_members, _get_resolved_path + + base = _get_resolved_path("/tmp/extract") + + mock_safe = Mock() + mock_safe.name = "safe/file.txt" + mock_safe.issym = Mock(return_value=False) + mock_safe.islnk = Mock(return_value=False) + + mock_symlink = Mock() + mock_symlink.name = "bad/symlink" + mock_symlink.issym = Mock(return_value=True) + mock_symlink.islnk = Mock(return_value=False) + mock_symlink.linkname = "/etc/passwd" + + with patch('sagemaker.core.common_utils._is_bad_path', return_value=False): + with patch('sagemaker.core.common_utils._is_bad_link') as mock_is_bad_link: + mock_is_bad_link.return_value = True + safe = list(_get_safe_members([mock_safe, mock_symlink], base)) + assert len(safe) == 1 + assert mock_safe in safe + + def test_get_safe_members_filters_bad_hardlinks(self): + """Test _get_safe_members filters out bad hardlinks.""" + from sagemaker.core.common_utils import _get_safe_members, _get_resolved_path + + base = _get_resolved_path("/tmp/extract") + + mock_safe = Mock() + mock_safe.name = "safe/file.txt" + mock_safe.issym = Mock(return_value=False) + mock_safe.islnk = Mock(return_value=False) + + mock_hardlink = Mock() + mock_hardlink.name = "bad/hardlink" + mock_hardlink.issym = Mock(return_value=False) + mock_hardlink.islnk = Mock(return_value=True) + mock_hardlink.linkname = "/etc/passwd" + + with patch('sagemaker.core.common_utils._is_bad_path', return_value=False): + with patch('sagemaker.core.common_utils._is_bad_link') as mock_is_bad_link: + mock_is_bad_link.return_value = True + safe = list(_get_safe_members([mock_safe, mock_hardlink], base)) + assert len(safe) == 1 + assert mock_safe in safe + + +class TestValidateSourceDirectorySecurity: + """Test _validate_source_directory prefix collision fix.""" + + def test_validate_source_directory_blocks_sensitive_path(self): + """Test that actual sensitive paths are blocked.""" + from sagemaker.core.common_utils import _validate_source_directory + + with pytest.raises(ValueError, match="sensitive system paths"): + _validate_source_directory("/etc/secrets") + + def test_validate_source_directory_prefix_collision(self): + """Test that /etcetera is NOT blocked when /etc is in sensitive paths. + + This tests the fix for the startswith() prefix collision vulnerability. + """ + from sagemaker.core.common_utils import _validate_source_directory + + # /etcetera should NOT be blocked - it's not under /etc + # With the old startswith() check, this would incorrectly raise ValueError + try: + _validate_source_directory("/etcetera") + except ValueError: + pytest.fail("_validate_source_directory incorrectly blocked /etcetera due to prefix collision with /etc") + + def test_validate_source_directory_s3_path(self): + """Test that S3 paths are allowed.""" + from sagemaker.core.common_utils import _validate_source_directory + + _validate_source_directory("s3://my-bucket/my-prefix") + + def test_validate_source_directory_none(self): + """Test that None is allowed.""" + from sagemaker.core.common_utils import _validate_source_directory + + _validate_source_directory(None) + + +class TestValidateDependencyPathSecurity: + """Test _validate_dependency_path prefix collision fix.""" + + def test_validate_dependency_path_blocks_sensitive_path(self): + """Test that actual sensitive paths are blocked.""" + from sagemaker.core.common_utils import _validate_dependency_path + + with pytest.raises(ValueError, match="sensitive system paths"): + _validate_dependency_path("/root/.bashrc") + + def test_validate_dependency_path_prefix_collision(self): + """Test that /rootkit is NOT blocked when /root is in sensitive paths. + + This tests the fix for the startswith() prefix collision vulnerability. + """ + from sagemaker.core.common_utils import _validate_dependency_path + + # /rootkit should NOT be blocked - it's not under /root + try: + _validate_dependency_path("/rootkit") + except ValueError: + pytest.fail("_validate_dependency_path incorrectly blocked /rootkit due to prefix collision with /root") + + def test_validate_dependency_path_none(self): + """Test that None is allowed.""" + from sagemaker.core.common_utils import _validate_dependency_path + + _validate_dependency_path(None) + class TestCanModelPackageSourceUriAutopopulate: """Test can_model_package_source_uri_autopopulate function.""" From e13e8176293b14c37e8ff0f24ba7294f13cfe744 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Fri, 27 Mar 2026 16:07:07 -0400 Subject: [PATCH 2/2] fix: address review comments (iteration #1) --- .../src/sagemaker/core/common_utils.py | 33 ++- .../tests/unit/test_common_utils.py | 250 +++++++++--------- 2 files changed, 145 insertions(+), 138 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/common_utils.py b/sagemaker-core/src/sagemaker/core/common_utils.py index 037ef6a3b9..ec6571cb57 100644 --- a/sagemaker-core/src/sagemaker/core/common_utils.py +++ b/sagemaker-core/src/sagemaker/core/common_utils.py @@ -647,7 +647,10 @@ def _validate_source_directory(source_directory): # Check if the source path is under any sensitive directory for sensitive_path in _SENSITIVE_SYSTEM_PATHS: - if abs_source != "/" and os.path.commonpath([abs_source, sensitive_path]) == sensitive_path: + if abs_source != "/" and ( + os.path.commonpath([abs_source, sensitive_path]) + == sensitive_path + ): raise ValueError( f"source_directory cannot access sensitive system paths. " f"Got: {source_directory} (resolved to {abs_source})" @@ -673,7 +676,10 @@ def _validate_dependency_path(dependency): # Check if the dependency path is under any sensitive directory for sensitive_path in _SENSITIVE_SYSTEM_PATHS: - if abs_dependency != "/" and os.path.commonpath([abs_dependency, sensitive_path]) == sensitive_path: + if abs_dependency != "/" and ( + os.path.commonpath([abs_dependency, sensitive_path]) + == sensitive_path + ): raise ValueError( f"dependency path cannot access sensitive system paths. " f"Got: {dependency} (resolved to {abs_dependency})" @@ -686,10 +692,13 @@ def _create_or_update_code_dir( """Placeholder docstring""" code_dir = os.path.join(model_dir, "code") resolved_code_dir = _get_resolved_path(code_dir) - + # Validate that code_dir does not resolve to a sensitive system path for sensitive_path in _SENSITIVE_SYSTEM_PATHS: - if resolved_code_dir != "/" and os.path.commonpath([resolved_code_dir, sensitive_path]) == sensitive_path: + if resolved_code_dir != "/" and ( + os.path.commonpath([resolved_code_dir, sensitive_path]) + == sensitive_path + ): raise ValueError( f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}" ) @@ -1688,7 +1697,8 @@ def _is_bad_path(path, base): bool: True if the path is not rooted under the base directory, False otherwise. """ # joinpath will ignore base if path is absolute - return not _get_resolved_path(joinpath(base, path)).startswith(base) + resolved = _get_resolved_path(joinpath(base, path)) + return os.path.commonpath([resolved, base]) != base def _is_bad_link(info, base): @@ -1708,19 +1718,18 @@ def _is_bad_link(info, base): return _is_bad_path(info.linkname, base=tip) -def _get_safe_members(members): +def _get_safe_members(members, base): """A generator that yields members that are safe to extract. It filters out bad paths and bad links. Args: - members (list): A list of members to check. + members (list): A list of TarInfo members to check. + base (str): The resolved base directory for extraction. Yields: tarfile.TarInfo: The tar file info. """ - base = _get_resolved_path("") - for file_info in members: if _is_bad_path(file_info.name, base): logger.error("%s is blocked (illegal path)", file_info.name) @@ -1783,7 +1792,11 @@ def custom_extractall_tarfile(tar, extract_path): if hasattr(tarfile, "data_filter"): tar.extractall(path=extract_path, filter="data") else: - tar.extractall(path=extract_path, members=_get_safe_members(tar)) + base = _get_resolved_path(extract_path) + tar.extractall( + path=extract_path, + members=_get_safe_members(tar.getmembers(), base), + ) # Re-validate extracted paths to catch symlink race conditions _validate_extracted_paths(extract_path) diff --git a/sagemaker-core/tests/unit/test_common_utils.py b/sagemaker-core/tests/unit/test_common_utils.py index a24a572f34..c6886cd9e6 100644 --- a/sagemaker-core/tests/unit/test_common_utils.py +++ b/sagemaker-core/tests/unit/test_common_utils.py @@ -49,6 +49,12 @@ get_module, resolve_value_from_config, get_sagemaker_config_value, + _get_resolved_path, + _is_bad_path, + _get_safe_members, + _validate_source_directory, + _validate_dependency_path, + custom_extractall_tarfile, ) @@ -1119,8 +1125,6 @@ class TestCustomExtractallTarfile: def test_custom_extractall_tarfile_basic(self, tmp_path): """Test basic tar extraction.""" - from sagemaker.core.common_utils import custom_extractall_tarfile - # Create tar file source = tmp_path / "source" source.mkdir() @@ -1140,9 +1144,7 @@ def test_custom_extractall_tarfile_basic(self, tmp_path): assert (extract_path / "file.txt").exists() def test_custom_extractall_tarfile_without_data_filter(self, tmp_path): - """Test custom_extractall_tarfile uses safe members with extract_path as base when data_filter unavailable.""" - from sagemaker.core.common_utils import custom_extractall_tarfile - + """Test fallback path passes correct args to _get_safe_members.""" # Create tar file source = tmp_path / "source" source.mkdir() @@ -1155,61 +1157,67 @@ def test_custom_extractall_tarfile_without_data_filter(self, tmp_path): extract_path = tmp_path / "extract" extract_path.mkdir() - with patch('sagemaker.core.common_utils.tarfile') as mock_tarfile_module: - # Remove data_filter to force fallback path - if hasattr(mock_tarfile_module, 'data_filter'): - delattr(mock_tarfile_module, 'data_filter') - - with tarfile.open(tar_path, "r:gz") as tar: - with patch('sagemaker.core.common_utils._get_safe_members') as mock_safe: - mock_safe.return_value = tar.getmembers() - custom_extractall_tarfile(tar, str(extract_path)) - # Verify _get_safe_members was called with members list and base path - mock_safe.assert_called_once() - call_args = mock_safe.call_args - # First arg should be a list of TarInfo members - assert isinstance(call_args[0][0], list) - # Second arg should be the resolved extract path (not CWD) - from sagemaker.core.common_utils import _get_resolved_path - expected_base = _get_resolved_path(str(extract_path)) - assert call_args[0][1] == expected_base + with tarfile.open(tar_path, "r:gz") as tar: + with patch( + 'sagemaker.core.common_utils._get_safe_members' + ) as mock_safe: + mock_safe.return_value = tar.getmembers() + # Temporarily remove data_filter to force fallback + saved = getattr(tarfile, 'data_filter', None) + try: + if hasattr(tarfile, 'data_filter'): + delattr(tarfile, 'data_filter') + custom_extractall_tarfile( + tar, str(extract_path) + ) + finally: + if saved is not None: + tarfile.data_filter = saved + + mock_safe.assert_called_once() + call_args = mock_safe.call_args + # First arg: list of TarInfo members + assert isinstance(call_args[0][0], list) + assert all( + isinstance(m, tarfile.TarInfo) + for m in call_args[0][0] + ) + # Second arg: resolved extract path (not CWD) + expected_base = _get_resolved_path( + str(extract_path) + ) + assert call_args[0][1] == expected_base class TestIsBadPath: """Test _is_bad_path function for secure tar extraction.""" - def test_is_bad_path_safe_relative(self): - """Test _is_bad_path returns False for safe relative paths.""" - from sagemaker.core.common_utils import _is_bad_path, _get_resolved_path - - base = _get_resolved_path("/tmp/safe") - assert _is_bad_path("safe/file.txt", base) is False - - def test_is_bad_path_actual_escape(self): - """Test _is_bad_path returns True for paths escaping base.""" - from sagemaker.core.common_utils import _is_bad_path, _get_resolved_path - + def test_is_bad_path_safe_file_under_base(self, tmp_path): + """Test _is_bad_path returns False for file under base.""" + base = _get_resolved_path(str(tmp_path)) + # A relative path that resolves under base when joined + sub = os.path.join(str(tmp_path), "subdir") + os.makedirs(sub, exist_ok=True) + # Use a path relative to base + assert _is_bad_path("subdir/file.txt", base) is False + + def test_is_bad_path_absolute_escape(self): + """Test _is_bad_path returns True for absolute path outside base.""" base = _get_resolved_path("/tmp/safe") assert _is_bad_path("/etc/passwd", base) is True def test_is_bad_path_traversal(self): """Test _is_bad_path detects parent directory traversal.""" - from sagemaker.core.common_utils import _is_bad_path, _get_resolved_path - base = _get_resolved_path("/tmp/safe") assert _is_bad_path("../../etc/passwd", base) is True def test_is_bad_path_prefix_collision(self): - """Test _is_bad_path does NOT flag /tmp/safe2 when base is /tmp/safe. + """Test _is_bad_path correctly flags prefix collision. - This is the key test for the startswith() bug fix - /tmp/safe2 starts with - /tmp/safe but is NOT actually under /tmp/safe. + /tmp/safe2 starts with /tmp/safe but is NOT under /tmp/safe. + The old startswith() check would miss this; commonpath catches it. """ - from sagemaker.core.common_utils import _is_bad_path, _get_resolved_path - base = _get_resolved_path("/tmp/safe") - # /tmp/safe2 is NOT under /tmp/safe, but startswith would incorrectly say it is - # With the fix using commonpath, this should correctly identify it as outside base assert _is_bad_path("/tmp/safe2/file.txt", base) is True @@ -1217,9 +1225,7 @@ class TestGetSafeMembers: """Test _get_safe_members function for secure tar extraction.""" def test_get_safe_members_accepts_member_list_and_base(self): - """Test _get_safe_members works with a list of TarInfo mocks and a base path.""" - from sagemaker.core.common_utils import _get_safe_members, _get_resolved_path - + """Test _get_safe_members with a list of TarInfo mocks and base.""" base = _get_resolved_path("/tmp/extract") mock_member = Mock() @@ -1233,9 +1239,7 @@ def test_get_safe_members_accepts_member_list_and_base(self): assert mock_member in safe def test_get_safe_members_filters_bad_paths(self): - """Test _get_safe_members filters out members with bad paths.""" - from sagemaker.core.common_utils import _get_safe_members, _get_resolved_path - + """Test _get_safe_members filters members with bad paths.""" base = _get_resolved_path("/tmp/extract") mock_safe = Mock() @@ -1248,16 +1252,20 @@ def test_get_safe_members_filters_bad_paths(self): mock_bad.issym = Mock(return_value=False) mock_bad.islnk = Mock(return_value=False) - with patch('sagemaker.core.common_utils._is_bad_path') as mock_is_bad: - mock_is_bad.side_effect = lambda name, base: name == "/etc/passwd" - safe = list(_get_safe_members([mock_safe, mock_bad], base)) + with patch( + 'sagemaker.core.common_utils._is_bad_path' + ) as mock_is_bad: + mock_is_bad.side_effect = ( + lambda name, base: name == "/etc/passwd" + ) + safe = list( + _get_safe_members([mock_safe, mock_bad], base) + ) assert len(safe) == 1 assert mock_safe in safe def test_get_safe_members_filters_bad_symlinks(self): """Test _get_safe_members filters out bad symlinks.""" - from sagemaker.core.common_utils import _get_safe_members, _get_resolved_path - base = _get_resolved_path("/tmp/extract") mock_safe = Mock() @@ -1271,17 +1279,24 @@ def test_get_safe_members_filters_bad_symlinks(self): mock_symlink.islnk = Mock(return_value=False) mock_symlink.linkname = "/etc/passwd" - with patch('sagemaker.core.common_utils._is_bad_path', return_value=False): - with patch('sagemaker.core.common_utils._is_bad_link') as mock_is_bad_link: + with patch( + 'sagemaker.core.common_utils._is_bad_path', + return_value=False, + ): + with patch( + 'sagemaker.core.common_utils._is_bad_link' + ) as mock_is_bad_link: mock_is_bad_link.return_value = True - safe = list(_get_safe_members([mock_safe, mock_symlink], base)) + safe = list( + _get_safe_members( + [mock_safe, mock_symlink], base + ) + ) assert len(safe) == 1 assert mock_safe in safe def test_get_safe_members_filters_bad_hardlinks(self): """Test _get_safe_members filters out bad hardlinks.""" - from sagemaker.core.common_utils import _get_safe_members, _get_resolved_path - base = _get_resolved_path("/tmp/extract") mock_safe = Mock() @@ -1295,10 +1310,19 @@ def test_get_safe_members_filters_bad_hardlinks(self): mock_hardlink.islnk = Mock(return_value=True) mock_hardlink.linkname = "/etc/passwd" - with patch('sagemaker.core.common_utils._is_bad_path', return_value=False): - with patch('sagemaker.core.common_utils._is_bad_link') as mock_is_bad_link: + with patch( + 'sagemaker.core.common_utils._is_bad_path', + return_value=False, + ): + with patch( + 'sagemaker.core.common_utils._is_bad_link' + ) as mock_is_bad_link: mock_is_bad_link.return_value = True - safe = list(_get_safe_members([mock_safe, mock_hardlink], base)) + safe = list( + _get_safe_members( + [mock_safe, mock_hardlink], base + ) + ) assert len(safe) == 1 assert mock_safe in safe @@ -1308,35 +1332,24 @@ class TestValidateSourceDirectorySecurity: def test_validate_source_directory_blocks_sensitive_path(self): """Test that actual sensitive paths are blocked.""" - from sagemaker.core.common_utils import _validate_source_directory - with pytest.raises(ValueError, match="sensitive system paths"): _validate_source_directory("/etc/secrets") def test_validate_source_directory_prefix_collision(self): - """Test that /etcetera is NOT blocked when /etc is in sensitive paths. + """Test /etcetera is NOT blocked by /etc sensitive path. - This tests the fix for the startswith() prefix collision vulnerability. + This validates the commonpath fix for startswith() prefix + collision vulnerability. """ - from sagemaker.core.common_utils import _validate_source_directory - - # /etcetera should NOT be blocked - it's not under /etc - # With the old startswith() check, this would incorrectly raise ValueError - try: - _validate_source_directory("/etcetera") - except ValueError: - pytest.fail("_validate_source_directory incorrectly blocked /etcetera due to prefix collision with /etc") + # Should NOT raise - /etcetera is not under /etc + _validate_source_directory("/etcetera") def test_validate_source_directory_s3_path(self): """Test that S3 paths are allowed.""" - from sagemaker.core.common_utils import _validate_source_directory - _validate_source_directory("s3://my-bucket/my-prefix") def test_validate_source_directory_none(self): """Test that None is allowed.""" - from sagemaker.core.common_utils import _validate_source_directory - _validate_source_directory(None) @@ -1345,28 +1358,20 @@ class TestValidateDependencyPathSecurity: def test_validate_dependency_path_blocks_sensitive_path(self): """Test that actual sensitive paths are blocked.""" - from sagemaker.core.common_utils import _validate_dependency_path - with pytest.raises(ValueError, match="sensitive system paths"): _validate_dependency_path("/root/.bashrc") def test_validate_dependency_path_prefix_collision(self): - """Test that /rootkit is NOT blocked when /root is in sensitive paths. + """Test /rootkit is NOT blocked by /root sensitive path. - This tests the fix for the startswith() prefix collision vulnerability. + This validates the commonpath fix for startswith() prefix + collision vulnerability. """ - from sagemaker.core.common_utils import _validate_dependency_path - - # /rootkit should NOT be blocked - it's not under /root - try: - _validate_dependency_path("/rootkit") - except ValueError: - pytest.fail("_validate_dependency_path incorrectly blocked /rootkit due to prefix collision with /root") + # Should NOT raise - /rootkit is not under /root + _validate_dependency_path("/rootkit") def test_validate_dependency_path_none(self): """Test that None is allowed.""" - from sagemaker.core.common_utils import _validate_dependency_path - _validate_dependency_path(None) @@ -2462,67 +2467,62 @@ class TestValidateSourceDirectory: def test_validate_source_directory_none(self): """Test with None source directory.""" - from sagemaker.core.common_utils import _validate_source_directory - # Should not raise _validate_source_directory(None) def test_validate_source_directory_s3_path(self): """Test with S3 path.""" - from sagemaker.core.common_utils import _validate_source_directory - # Should not raise for S3 paths _validate_source_directory("s3://my-bucket/my-code") def test_validate_source_directory_valid_local_path(self): """Test with valid local path.""" - from sagemaker.core.common_utils import _validate_source_directory - with tempfile.TemporaryDirectory() as tmpdir: # Should not raise for valid local paths _validate_source_directory(tmpdir) def test_validate_source_directory_sensitive_path_aws(self): """Test rejection of ~/.aws path.""" - from sagemaker.core.common_utils import _validate_source_directory - aws_dir = os.path.expanduser("~/.aws") if os.path.exists(aws_dir): - with pytest.raises(ValueError, match="cannot access sensitive system paths"): + with pytest.raises( + ValueError, + match="cannot access sensitive system paths", + ): _validate_source_directory(aws_dir) def test_validate_source_directory_sensitive_path_ssh(self): """Test rejection of ~/.ssh path.""" - from sagemaker.core.common_utils import _validate_source_directory - ssh_dir = os.path.expanduser("~/.ssh") if os.path.exists(ssh_dir): - with pytest.raises(ValueError, match="cannot access sensitive system paths"): + with pytest.raises( + ValueError, + match="cannot access sensitive system paths", + ): _validate_source_directory(ssh_dir) def test_validate_source_directory_sensitive_path_root(self): """Test rejection of /root path.""" - from sagemaker.core.common_utils import _validate_source_directory - - # Test with /root which is a sensitive path - if os.path.exists("/root") and os.access("/root", os.R_OK): - with pytest.raises(ValueError, match="cannot access sensitive system paths"): + if ( + os.path.exists("/root") + and os.access("/root", os.R_OK) + ): + with pytest.raises( + ValueError, + match="cannot access sensitive system paths", + ): _validate_source_directory("/root") def test_validate_source_directory_symlink_resolution(self): """Test that symlinks are resolved correctly.""" - from sagemaker.core.common_utils import _validate_source_directory - with tempfile.TemporaryDirectory() as tmpdir: - # Create a real directory real_dir = os.path.join(tmpdir, "real_code") os.makedirs(real_dir) - # Create a symlink to it symlink_path = os.path.join(tmpdir, "link_to_code") os.symlink(real_dir, symlink_path) - # Should not raise - symlink should be resolved and validated + # Should not raise _validate_source_directory(symlink_path) @@ -2531,51 +2531,45 @@ class TestValidateDependencyPath: def test_validate_dependency_path_none(self): """Test with None dependency.""" - from sagemaker.core.common_utils import _validate_dependency_path - # Should not raise _validate_dependency_path(None) def test_validate_dependency_path_valid_local_path(self): """Test with valid local path.""" - from sagemaker.core.common_utils import _validate_dependency_path - with tempfile.TemporaryDirectory() as tmpdir: # Should not raise for valid local paths _validate_dependency_path(tmpdir) def test_validate_dependency_path_sensitive_path_aws(self): """Test rejection of ~/.aws path.""" - from sagemaker.core.common_utils import _validate_dependency_path - aws_dir = os.path.expanduser("~/.aws") if os.path.exists(aws_dir): - with pytest.raises(ValueError, match="cannot access sensitive system paths"): + with pytest.raises( + ValueError, + match="cannot access sensitive system paths", + ): _validate_dependency_path(aws_dir) def test_validate_dependency_path_sensitive_path_credentials(self): """Test rejection of ~/.credentials path.""" - from sagemaker.core.common_utils import _validate_dependency_path - creds_dir = os.path.expanduser("~/.credentials") if os.path.exists(creds_dir): - with pytest.raises(ValueError, match="cannot access sensitive system paths"): + with pytest.raises( + ValueError, + match="cannot access sensitive system paths", + ): _validate_dependency_path(creds_dir) def test_validate_dependency_path_symlink_resolution(self): """Test that symlinks are resolved correctly.""" - from sagemaker.core.common_utils import _validate_dependency_path - with tempfile.TemporaryDirectory() as tmpdir: - # Create a real directory real_dir = os.path.join(tmpdir, "real_lib") os.makedirs(real_dir) - # Create a symlink to it symlink_path = os.path.join(tmpdir, "link_to_lib") os.symlink(real_dir, symlink_path) - # Should not raise - symlink should be resolved and validated + # Should not raise _validate_dependency_path(symlink_path)