Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/databricks/sql/backend/sea/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,7 @@ def execute_command(
async_op: bool,
enforce_embedded_schema_correctness: bool,
row_limit: Optional[int] = None,
query_tags: Optional[
Dict[str, Optional[str]]
] = None, # TODO: implement query_tags for SEA backend
query_tags: Optional[Dict[str, Optional[str]]] = None,
) -> Union[SeaResultSet, None]:
"""
Execute a SQL command using the SEA backend.
Expand Down Expand Up @@ -532,6 +530,7 @@ def execute_command(
row_limit=row_limit,
parameters=sea_parameters if sea_parameters else None,
result_compression=result_compression,
query_tags=query_tags,
)

response_data = self._http_client._make_request(
Expand Down
9 changes: 9 additions & 0 deletions src/databricks/sql/backend/sea/models/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ExecuteStatementRequest:
wait_timeout: str = "10s"
on_wait_timeout: str = "CONTINUE"
row_limit: Optional[int] = None
query_tags: Optional[Dict[str, Optional[str]]] = None

def to_dict(self) -> Dict[str, Any]:
"""Convert the request to a dictionary for JSON serialization."""
Expand Down Expand Up @@ -60,6 +61,14 @@ def to_dict(self) -> Dict[str, Any]:
for param in self.parameters
]

# SEA API expects query_tags as an array of {key, value} objects.
# None values are represented by omitting the "value" field.
if self.query_tags:
result["query_tags"] = [
{"key": k, "value": v} if v is not None else {"key": k}
for k, v in self.query_tags.items()
]

return result


Expand Down
110 changes: 108 additions & 2 deletions tests/unit/test_sea_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i
session_config = {
"ANSI_MODE": "FALSE", # Supported parameter
"STATEMENT_TIMEOUT": "3600", # Supported parameter
"QUERY_TAGS": "team:marketing,dashboard:abc123", # Supported parameter
"QUERY_TAGS": "team:marketing,dashboard:abc123", # Supported parameter
"unsupported_param": "value", # Unsupported parameter
}
catalog = "test_catalog"
Expand All @@ -197,7 +197,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i
"session_confs": {
"ansi_mode": "FALSE",
"statement_timeout": "3600",
"query_tags": "team:marketing,dashboard:abc123",
"query_tags": "team:marketing,dashboard:abc123",
},
"catalog": catalog,
"schema": schema,
Expand Down Expand Up @@ -416,6 +416,112 @@ def test_command_execution_advanced(
)
assert "Command failed" in str(excinfo.value)

def _execute_response(self):
return {
"statement_id": "test-statement-123",
"status": {"state": "SUCCEEDED"},
"manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0},
"result": {"data": []},
}

def _run_execute_command(self, sea_client, sea_session_id, mock_cursor, **kwargs):
"""Helper to invoke execute_command with default args."""
return sea_client.execute_command(
operation="SELECT 1",
session_id=sea_session_id,
max_rows=100,
max_bytes=1000,
lz4_compression=False,
cursor=mock_cursor,
use_cloud_fetch=False,
parameters=[],
async_op=False,
enforce_embedded_schema_correctness=False,
**kwargs,
)

def test_execute_command_query_tags_string_values(
self, sea_client, mock_http_client, mock_cursor, sea_session_id
):
"""query_tags with string values are included in the request payload."""
mock_http_client._make_request.return_value = self._execute_response()
with patch.object(sea_client, "_response_to_result_set"):
self._run_execute_command(
sea_client,
sea_session_id,
mock_cursor,
query_tags={"env": "prod", "team": "data"},
)
_, kwargs = mock_http_client._make_request.call_args
assert kwargs["data"]["query_tags"] == [
{"key": "env", "value": "prod"},
{"key": "team", "value": "data"},
]

def test_execute_command_query_tags_none_value(
self, sea_client, mock_http_client, mock_cursor, sea_session_id
):
"""query_tags with a None value omit the value field (key-only tag)."""
mock_http_client._make_request.return_value = self._execute_response()
with patch.object(sea_client, "_response_to_result_set"):
self._run_execute_command(
sea_client,
sea_session_id,
mock_cursor,
query_tags={"env": "prod", "team": None},
)
_, kwargs = mock_http_client._make_request.call_args
assert kwargs["data"]["query_tags"] == [
{"key": "env", "value": "prod"},
{"key": "team"},
]

def test_execute_command_no_query_tags_omitted(
self, sea_client, mock_http_client, mock_cursor, sea_session_id
):
"""query_tags field is absent from the request when not provided."""
mock_http_client._make_request.return_value = self._execute_response()
with patch.object(sea_client, "_response_to_result_set"):
self._run_execute_command(sea_client, sea_session_id, mock_cursor)
_, kwargs = mock_http_client._make_request.call_args
assert "query_tags" not in kwargs["data"]

def test_execute_command_empty_query_tags_omitted(
self, sea_client, mock_http_client, mock_cursor, sea_session_id
):
"""Empty query_tags dict is treated as absent — field omitted from request."""
mock_http_client._make_request.return_value = self._execute_response()
with patch.object(sea_client, "_response_to_result_set"):
self._run_execute_command(
sea_client, sea_session_id, mock_cursor, query_tags={}
)
_, kwargs = mock_http_client._make_request.call_args
assert "query_tags" not in kwargs["data"]

def test_execute_command_async_query_tags(
self, sea_client, mock_http_client, mock_cursor, sea_session_id
):
"""query_tags are included in async execute requests (execute_async path)."""
mock_http_client._make_request.return_value = {
"statement_id": "test-statement-async",
"status": {"state": "PENDING"},
}
sea_client.execute_command(
operation="SELECT 1",
session_id=sea_session_id,
max_rows=100,
max_bytes=1000,
lz4_compression=False,
cursor=mock_cursor,
use_cloud_fetch=False,
parameters=[],
async_op=True,
enforce_embedded_schema_correctness=False,
query_tags={"job": "nightly-etl"},
)
_, kwargs = mock_http_client._make_request.call_args
assert kwargs["data"]["query_tags"] == [{"key": "job", "value": "nightly-etl"}]

def test_command_management(
self,
sea_client,
Expand Down
Loading