diff --git a/sagemaker-train/src/sagemaker/ai_registry/evaluator.py b/sagemaker-train/src/sagemaker/ai_registry/evaluator.py index 349705f08d..5d0b47939f 100644 --- a/sagemaker-train/src/sagemaker/ai_registry/evaluator.py +++ b/sagemaker-train/src/sagemaker/ai_registry/evaluator.py @@ -382,7 +382,7 @@ def _create_lambda_function(cls, name: str, source_file: str, role: Optional[str # Create Lambda function lambda_client = boto3.client("lambda") function_name = f"SageMaker-evaluator-{name}-{datetime.now().strftime('%Y%m%d_%H%M%S')}" - handler_name = f"{os.path.splitext(os.path.basename(source_file))[0]}.lambda_handler" + handler_name = "lambda_function.lambda_handler" try: lambda_response = lambda_client.create_function( diff --git a/sagemaker-train/tests/integ/ai_registry/conftest.py b/sagemaker-train/tests/integ/ai_registry/conftest.py index 9f7c2069f1..755c97f09d 100644 --- a/sagemaker-train/tests/integ/ai_registry/conftest.py +++ b/sagemaker-train/tests/integ/ai_registry/conftest.py @@ -56,6 +56,22 @@ def sample_jsonl_file(): os.unlink(f.name) +@pytest.fixture +def sample_lambda_py_file(): + """Create a raw Python Lambda file with a non-default filename to test handler derivation.""" + code = '''import json +def lambda_handler(event, context): + return {"statusCode": 200, "body": json.dumps({"score": 0.9})} +''' + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', prefix='my_custom_evaluator_', delete=False) as f: + f.write(code) + f.flush() + os.fsync(f.fileno()) + fname = f.name + yield fname + os.unlink(fname) + + @pytest.fixture def sample_lambda_code(): """Create sample Lambda function code as zip.""" diff --git a/sagemaker-train/tests/integ/ai_registry/test_evaluator.py b/sagemaker-train/tests/integ/ai_registry/test_evaluator.py index 8cb029055e..51497c6cfd 100644 --- a/sagemaker-train/tests/integ/ai_registry/test_evaluator.py +++ b/sagemaker-train/tests/integ/ai_registry/test_evaluator.py @@ -81,6 +81,47 @@ def test_create_reward_function_from_local_code(self, unique_name, sample_lambda assert evaluator.method == EvaluatorMethod.BYOC assert evaluator.reference is not None + def test_create_reward_function_from_local_py_file_and_invoke( + self, unique_name, sample_lambda_py_file, test_role, cleanup_list + ): + """End-to-end test: create evaluator from a raw .py file with non-default name and invoke it. + + Regression test for the handler name bug where the Lambda was created with an incorrect + handler derived from the source filename instead of 'lambda_function.lambda_handler'. + """ + import json + import boto3 + + evaluator = Evaluator.create( + name=unique_name, + type=REWARD_FUNCTION, + source=sample_lambda_py_file, + role=test_role, + wait=True, # wait for Lambda to be active + ) + cleanup_list.append(evaluator) + assert evaluator.method == EvaluatorMethod.BYOC + assert evaluator.reference is not None + + # Wait for Lambda to become Active before invoking + lambda_client = boto3.client("lambda") + waiter = lambda_client.get_waiter("function_active_v2") + waiter.wait(FunctionName=evaluator.reference) + + # Invoke the Lambda directly to verify the handler is correct + lambda_client = boto3.client("lambda") + response = lambda_client.invoke( + FunctionName=evaluator.reference, + InvocationType="RequestResponse", + Payload=json.dumps({"input": "test"}).encode(), + ) + assert response["StatusCode"] == 200 + assert "FunctionError" not in response, ( + f"Lambda invocation failed with error: {response.get('FunctionError')}" + ) + result = json.loads(response["Payload"].read()) + assert result.get("statusCode") == 200 + def test_get_evaluator(self, unique_name, sample_prompt_file, cleanup_list): """Test retrieving evaluator by name.""" try: diff --git a/sagemaker-train/tests/unit/ai_registry/test_evaluator.py b/sagemaker-train/tests/unit/ai_registry/test_evaluator.py index 326978d5ba..3c4b233473 100644 --- a/sagemaker-train/tests/unit/ai_registry/test_evaluator.py +++ b/sagemaker-train/tests/unit/ai_registry/test_evaluator.py @@ -77,6 +77,9 @@ def test_create_with_byoc(self, mock_air_hub, mock_boto3): assert evaluator.method == EvaluatorMethod.BYOC mock_air_hub.upload_to_s3.assert_called_once() + mock_lambda_client.create_function.assert_called_once() + call_kwargs = mock_lambda_client.create_function.call_args[1] + assert call_kwargs["Handler"] == "lambda_function.lambda_handler" @patch('sagemaker.ai_registry.evaluator.AIRHub') def test_get_all(self, mock_air_hub):