diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index a6cff4913..ff130cd39 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -463,9 +463,7 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, row_limit: Optional[int] = None, - query_tags: Optional[ - Dict[str, Optional[str]] - ] = None, # TODO: implement query_tags for SEA backend + query_tags: Optional[Dict[str, Optional[str]]] = None, ) -> Union[SeaResultSet, None]: """ Execute a SQL command using the SEA backend. @@ -532,6 +530,7 @@ def execute_command( row_limit=row_limit, parameters=sea_parameters if sea_parameters else None, result_compression=result_compression, + query_tags=query_tags, ) response_data = self._http_client._make_request( diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index ad046ff54..9a11e70df 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -31,6 +31,7 @@ class ExecuteStatementRequest: wait_timeout: str = "10s" on_wait_timeout: str = "CONTINUE" row_limit: Optional[int] = None + query_tags: Optional[Dict[str, Optional[str]]] = None def to_dict(self) -> Dict[str, Any]: """Convert the request to a dictionary for JSON serialization.""" @@ -60,6 +61,14 @@ def to_dict(self) -> Dict[str, Any]: for param in self.parameters ] + # SEA API expects query_tags as an array of {key, value} objects. + # None values are represented by omitting the "value" field. + if self.query_tags: + result["query_tags"] = [ + {"key": k, "value": v} if v is not None else {"key": k} + for k, v in self.query_tags.items() + ] + return result diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 26a898cb8..0b2f39b2a 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -185,7 +185,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i session_config = { "ANSI_MODE": "FALSE", # Supported parameter "STATEMENT_TIMEOUT": "3600", # Supported parameter - "QUERY_TAGS": "team:marketing,dashboard:abc123", # Supported parameter + "QUERY_TAGS": "team:marketing,dashboard:abc123", # Supported parameter "unsupported_param": "value", # Unsupported parameter } catalog = "test_catalog" @@ -197,7 +197,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i "session_confs": { "ansi_mode": "FALSE", "statement_timeout": "3600", - "query_tags": "team:marketing,dashboard:abc123", + "query_tags": "team:marketing,dashboard:abc123", }, "catalog": catalog, "schema": schema, @@ -416,6 +416,112 @@ def test_command_execution_advanced( ) assert "Command failed" in str(excinfo.value) + def _execute_response(self): + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, + } + + def _run_execute_command(self, sea_client, sea_session_id, mock_cursor, **kwargs): + """Helper to invoke execute_command with default args.""" + return sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + **kwargs, + ) + + def test_execute_command_query_tags_string_values( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """query_tags with string values are included in the request payload.""" + mock_http_client._make_request.return_value = self._execute_response() + with patch.object(sea_client, "_response_to_result_set"): + self._run_execute_command( + sea_client, + sea_session_id, + mock_cursor, + query_tags={"env": "prod", "team": "data"}, + ) + _, kwargs = mock_http_client._make_request.call_args + assert kwargs["data"]["query_tags"] == [ + {"key": "env", "value": "prod"}, + {"key": "team", "value": "data"}, + ] + + def test_execute_command_query_tags_none_value( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """query_tags with a None value omit the value field (key-only tag).""" + mock_http_client._make_request.return_value = self._execute_response() + with patch.object(sea_client, "_response_to_result_set"): + self._run_execute_command( + sea_client, + sea_session_id, + mock_cursor, + query_tags={"env": "prod", "team": None}, + ) + _, kwargs = mock_http_client._make_request.call_args + assert kwargs["data"]["query_tags"] == [ + {"key": "env", "value": "prod"}, + {"key": "team"}, + ] + + def test_execute_command_no_query_tags_omitted( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """query_tags field is absent from the request when not provided.""" + mock_http_client._make_request.return_value = self._execute_response() + with patch.object(sea_client, "_response_to_result_set"): + self._run_execute_command(sea_client, sea_session_id, mock_cursor) + _, kwargs = mock_http_client._make_request.call_args + assert "query_tags" not in kwargs["data"] + + def test_execute_command_empty_query_tags_omitted( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Empty query_tags dict is treated as absent — field omitted from request.""" + mock_http_client._make_request.return_value = self._execute_response() + with patch.object(sea_client, "_response_to_result_set"): + self._run_execute_command( + sea_client, sea_session_id, mock_cursor, query_tags={} + ) + _, kwargs = mock_http_client._make_request.call_args + assert "query_tags" not in kwargs["data"] + + def test_execute_command_async_query_tags( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """query_tags are included in async execute requests (execute_async path).""" + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-async", + "status": {"state": "PENDING"}, + } + sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, + enforce_embedded_schema_correctness=False, + query_tags={"job": "nightly-etl"}, + ) + _, kwargs = mock_http_client._make_request.call_args + assert kwargs["data"]["query_tags"] == [{"key": "job", "value": "nightly-etl"}] + def test_command_management( self, sea_client,