From 6f73fbff3ce7221c90456c56fb3e40aafcbb1d39 Mon Sep 17 00:00:00 2001 From: Molly He Date: Fri, 27 Mar 2026 15:07:42 -0700 Subject: [PATCH] Update accept_eula to respect user setup --- .../src/sagemaker/serve/model_builder.py | 9 ++- .../tests/unit/test_model_builder.py | 71 +++++++++++++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index 4e09c64abc..a065e5ff87 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -2373,6 +2373,13 @@ def _build_single_modelbuilder( "HostingArtifactUri not found in JumpStart hub metadata. " "Cannot deploy LORA adapter without base model artifacts." ) + accept_eula = getattr(self, "accept_eula", None) + if not accept_eula: + raise ValueError( + "accept_eula must be set to True to deploy this model. " + "Please set accept_eula=True on the ModelBuilder instance to confirm " + "you have read and accepted the end-user license agreement for this model." + ) container_def = ContainerDefinition( image=self.image_uri, environment=self.env_vars, @@ -2381,7 +2388,7 @@ def _build_single_modelbuilder( "s3_uri": hosting_artifact_uri, "s3_data_type": "S3Prefix", "compression_type": "None", - "model_access_config": {"accept_eula": True}, + "model_access_config": {"accept_eula": accept_eula}, } }, ) diff --git a/sagemaker-serve/tests/unit/test_model_builder.py b/sagemaker-serve/tests/unit/test_model_builder.py index e5900f5562..20ea39c0a0 100644 --- a/sagemaker-serve/tests/unit/test_model_builder.py +++ b/sagemaker-serve/tests/unit/test_model_builder.py @@ -715,3 +715,74 @@ def test_deploy_passes_inference_config_to_model_customization(self): call_kwargs = mock_deploy_mc.call_args[1] self.assertEqual(call_kwargs['inference_config'], inference_config) self.assertEqual(result, mock_endpoint) + + +class TestLoraAcceptEula(unittest.TestCase): + """Tests for accept_eula handling in the LoRA deployment path.""" + + def _make_mb(self, accept_eula=None): + mb = ModelBuilder.__new__(ModelBuilder) + mb.accept_eula = accept_eula + mb.image_uri = "some-image-uri" + mb.env_vars = {} + mb.model_name = None + mb.role_arn = "arn:aws:iam::123456789012:role/role" + mb.model = MagicMock() + mb._adapter_s3_uri = None + return mb + + def _patch_lora_deps(self, mb, hosting_uri="s3://bucket/hosting/"): + """Patch all dependencies needed to reach the LoRA ContainerDefinition block.""" + patches = [ + patch.object(mb, "_fetch_peft", return_value="LORA"), + patch.object(mb, "_fetch_hub_document_for_custom_model", + return_value={"HostingArtifactUri": hosting_uri}), + patch.object(mb, "_get_model_package_for_training_job", + return_value=MagicMock()), + ] + return patches + + def test_lora_build_raises_when_accept_eula_false(self): + mb = self._make_mb(accept_eula=False) + patches = self._patch_lora_deps(mb) + for p in patches: + p.start() + try: + with self.assertRaises(ValueError) as ctx: + mb._build_single_modelbuilder() + self.assertIn("accept_eula", str(ctx.exception)) + finally: + for p in patches: + p.stop() + + def test_lora_build_raises_when_accept_eula_not_set(self): + mb = self._make_mb(accept_eula=None) + patches = self._patch_lora_deps(mb) + for p in patches: + p.start() + try: + with self.assertRaises(ValueError) as ctx: + mb._build_single_modelbuilder() + self.assertIn("accept_eula", str(ctx.exception)) + finally: + for p in patches: + p.stop() + + @patch("sagemaker.serve.model_builder.ContainerDefinition") + @patch("sagemaker.serve.model_builder.Model") + def test_lora_build_passes_accept_eula_true(self, mock_model, mock_container_def): + mb = self._make_mb(accept_eula=True) + mock_model.create.return_value = MagicMock() + patches = self._patch_lora_deps(mb) + for p in patches: + p.start() + try: + mb._build_single_modelbuilder() + call_kwargs = mock_container_def.call_args[1] + eula_val = ( + call_kwargs["model_data_source"]["s3_data_source"]["model_access_config"]["accept_eula"] + ) + self.assertTrue(eula_val) + finally: + for p in patches: + p.stop()