From 0b225ac1084d5f0c12dee3eb51f42cdb4d14c693 Mon Sep 17 00:00:00 2001 From: Kevin Zheng Date: Fri, 6 Mar 2026 15:48:15 +0000 Subject: [PATCH 1/9] feat: Added a batch completed callback to the data client mutations batcher --- google/cloud/bigtable/data/_async/client.py | 4 + .../bigtable/data/_async/mutations_batcher.py | 18 ++++- google/cloud/bigtable/data/_helpers.py | 59 ++++++++++++++ .../bigtable/data/_sync_autogen/client.py | 5 +- .../data/_sync_autogen/mutations_batcher.py | 13 ++- google/cloud/bigtable/table.py | 44 ++-------- tests/system/data/test_system_async.py | 37 +++++++++ tests/system/data/test_system_autogen.py | 29 +++++++ .../data/_async/test_mutations_batcher.py | 80 ++++++++++++++++++- .../_sync_autogen/test_mutations_batcher.py | 77 +++++++++++++++++- tests/unit/data/test__helpers.py | 68 ++++++++++++++++ 11 files changed, 392 insertions(+), 42 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 35fe42814..957faf7a1 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -16,6 +16,7 @@ from __future__ import annotations from typing import ( + Callable, cast, Any, AsyncIterable, @@ -115,6 +116,7 @@ if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery + from google.rpc import status_pb2 if CrossSync.is_async: from google.cloud.bigtable.data._async.mutations_batcher import ( @@ -1437,6 +1439,7 @@ def mutations_batcher( batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + _batch_completed_callback: Optional[Callable[list[status_pb2.Status]]] = None, ) -> "MutationsBatcherAsync": """ Returns a new mutations batcher instance. @@ -1472,6 +1475,7 @@ def mutations_batcher( batch_operation_timeout=batch_operation_timeout, batch_attempt_timeout=batch_attempt_timeout, batch_retryable_errors=batch_retryable_errors, + _batch_completed_callback=_batch_completed_callback, ) @CrossSync.convert diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 6d87ff5d2..2fe9cab94 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -14,7 +14,7 @@ # from __future__ import annotations -from typing import Sequence, TYPE_CHECKING, cast +from typing import Callable, Optional, Sequence, TYPE_CHECKING, cast import atexit import warnings from collections import deque @@ -24,6 +24,10 @@ from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts +from google.cloud.bigtable.data._helpers import ( + _populate_statuses_from_mutations_exception_group, +) + from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data.mutations import ( @@ -33,6 +37,9 @@ from google.cloud.bigtable.data._cross_sync import CrossSync +from google.rpc import code_pb2 +from google.rpc import status_pb2 + if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -223,6 +230,7 @@ def __init__( batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + _batch_completed_callback: Optional[Callable[list[status_pb2.Status]]] = None, ): self._operation_timeout, self._attempt_timeout = _get_timeouts( batch_operation_timeout, batch_attempt_timeout, target @@ -269,6 +277,7 @@ def __init__( self._newest_exceptions: deque[Exception] = deque( maxlen=self._exception_list_limit ) + self._user_batch_completed_callback = _batch_completed_callback # clean up on program exit atexit.register(self._on_exit) @@ -380,6 +389,7 @@ async def _execute_mutate_rows( list of FailedMutationEntryError objects for mutations that failed. FailedMutationEntryError objects will not contain index information """ + statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len(batch) try: operation = CrossSync._MutateRowsOperation( self._target.client._gapic_client, @@ -391,6 +401,8 @@ async def _execute_mutate_rows( ) await operation.start() except MutationsExceptionGroup as e: + _populate_statuses_from_mutations_exception_group(statuses, e) + # strip index information from exceptions, since it is not useful in a batch context for subexc in e.exceptions: subexc.index = None @@ -398,6 +410,10 @@ async def _execute_mutate_rows( finally: # mark batch as complete in flow control await self._flow_control.remove_from_flow(batch) + + # Call batch done callback with list of statuses. + if self._user_batch_completed_callback: + self._user_batch_completed_callback(statuses) return [] def _add_exceptions(self, excs: list[Exception]): diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index f88fa3078..5a6dd888d 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -26,6 +26,10 @@ from google.api_core import retry as retries from google.api_core.retry import RetryFailureReason from google.cloud.bigtable.data.exceptions import RetryExceptionGroup +from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup +from google.rpc import code_pb2 +from google.rpc import status_pb2 + if TYPE_CHECKING: import grpc @@ -224,6 +228,61 @@ def _align_timeouts(operation: float, attempt: float | None) -> tuple[float, flo return operation, final_attempt +def _populate_statuses_from_mutations_exception_group( + statuses: list[status_pb2.Status], exc_group: MutationsExceptionGroup +): + """ + Helper function that populates a list of Status objects with exception information from + the exception group. + + Args: + statuses: The initial list of Status objects + exc_group: The exception group from a mutate rows operation + """ + # We exception handle as follows: + # + # 1. Each exception in the error group is a FailedMutationEntryError, and its + # cause is either a singular exception or a RetryExceptionGroup consisting of + # multiple exceptions. + # + # 2. In the case of a singular exception, if the error does not have a gRPC status + # code, we return a status code of UNKNOWN. + # + # 3. In the case of a RetryExceptionGroup, we use terminal exception in the exception + # group and process that. + for error in exc_group.exceptions: + cause = error.__cause__ + if isinstance(cause, RetryExceptionGroup): + statuses[error.index] = _get_status(cause.exceptions[-1]) + else: + statuses[error.index] = _get_status(cause) + + +def _get_status(exc: Exception) -> status_pb2.Status: + """ + Helper function that returns a Status object corresponding to the given exception. + + Args: + exc: An exception to be converted into a Status. + Returns: + status_pb2.Status: A Status proto object. + """ + if ( + isinstance(exc, core_exceptions.GoogleAPICallError) + and exc.grpc_status_code is not None + ): + return status_pb2.Status( + code=exc.grpc_status_code.value[0], + message=exc.message, + details=exc.details, + ) + + return status_pb2.Status( + code=code_pb2.Code.UNKNOWN, + message=str(exc), + ) + + def _validate_timeouts( operation_timeout: float, attempt_timeout: float | None, allow_none: bool = False ): diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index 88136ddad..78c5969da 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -17,7 +17,7 @@ # This file is automatically generated by CrossSync. Do not edit manually. from __future__ import annotations -from typing import cast, Any, Optional, Set, Sequence, TYPE_CHECKING +from typing import Callable, cast, Any, Optional, Set, Sequence, TYPE_CHECKING import abc import time import warnings @@ -87,6 +87,7 @@ if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery + from google.rpc import status_pb2 from google.cloud.bigtable.data._sync_autogen.mutations_batcher import ( MutationsBatcher, ) @@ -1190,6 +1191,7 @@ def mutations_batcher( batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + _batch_completed_callback: Optional[Callable[list[status_pb2.Status]]] = None, ) -> "MutationsBatcher": """Returns a new mutations batcher instance. @@ -1224,6 +1226,7 @@ def mutations_batcher( batch_operation_timeout=batch_operation_timeout, batch_attempt_timeout=batch_attempt_timeout, batch_retryable_errors=batch_retryable_errors, + _batch_completed_callback=_batch_completed_callback, ) def mutate_row( diff --git a/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py index c14606de9..a9ae22d42 100644 --- a/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py @@ -16,7 +16,7 @@ # This file is automatically generated by CrossSync. Do not edit manually. from __future__ import annotations -from typing import Sequence, TYPE_CHECKING, cast +from typing import Callable, Optional, Sequence, TYPE_CHECKING, cast import atexit import warnings from collections import deque @@ -25,10 +25,15 @@ from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts +from google.cloud.bigtable.data._helpers import ( + _populate_statuses_from_mutations_exception_group, +) from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT from google.cloud.bigtable.data.mutations import Mutation from google.cloud.bigtable.data._cross_sync import CrossSync +from google.rpc import code_pb2 +from google.rpc import status_pb2 if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -192,6 +197,7 @@ def __init__( batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + _batch_completed_callback: Optional[Callable[list[status_pb2.Status]]] = None, ): (self._operation_timeout, self._attempt_timeout) = _get_timeouts( batch_operation_timeout, batch_attempt_timeout, target @@ -233,6 +239,7 @@ def __init__( self._newest_exceptions: deque[Exception] = deque( maxlen=self._exception_list_limit ) + self._user_batch_completed_callback = _batch_completed_callback atexit.register(self._on_exit) def _timer_routine(self, interval: float | None) -> None: @@ -324,6 +331,7 @@ def _execute_mutate_rows( list[FailedMutationEntryError]: list of FailedMutationEntryError objects for mutations that failed. FailedMutationEntryError objects will not contain index information""" + statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len(batch) try: operation = CrossSync._Sync_Impl._MutateRowsOperation( self._target.client._gapic_client, @@ -335,11 +343,14 @@ def _execute_mutate_rows( ) operation.start() except MutationsExceptionGroup as e: + _populate_statuses_from_mutations_exception_group(statuses, e) for subexc in e.exceptions: subexc.index = None return list(e.exceptions) finally: self._flow_control.remove_from_flow(batch) + if self._user_batch_completed_callback: + self._user_batch_completed_callback(statuses) return [] def _add_exceptions(self, excs: list[Exception]): diff --git a/google/cloud/bigtable/table.py b/google/cloud/bigtable/table.py index 8136c3f9a..b478d3f01 100644 --- a/google/cloud/bigtable/table.py +++ b/google/cloud/bigtable/table.py @@ -17,7 +17,6 @@ from typing import Set import warnings -from google.api_core.exceptions import GoogleAPICallError from google.api_core.exceptions import Aborted from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import NotFound @@ -31,10 +30,10 @@ from google.cloud.bigtable.column_family import _gc_rule_from_pb from google.cloud.bigtable.column_family import ColumnFamily from google.cloud.bigtable.data._helpers import TABLE_DEFAULT -from google.cloud.bigtable.data.exceptions import ( - RetryExceptionGroup, - MutationsExceptionGroup, +from google.cloud.bigtable.data._helpers import ( + _populate_statuses_from_mutations_exception_group, ) +from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup from google.cloud.bigtable.data.mutations import RowMutationEntry from google.cloud.bigtable.batcher import MutationsBatcher from google.cloud.bigtable.batcher import FLUSH_COUNT, MAX_MUTATION_SIZE @@ -774,41 +773,12 @@ def mutate_rows(self, rows, retry=DEFAULT_RETRY, timeout=DEFAULT): retryable_errors=retryable_errors, ) except MutationsExceptionGroup as mut_exc_group: - # We exception handle as follows: - # - # 1. Each exception in the error group is a FailedMutationEntryError, and its - # cause is either a singular exception or a RetryExceptionGroup consisting of - # multiple exceptions. - # - # 2. In the case of a singular exception, if the error does not have a gRPC status - # code, we return a status code of UNKNOWN. - # - # 3. In the case of a RetryExceptionGroup, we use terminal exception in the exception - # group and process that. - for error in mut_exc_group.exceptions: - cause = error.__cause__ - if isinstance(cause, RetryExceptionGroup): - return_statuses[error.index] = self._get_status( - cause.exceptions[-1] - ) - else: - return_statuses[error.index] = self._get_status(cause) - - return return_statuses - - @staticmethod - def _get_status(error): - if isinstance(error, GoogleAPICallError) and error.grpc_status_code is not None: - return status_pb2.Status( - code=error.grpc_status_code.value[0], - message=error.message, - details=error.details, + _populate_statuses_from_mutations_exception_group( + return_statuses, + mut_exc_group, ) - return status_pb2.Status( - code=code_pb2.Code.UNKNOWN, - message=str(error), - ) + return return_statuses def sample_row_keys(self): """Read a sample of row keys in the table. diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index c96570b76..4aa589027 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -482,6 +482,43 @@ async def test_mutations_batcher_timer_flush(self, client, target, temp_rows): # ensure cell is updated assert (await self._retrieve_cell_value(target, row_key)) == new_value + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("target") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_mutations_batcher_completed_callback( + self, client, target, temp_rows + ): + """ + test batcher with batch completed callback. It should be called when the batcher flushes. + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + from google.rpc import code_pb2, status_pb2 + + import mock + + callback = mock.Mock() + + new_value = uuid.uuid4().hex.encode() + row_key, mutation = await self._create_row_and_mutation( + target, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + flush_interval = 0.1 + async with target.mutations_batcher( + flush_interval=flush_interval, _batch_completed_callback=callback + ) as batcher: + await batcher.append(bulk_mutation) + await CrossSync.yield_to_event_loop() + assert len(batcher._staged_entries) == 1 + await CrossSync.sleep(flush_interval + 0.1) + assert len(batcher._staged_entries) == 0 + callback.assert_called_once_with([status_pb2.Status(code=code_pb2.OK)]) + # ensure cell is updated + assert (await self._retrieve_cell_value(target, row_key)) == new_value + @pytest.mark.usefixtures("client") @pytest.mark.usefixtures("target") @CrossSync.Retry( diff --git a/tests/system/data/test_system_autogen.py b/tests/system/data/test_system_autogen.py index 44895808a..9a8f522d8 100644 --- a/tests/system/data/test_system_autogen.py +++ b/tests/system/data/test_system_autogen.py @@ -385,6 +385,35 @@ def test_mutations_batcher_timer_flush(self, client, target, temp_rows): assert len(batcher._staged_entries) == 0 assert self._retrieve_cell_value(target, row_key) == new_value + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("target") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutations_batcher_completed_callback(self, client, target, temp_rows): + """test batcher with batch completed callback. It should be called when the batcher flushes.""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + from google.rpc import code_pb2, status_pb2 + import mock + + callback = mock.Mock() + new_value = uuid.uuid4().hex.encode() + (row_key, mutation) = self._create_row_and_mutation( + target, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + flush_interval = 0.1 + with target.mutations_batcher( + flush_interval=flush_interval, _batch_completed_callback=callback + ) as batcher: + batcher.append(bulk_mutation) + CrossSync._Sync_Impl.yield_to_event_loop() + assert len(batcher._staged_entries) == 1 + CrossSync._Sync_Impl.sleep(flush_interval + 0.1) + assert len(batcher._staged_entries) == 0 + callback.assert_called_once_with([status_pb2.Status(code=code_pb2.OK)]) + assert self._retrieve_cell_value(target, row_key) == new_value + @pytest.mark.usefixtures("client") @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index e8df3eb42..01300ec7b 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -483,7 +483,7 @@ def test_default_argument_consistency(self): batcher_init_signature.pop("target") # both should have same number of arguments assert len(get_batcher_signature.keys()) == len(batcher_init_signature.keys()) - assert len(get_batcher_signature) == 8 # update if expected params change + assert len(get_batcher_signature) == 9 # update if expected params change # both should have same argument names assert set(get_batcher_signature.keys()) == set(batcher_init_signature.keys()) # both should have same default values @@ -962,8 +962,86 @@ async def test__execute_mutate_rows_returns_errors(self): table.default_mutate_rows_attempt_timeout = 13 table.default_mutate_rows_retryable_errors = () async with self._make_one(table) as instance: + batch = [self._make_mutation(), self._make_mutation()] + result = await instance._execute_mutate_rows(batch) + assert len(result) == 2 + assert result[0] == err1 + assert result[1] == err2 + # indices should be set to None + assert result[0].index is None + assert result[1].index is None + + @CrossSync.pytest + async def test__execute_mutate_rows_batch_completed_callback(self): + from google.rpc import code_pb2, status_pb2 + + with mock.patch.object(CrossSync, "_MutateRowsOperation") as mutate_rows: + mutate_rows.return_value = CrossSync.Mock() + start_operation = mutate_rows().start + table = mock.Mock() + table.table_name = "test-table" + table.app_profile_id = "test-app-profile" + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + callback = mock.Mock() + async with self._make_one( + table, _batch_completed_callback=callback + ) as instance: batch = [self._make_mutation()] result = await instance._execute_mutate_rows(batch) + callback.assert_called_once_with([status_pb2.Status(code=code_pb2.OK)]) + assert start_operation.call_count == 1 + args, kwargs = mutate_rows.call_args + assert args[0] == table.client._gapic_client + assert args[1] == table + assert args[2] == batch + kwargs["operation_timeout"] == 17 + kwargs["attempt_timeout"] == 13 + assert result == [] + + @CrossSync.pytest + async def test__execute_mutate_rows_batch_completed_callback_errors(self): + from google.api_core import exceptions + from google.cloud.bigtable.data.exceptions import ( + MutationsExceptionGroup, + FailedMutationEntryError, + ) + from google.rpc import code_pb2, status_pb2 + + with mock.patch.object(CrossSync._MutateRowsOperation, "start") as mutate_rows: + err1 = FailedMutationEntryError( + 1, mock.Mock(), exceptions.DataLoss("test error") + ) + err2 = FailedMutationEntryError( + 2, mock.Mock(), exceptions.DataLoss("test error") + ) + mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + callback = mock.Mock() + async with self._make_one( + table, _batch_completed_callback=callback + ) as instance: + batch = [ + self._make_mutation(), + self._make_mutation(), + self._make_mutation(), + ] + result = await instance._execute_mutate_rows(batch) + callback.assert_called_once_with( + [ + status_pb2.Status(code=code_pb2.OK), + status_pb2.Status( + code=code_pb2.DATA_LOSS, message="test error" + ), + status_pb2.Status( + code=code_pb2.DATA_LOSS, message="test error" + ), + ] + ) assert len(result) == 2 assert result[0] == err1 assert result[1] == err2 diff --git a/tests/unit/data/_sync_autogen/test_mutations_batcher.py b/tests/unit/data/_sync_autogen/test_mutations_batcher.py index 60a6708ba..6dfa3b9fa 100644 --- a/tests/unit/data/_sync_autogen/test_mutations_batcher.py +++ b/tests/unit/data/_sync_autogen/test_mutations_batcher.py @@ -430,7 +430,7 @@ def test_default_argument_consistency(self): ) batcher_init_signature.pop("target") assert len(get_batcher_signature.keys()) == len(batcher_init_signature.keys()) - assert len(get_batcher_signature) == 8 + assert len(get_batcher_signature) == 9 assert set(get_batcher_signature.keys()) == set(batcher_init_signature.keys()) for arg_name in get_batcher_signature.keys(): assert ( @@ -843,8 +843,83 @@ def test__execute_mutate_rows_returns_errors(self): table.default_mutate_rows_attempt_timeout = 13 table.default_mutate_rows_retryable_errors = () with self._make_one(table) as instance: + batch = [self._make_mutation(), self._make_mutation()] + result = instance._execute_mutate_rows(batch) + assert len(result) == 2 + assert result[0] == err1 + assert result[1] == err2 + assert result[0].index is None + assert result[1].index is None + + def test__execute_mutate_rows_batch_completed_callback(self): + from google.rpc import code_pb2, status_pb2 + + with mock.patch.object( + CrossSync._Sync_Impl, "_MutateRowsOperation" + ) as mutate_rows: + mutate_rows.return_value = CrossSync._Sync_Impl.Mock() + start_operation = mutate_rows().start + table = mock.Mock() + table.table_name = "test-table" + table.app_profile_id = "test-app-profile" + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + callback = mock.Mock() + with self._make_one(table, _batch_completed_callback=callback) as instance: batch = [self._make_mutation()] result = instance._execute_mutate_rows(batch) + callback.assert_called_once_with([status_pb2.Status(code=code_pb2.OK)]) + assert start_operation.call_count == 1 + (args, kwargs) = mutate_rows.call_args + assert args[0] == table.client._gapic_client + assert args[1] == table + assert args[2] == batch + kwargs["operation_timeout"] == 17 + kwargs["attempt_timeout"] == 13 + assert result == [] + + def test__execute_mutate_rows_batch_completed_callback_errors(self): + from google.api_core import exceptions + from google.cloud.bigtable.data.exceptions import ( + MutationsExceptionGroup, + FailedMutationEntryError, + ) + from google.rpc import code_pb2, status_pb2 + + with mock.patch.object( + CrossSync._Sync_Impl._MutateRowsOperation, "start" + ) as mutate_rows: + err1 = FailedMutationEntryError( + 1, mock.Mock(), exceptions.DataLoss("test error") + ) + err2 = FailedMutationEntryError( + 2, mock.Mock(), exceptions.DataLoss("test error") + ) + mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + callback = mock.Mock() + with self._make_one(table, _batch_completed_callback=callback) as instance: + batch = [ + self._make_mutation(), + self._make_mutation(), + self._make_mutation(), + ] + result = instance._execute_mutate_rows(batch) + callback.assert_called_once_with( + [ + status_pb2.Status(code=code_pb2.OK), + status_pb2.Status( + code=code_pb2.DATA_LOSS, message="test error" + ), + status_pb2.Status( + code=code_pb2.DATA_LOSS, message="test error" + ), + ] + ) assert len(result) == 2 assert result[0] == err1 assert result[1] == err2 diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py index 20ad972d7..ad8b2d298 100644 --- a/tests/unit/data/test__helpers.py +++ b/tests/unit/data/test__helpers.py @@ -16,8 +16,11 @@ import grpc from google.api_core import exceptions as core_exceptions import google.cloud.bigtable.data._helpers as _helpers +import google.cloud.bigtable.data.exceptions as bt_exceptions from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.rpc import code_pb2, status_pb2 + import mock @@ -264,6 +267,71 @@ def test_rst_stream_aware_predicate( assert predicate(exception) is expected_is_retryable +class TestPopulateStatusesFromMutationExceptionGroup: + @pytest.mark.parametrize( + "cause_exc,expected_status", + [ + ( + core_exceptions.DeadlineExceeded( + "Operation timed out after 40 seconds" + ), + status_pb2.Status( + code=code_pb2.DEADLINE_EXCEEDED, + message="Operation timed out after 40 seconds", + ), + ), + ( + RuntimeError("Something happened"), + status_pb2.Status(code=code_pb2.UNKNOWN, message="Something happened"), + ), + ( + bt_exceptions.RetryExceptionGroup( + excs=[ + core_exceptions.ServiceUnavailable("Service Unavailable"), + core_exceptions.ServiceUnavailable("Service Unavailable"), + core_exceptions.DeadlineExceeded( + "Operation timed out after 40 seconds" + ), + ] + ), + status_pb2.Status( + code=code_pb2.DEADLINE_EXCEEDED, + message="Operation timed out after 40 seconds", + ), + ), + ( + bt_exceptions.RetryExceptionGroup( + excs=[ + core_exceptions.ServiceUnavailable("Service Unavailable"), + core_exceptions.ServiceUnavailable("Service Unavailable"), + RuntimeError("Something happened"), + ] + ), + status_pb2.Status(code=code_pb2.UNKNOWN, message="Something happened"), + ), + ], + ) + def test_populate_statuses_from_mutation_exception_group( + self, cause_exc, expected_status + ): + statuses = [status_pb2.Status(code=code_pb2.OK)] + + mutation_exception_group = bt_exceptions.MutationsExceptionGroup( + excs=[ + bt_exceptions.FailedMutationEntryError( + failed_idx=0, failed_mutation_entry=mock.Mock(), cause=cause_exc + ) + ], + total_entries=1, + message="Mutations failed.", + ) + + _helpers._populate_statuses_from_mutations_exception_group( + statuses, mutation_exception_group + ) + assert statuses[0] == expected_status + + class TestGetRetryableErrors: @pytest.mark.parametrize( "input_codes,input_table,expected", From 4e3d8bba1eb46c8cadf765e5c3cc323945bf9525 Mon Sep 17 00:00:00 2001 From: Kevin Zheng Date: Mon, 9 Mar 2026 16:39:52 +0000 Subject: [PATCH 2/9] Added input vvalidation + unit tests --- google/cloud/bigtable/data/_helpers.py | 11 ++++---- .../data/_async/test_mutations_batcher.py | 8 +++--- .../_sync_autogen/test_mutations_batcher.py | 8 +++--- tests/unit/data/test__helpers.py | 25 +++++++++++++++++++ 4 files changed, 39 insertions(+), 13 deletions(-) diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index 5a6dd888d..eaa5a2b36 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -251,11 +251,12 @@ def _populate_statuses_from_mutations_exception_group( # 3. In the case of a RetryExceptionGroup, we use terminal exception in the exception # group and process that. for error in exc_group.exceptions: - cause = error.__cause__ - if isinstance(cause, RetryExceptionGroup): - statuses[error.index] = _get_status(cause.exceptions[-1]) - else: - statuses[error.index] = _get_status(cause) + if isinstance(error.index, int) and 0 <= error.index < len(statuses): + cause = error.__cause__ + if isinstance(cause, RetryExceptionGroup): + statuses[error.index] = _get_status(cause.exceptions[-1]) + else: + statuses[error.index] = _get_status(cause) def _get_status(exc: Exception) -> status_pb2.Status: diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 01300ec7b..c1a1289f9 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -941,8 +941,8 @@ async def test__execute_mutate_rows(self): assert args[0] == table.client._gapic_client assert args[1] == table assert args[2] == batch - kwargs["operation_timeout"] == 17 - kwargs["attempt_timeout"] == 13 + assert kwargs["operation_timeout"] == 17 + assert kwargs["attempt_timeout"] == 13 assert result == [] @CrossSync.pytest @@ -996,8 +996,8 @@ async def test__execute_mutate_rows_batch_completed_callback(self): assert args[0] == table.client._gapic_client assert args[1] == table assert args[2] == batch - kwargs["operation_timeout"] == 17 - kwargs["attempt_timeout"] == 13 + assert kwargs["operation_timeout"] == 17 + assert kwargs["attempt_timeout"] == 13 assert result == [] @CrossSync.pytest diff --git a/tests/unit/data/_sync_autogen/test_mutations_batcher.py b/tests/unit/data/_sync_autogen/test_mutations_batcher.py index 6dfa3b9fa..d7b47ae9d 100644 --- a/tests/unit/data/_sync_autogen/test_mutations_batcher.py +++ b/tests/unit/data/_sync_autogen/test_mutations_batcher.py @@ -821,8 +821,8 @@ def test__execute_mutate_rows(self): assert args[0] == table.client._gapic_client assert args[1] == table assert args[2] == batch - kwargs["operation_timeout"] == 17 - kwargs["attempt_timeout"] == 13 + assert kwargs["operation_timeout"] == 17 + assert kwargs["attempt_timeout"] == 13 assert result == [] def test__execute_mutate_rows_returns_errors(self): @@ -875,8 +875,8 @@ def test__execute_mutate_rows_batch_completed_callback(self): assert args[0] == table.client._gapic_client assert args[1] == table assert args[2] == batch - kwargs["operation_timeout"] == 17 - kwargs["attempt_timeout"] == 13 + assert kwargs["operation_timeout"] == 17 + assert kwargs["attempt_timeout"] == 13 assert result == [] def test__execute_mutate_rows_batch_completed_callback_errors(self): diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py index ad8b2d298..6eeebe6e1 100644 --- a/tests/unit/data/test__helpers.py +++ b/tests/unit/data/test__helpers.py @@ -331,6 +331,31 @@ def test_populate_statuses_from_mutation_exception_group( ) assert statuses[0] == expected_status + @pytest.mark.parametrize( + "index", + [ + 100, + None, + ], + ) + def test_populate_statuses_from_mutation_exception_group_out_of_bounds(self, index): + statuses = [status_pb2.Status(code=code_pb2.OK)] + + mutation_exception_group = bt_exceptions.MutationsExceptionGroup( + excs=[ + bt_exceptions.FailedMutationEntryError( + failed_idx=index, failed_mutation_entry=mock.Mock(), cause=Exception("Boom!") + ) + ], + total_entries=1, + message="Mutations failed.", + ) + + _helpers._populate_statuses_from_mutations_exception_group( + statuses, mutation_exception_group + ) + assert statuses[0] == status_pb2.Status(code=code_pb2.OK) + class TestGetRetryableErrors: @pytest.mark.parametrize( From f03977c304a04c3c0b5fd8dc930c20f0eb1cff4e Mon Sep 17 00:00:00 2001 From: Kevin Zheng Date: Mon, 9 Mar 2026 19:07:53 +0000 Subject: [PATCH 3/9] linting --- tests/unit/data/test__helpers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py index 6eeebe6e1..78662e584 100644 --- a/tests/unit/data/test__helpers.py +++ b/tests/unit/data/test__helpers.py @@ -344,7 +344,9 @@ def test_populate_statuses_from_mutation_exception_group_out_of_bounds(self, ind mutation_exception_group = bt_exceptions.MutationsExceptionGroup( excs=[ bt_exceptions.FailedMutationEntryError( - failed_idx=index, failed_mutation_entry=mock.Mock(), cause=Exception("Boom!") + failed_idx=index, + failed_mutation_entry=mock.Mock(), + cause=Exception("Boom!"), ) ], total_entries=1, From 701eebb9fd65009dd5acc4d2ed5cad02e11ed3f3 Mon Sep 17 00:00:00 2001 From: Kevin Zheng Date: Mon, 9 Mar 2026 21:15:35 +0000 Subject: [PATCH 4/9] fixed mypy --- google/cloud/bigtable/data/_async/client.py | 4 +++- google/cloud/bigtable/data/_async/mutations_batcher.py | 4 +++- google/cloud/bigtable/data/_helpers.py | 2 +- google/cloud/bigtable/data/_sync_autogen/client.py | 4 +++- google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py | 4 +++- 5 files changed, 13 insertions(+), 5 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 957faf7a1..34666baaa 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -1439,7 +1439,9 @@ def mutations_batcher( batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - _batch_completed_callback: Optional[Callable[list[status_pb2.Status]]] = None, + _batch_completed_callback: Optional[ + Callable[[list[status_pb2.Status]], None] + ] = None, ) -> "MutationsBatcherAsync": """ Returns a new mutations batcher instance. diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 2fe9cab94..d76547844 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -230,7 +230,9 @@ def __init__( batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - _batch_completed_callback: Optional[Callable[list[status_pb2.Status]]] = None, + _batch_completed_callback: Optional[ + Callable[[list[status_pb2.Status]], None] + ] = None, ): self._operation_timeout, self._attempt_timeout = _get_timeouts( batch_operation_timeout, batch_attempt_timeout, target diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index eaa5a2b36..63329b5ec 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -272,7 +272,7 @@ def _get_status(exc: Exception) -> status_pb2.Status: isinstance(exc, core_exceptions.GoogleAPICallError) and exc.grpc_status_code is not None ): - return status_pb2.Status( + return status_pb2.Status( # type: ignore[unreachable] code=exc.grpc_status_code.value[0], message=exc.message, details=exc.details, diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index 78c5969da..96a71d2ce 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -1191,7 +1191,9 @@ def mutations_batcher( batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - _batch_completed_callback: Optional[Callable[list[status_pb2.Status]]] = None, + _batch_completed_callback: Optional[ + Callable[[list[status_pb2.Status]], None] + ] = None, ) -> "MutationsBatcher": """Returns a new mutations batcher instance. diff --git a/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py index a9ae22d42..79122e770 100644 --- a/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py @@ -197,7 +197,9 @@ def __init__( batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - _batch_completed_callback: Optional[Callable[list[status_pb2.Status]]] = None, + _batch_completed_callback: Optional[ + Callable[[list[status_pb2.Status]], None] + ] = None, ): (self._operation_timeout, self._attempt_timeout) = _get_timeouts( batch_operation_timeout, batch_attempt_timeout, target From 8320f9dbb63b34ebeb37b4ec2030cd9d713a9ecb Mon Sep 17 00:00:00 2001 From: Kevin Zheng Date: Tue, 10 Mar 2026 17:00:49 +0000 Subject: [PATCH 5/9] Fixed helper function and mutations batcher constructor --- google/cloud/bigtable/data/_async/client.py | 3 - .../bigtable/data/_async/mutations_batcher.py | 13 ++-- google/cloud/bigtable/data/_helpers.py | 12 ++-- .../bigtable/data/_sync_autogen/client.py | 6 +- .../data/_sync_autogen/mutations_batcher.py | 15 +++-- google/cloud/bigtable/table.py | 15 +++-- .../data/_async/test_mutations_batcher.py | 12 ++-- .../_sync_autogen/test_mutations_batcher.py | 8 ++- tests/unit/data/test__helpers.py | 59 ++++++++----------- 9 files changed, 64 insertions(+), 79 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 34666baaa..be62dcd28 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -1439,9 +1439,6 @@ def mutations_batcher( batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - _batch_completed_callback: Optional[ - Callable[[list[status_pb2.Status]], None] - ] = None, ) -> "MutationsBatcherAsync": """ Returns a new mutations batcher instance. diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index d76547844..6d6aefc9a 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -25,7 +25,7 @@ from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import ( - _populate_statuses_from_mutations_exception_group, + _get_statuses_from_mutations_exception_group, ) from google.cloud.bigtable.data._helpers import TABLE_DEFAULT @@ -230,9 +230,6 @@ def __init__( batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - _batch_completed_callback: Optional[ - Callable[[list[status_pb2.Status]], None] - ] = None, ): self._operation_timeout, self._attempt_timeout = _get_timeouts( batch_operation_timeout, batch_attempt_timeout, target @@ -279,7 +276,7 @@ def __init__( self._newest_exceptions: deque[Exception] = deque( maxlen=self._exception_list_limit ) - self._user_batch_completed_callback = _batch_completed_callback + self._user_batch_completed_callback = None # clean up on program exit atexit.register(self._on_exit) @@ -391,7 +388,7 @@ async def _execute_mutate_rows( list of FailedMutationEntryError objects for mutations that failed. FailedMutationEntryError objects will not contain index information """ - statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len(batch) + statuses = [status_pb2.Status(code=code_pb2.Code.UNKNOWN)] * len(batch) try: operation = CrossSync._MutateRowsOperation( self._target.client._gapic_client, @@ -403,12 +400,14 @@ async def _execute_mutate_rows( ) await operation.start() except MutationsExceptionGroup as e: - _populate_statuses_from_mutations_exception_group(statuses, e) + statuses = _get_statuses_from_mutations_exception_group(e, len(batch)) # strip index information from exceptions, since it is not useful in a batch context for subexc in e.exceptions: subexc.index = None return list(e.exceptions) + else: + statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len(batch) finally: # mark batch as complete in flow control await self._flow_control.remove_from_flow(batch) diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index 63329b5ec..6ecbd6662 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -228,16 +228,18 @@ def _align_timeouts(operation: float, attempt: float | None) -> tuple[float, flo return operation, final_attempt -def _populate_statuses_from_mutations_exception_group( - statuses: list[status_pb2.Status], exc_group: MutationsExceptionGroup -): +def _get_statuses_from_mutations_exception_group( + exc_group: MutationsExceptionGroup, batch_size: int +) -> list[status_pb2.Status]: """ Helper function that populates a list of Status objects with exception information from the exception group. Args: - statuses: The initial list of Status objects exc_group: The exception group from a mutate rows operation + batch_size: How many RowMutationGroups were provided to the batch + Returns: + list[status_pb2.Status]: A list of Status proto objects """ # We exception handle as follows: # @@ -250,6 +252,7 @@ def _populate_statuses_from_mutations_exception_group( # # 3. In the case of a RetryExceptionGroup, we use terminal exception in the exception # group and process that. + statuses = [status_pb2.Status(code=code_pb2.OK)] * batch_size for error in exc_group.exceptions: if isinstance(error.index, int) and 0 <= error.index < len(statuses): cause = error.__cause__ @@ -257,6 +260,7 @@ def _populate_statuses_from_mutations_exception_group( statuses[error.index] = _get_status(cause.exceptions[-1]) else: statuses[error.index] = _get_status(cause) + return statuses def _get_status(exc: Exception) -> status_pb2.Status: diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index 96a71d2ce..9c5b90adf 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -17,7 +17,7 @@ # This file is automatically generated by CrossSync. Do not edit manually. from __future__ import annotations -from typing import Callable, cast, Any, Optional, Set, Sequence, TYPE_CHECKING +from typing import cast, Any, Optional, Set, Sequence, TYPE_CHECKING import abc import time import warnings @@ -87,7 +87,6 @@ if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery - from google.rpc import status_pb2 from google.cloud.bigtable.data._sync_autogen.mutations_batcher import ( MutationsBatcher, ) @@ -1191,9 +1190,6 @@ def mutations_batcher( batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - _batch_completed_callback: Optional[ - Callable[[list[status_pb2.Status]], None] - ] = None, ) -> "MutationsBatcher": """Returns a new mutations batcher instance. diff --git a/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py index 79122e770..c72a79fc7 100644 --- a/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py @@ -16,7 +16,7 @@ # This file is automatically generated by CrossSync. Do not edit manually. from __future__ import annotations -from typing import Callable, Optional, Sequence, TYPE_CHECKING, cast +from typing import Sequence, TYPE_CHECKING, cast import atexit import warnings from collections import deque @@ -26,7 +26,7 @@ from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import ( - _populate_statuses_from_mutations_exception_group, + _get_statuses_from_mutations_exception_group, ) from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT @@ -197,9 +197,6 @@ def __init__( batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - _batch_completed_callback: Optional[ - Callable[[list[status_pb2.Status]], None] - ] = None, ): (self._operation_timeout, self._attempt_timeout) = _get_timeouts( batch_operation_timeout, batch_attempt_timeout, target @@ -241,7 +238,7 @@ def __init__( self._newest_exceptions: deque[Exception] = deque( maxlen=self._exception_list_limit ) - self._user_batch_completed_callback = _batch_completed_callback + self._user_batch_completed_callback = None atexit.register(self._on_exit) def _timer_routine(self, interval: float | None) -> None: @@ -333,7 +330,7 @@ def _execute_mutate_rows( list[FailedMutationEntryError]: list of FailedMutationEntryError objects for mutations that failed. FailedMutationEntryError objects will not contain index information""" - statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len(batch) + statuses = [status_pb2.Status(code=code_pb2.Code.UNKNOWN)] * len(batch) try: operation = CrossSync._Sync_Impl._MutateRowsOperation( self._target.client._gapic_client, @@ -345,10 +342,12 @@ def _execute_mutate_rows( ) operation.start() except MutationsExceptionGroup as e: - _populate_statuses_from_mutations_exception_group(statuses, e) + statuses = _get_statuses_from_mutations_exception_group(e, len(batch)) for subexc in e.exceptions: subexc.index = None return list(e.exceptions) + else: + statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len(batch) finally: self._flow_control.remove_from_flow(batch) if self._user_batch_completed_callback: diff --git a/google/cloud/bigtable/table.py b/google/cloud/bigtable/table.py index b478d3f01..5a99999bc 100644 --- a/google/cloud/bigtable/table.py +++ b/google/cloud/bigtable/table.py @@ -31,7 +31,7 @@ from google.cloud.bigtable.column_family import ColumnFamily from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data._helpers import ( - _populate_statuses_from_mutations_exception_group, + _get_statuses_from_mutations_exception_group, ) from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -761,9 +761,9 @@ def mutate_rows(self, rows, retry=DEFAULT_RETRY, timeout=DEFAULT): mutation_entries = [ RowMutationEntry(row.row_key, row._get_mutations()) for row in rows ] - return_statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len( + return_statuses = [status_pb2.Status(code=code_pb2.Code.UNKNOWN)] * len( mutation_entries - ) # By default, return status OKs for everything + ) try: self._table_impl.bulk_mutate_rows( @@ -773,9 +773,12 @@ def mutate_rows(self, rows, retry=DEFAULT_RETRY, timeout=DEFAULT): retryable_errors=retryable_errors, ) except MutationsExceptionGroup as mut_exc_group: - _populate_statuses_from_mutations_exception_group( - return_statuses, - mut_exc_group, + return_statuses = _get_statuses_from_mutations_exception_group( + mut_exc_group, len(mutation_entries) + ) + else: + return_statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len( + mutation_entries ) return return_statuses diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index c1a1289f9..9a06d692f 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -483,7 +483,7 @@ def test_default_argument_consistency(self): batcher_init_signature.pop("target") # both should have same number of arguments assert len(get_batcher_signature.keys()) == len(batcher_init_signature.keys()) - assert len(get_batcher_signature) == 9 # update if expected params change + assert len(get_batcher_signature) == 8 # update if expected params change # both should have same argument names assert set(get_batcher_signature.keys()) == set(batcher_init_signature.keys()) # both should have same default values @@ -985,9 +985,8 @@ async def test__execute_mutate_rows_batch_completed_callback(self): table.default_mutate_rows_attempt_timeout = 13 table.default_mutate_rows_retryable_errors = () callback = mock.Mock() - async with self._make_one( - table, _batch_completed_callback=callback - ) as instance: + async with self._make_one(table) as instance: + instance._user_batch_completed_callback = callback batch = [self._make_mutation()] result = await instance._execute_mutate_rows(batch) callback.assert_called_once_with([status_pb2.Status(code=code_pb2.OK)]) @@ -1022,9 +1021,8 @@ async def test__execute_mutate_rows_batch_completed_callback_errors(self): table.default_mutate_rows_attempt_timeout = 13 table.default_mutate_rows_retryable_errors = () callback = mock.Mock() - async with self._make_one( - table, _batch_completed_callback=callback - ) as instance: + async with self._make_one(table) as instance: + instance._user_batch_completed_callback = callback batch = [ self._make_mutation(), self._make_mutation(), diff --git a/tests/unit/data/_sync_autogen/test_mutations_batcher.py b/tests/unit/data/_sync_autogen/test_mutations_batcher.py index d7b47ae9d..413c7b955 100644 --- a/tests/unit/data/_sync_autogen/test_mutations_batcher.py +++ b/tests/unit/data/_sync_autogen/test_mutations_batcher.py @@ -430,7 +430,7 @@ def test_default_argument_consistency(self): ) batcher_init_signature.pop("target") assert len(get_batcher_signature.keys()) == len(batcher_init_signature.keys()) - assert len(get_batcher_signature) == 9 + assert len(get_batcher_signature) == 8 assert set(get_batcher_signature.keys()) == set(batcher_init_signature.keys()) for arg_name in get_batcher_signature.keys(): assert ( @@ -866,7 +866,8 @@ def test__execute_mutate_rows_batch_completed_callback(self): table.default_mutate_rows_attempt_timeout = 13 table.default_mutate_rows_retryable_errors = () callback = mock.Mock() - with self._make_one(table, _batch_completed_callback=callback) as instance: + with self._make_one(table) as instance: + instance._user_batch_completed_callback = callback batch = [self._make_mutation()] result = instance._execute_mutate_rows(batch) callback.assert_called_once_with([status_pb2.Status(code=code_pb2.OK)]) @@ -902,7 +903,8 @@ def test__execute_mutate_rows_batch_completed_callback_errors(self): table.default_mutate_rows_attempt_timeout = 13 table.default_mutate_rows_retryable_errors = () callback = mock.Mock() - with self._make_one(table, _batch_completed_callback=callback) as instance: + with self._make_one(table) as instance: + instance._user_batch_completed_callback = callback batch = [ self._make_mutation(), self._make_mutation(), diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py index 78662e584..07d0692f1 100644 --- a/tests/unit/data/test__helpers.py +++ b/tests/unit/data/test__helpers.py @@ -267,11 +267,12 @@ def test_rst_stream_aware_predicate( assert predicate(exception) is expected_is_retryable -class TestPopulateStatusesFromMutationExceptionGroup: +class TestGetStatusesFromMutationsExceptionGroup: @pytest.mark.parametrize( - "cause_exc,expected_status", + "failed_idx,cause_exc,expected_status", [ ( + 0, core_exceptions.DeadlineExceeded( "Operation timed out after 40 seconds" ), @@ -281,10 +282,12 @@ class TestPopulateStatusesFromMutationExceptionGroup: ), ), ( + 0, RuntimeError("Something happened"), status_pb2.Status(code=code_pb2.UNKNOWN, message="Something happened"), ), ( + 0, bt_exceptions.RetryExceptionGroup( excs=[ core_exceptions.ServiceUnavailable("Service Unavailable"), @@ -300,6 +303,7 @@ class TestPopulateStatusesFromMutationExceptionGroup: ), ), ( + 0, bt_exceptions.RetryExceptionGroup( excs=[ core_exceptions.ServiceUnavailable("Service Unavailable"), @@ -309,54 +313,37 @@ class TestPopulateStatusesFromMutationExceptionGroup: ), status_pb2.Status(code=code_pb2.UNKNOWN, message="Something happened"), ), + ( + 100, + RuntimeError("Something happened"), + status_pb2.Status(code=code_pb2.OK), + ), + ( + None, + RuntimeError("Something happened"), + status_pb2.Status(code=code_pb2.OK), + ), ], ) - def test_populate_statuses_from_mutation_exception_group( - self, cause_exc, expected_status + def test_get_statuses_from_mutations_exception_group( + self, failed_idx, cause_exc, expected_status ): - statuses = [status_pb2.Status(code=code_pb2.OK)] - mutation_exception_group = bt_exceptions.MutationsExceptionGroup( excs=[ bt_exceptions.FailedMutationEntryError( - failed_idx=0, failed_mutation_entry=mock.Mock(), cause=cause_exc - ) - ], - total_entries=1, - message="Mutations failed.", - ) - - _helpers._populate_statuses_from_mutations_exception_group( - statuses, mutation_exception_group - ) - assert statuses[0] == expected_status - - @pytest.mark.parametrize( - "index", - [ - 100, - None, - ], - ) - def test_populate_statuses_from_mutation_exception_group_out_of_bounds(self, index): - statuses = [status_pb2.Status(code=code_pb2.OK)] - - mutation_exception_group = bt_exceptions.MutationsExceptionGroup( - excs=[ - bt_exceptions.FailedMutationEntryError( - failed_idx=index, + failed_idx=failed_idx, failed_mutation_entry=mock.Mock(), - cause=Exception("Boom!"), + cause=cause_exc, ) ], total_entries=1, message="Mutations failed.", ) - _helpers._populate_statuses_from_mutations_exception_group( - statuses, mutation_exception_group + statuses = _helpers._get_statuses_from_mutations_exception_group( + mutation_exception_group, 1 ) - assert statuses[0] == status_pb2.Status(code=code_pb2.OK) + assert statuses[0] == expected_status class TestGetRetryableErrors: From 0741d0b397a251d4df6d57c203b059ee8c5d5df7 Mon Sep 17 00:00:00 2001 From: Kevin Zheng Date: Tue, 10 Mar 2026 18:38:27 +0000 Subject: [PATCH 6/9] None case in _get_status --- google/cloud/bigtable/data/_async/client.py | 1 - google/cloud/bigtable/data/_helpers.py | 6 +++--- .../cloud/bigtable/data/_sync_autogen/client.py | 1 - tests/unit/data/test__helpers.py | 16 ++++++++++++++++ 4 files changed, 19 insertions(+), 5 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index be62dcd28..61f932735 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -1474,7 +1474,6 @@ def mutations_batcher( batch_operation_timeout=batch_operation_timeout, batch_attempt_timeout=batch_attempt_timeout, batch_retryable_errors=batch_retryable_errors, - _batch_completed_callback=_batch_completed_callback, ) @CrossSync.convert diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index 6ecbd6662..0f411f88a 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -16,7 +16,7 @@ """ from __future__ import annotations -from typing import Callable, Sequence, List, Tuple, TYPE_CHECKING, Union +from typing import Callable, Sequence, List, Optional, Tuple, TYPE_CHECKING, Union import time import enum from collections import namedtuple @@ -263,7 +263,7 @@ def _get_statuses_from_mutations_exception_group( return statuses -def _get_status(exc: Exception) -> status_pb2.Status: +def _get_status(exc: Optional[Exception]) -> status_pb2.Status: """ Helper function that returns a Status object corresponding to the given exception. @@ -284,7 +284,7 @@ def _get_status(exc: Exception) -> status_pb2.Status: return status_pb2.Status( code=code_pb2.Code.UNKNOWN, - message=str(exc), + message=str(exc) if exc else "An unknown error has occurred", ) diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index 9c5b90adf..88136ddad 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -1224,7 +1224,6 @@ def mutations_batcher( batch_operation_timeout=batch_operation_timeout, batch_attempt_timeout=batch_attempt_timeout, batch_retryable_errors=batch_retryable_errors, - _batch_completed_callback=_batch_completed_callback, ) def mutate_row( diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py index 07d0692f1..875d3c9e1 100644 --- a/tests/unit/data/test__helpers.py +++ b/tests/unit/data/test__helpers.py @@ -313,6 +313,22 @@ class TestGetStatusesFromMutationsExceptionGroup: ), status_pb2.Status(code=code_pb2.UNKNOWN, message="Something happened"), ), + ( + 0, + bt_exceptions.RetryExceptionGroup( + excs=[ + core_exceptions.ServiceUnavailable("Service Unavailable"), + core_exceptions.ServiceUnavailable("Service Unavailable"), + None, + ] + ), + status_pb2.Status(code=code_pb2.UNKNOWN, message="An unknown error has occurred"), + ), + ( + 0, + None, + status_pb2.Status(code=code_pb2.UNKNOWN, message="An unknown error has occurred"), + ), ( 100, RuntimeError("Something happened"), From 89935392c71874786f744a6d5839bd3c0edcb3aa Mon Sep 17 00:00:00 2001 From: Kevin Zheng Date: Tue, 10 Mar 2026 18:51:52 +0000 Subject: [PATCH 7/9] Fixed unit tests --- tests/unit/data/test__helpers.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py index 875d3c9e1..7cd450295 100644 --- a/tests/unit/data/test__helpers.py +++ b/tests/unit/data/test__helpers.py @@ -313,17 +313,6 @@ class TestGetStatusesFromMutationsExceptionGroup: ), status_pb2.Status(code=code_pb2.UNKNOWN, message="Something happened"), ), - ( - 0, - bt_exceptions.RetryExceptionGroup( - excs=[ - core_exceptions.ServiceUnavailable("Service Unavailable"), - core_exceptions.ServiceUnavailable("Service Unavailable"), - None, - ] - ), - status_pb2.Status(code=code_pb2.UNKNOWN, message="An unknown error has occurred"), - ), ( 0, None, From 20fb4cb8a225d441beb2d1e850a42141a50d97ca Mon Sep 17 00:00:00 2001 From: Kevin Zheng Date: Tue, 10 Mar 2026 19:25:48 +0000 Subject: [PATCH 8/9] fixed system tests --- tests/system/data/test_system_async.py | 5 ++--- tests/system/data/test_system_autogen.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 4aa589027..69680326d 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -507,9 +507,8 @@ async def test_mutations_batcher_completed_callback( ) bulk_mutation = RowMutationEntry(row_key, [mutation]) flush_interval = 0.1 - async with target.mutations_batcher( - flush_interval=flush_interval, _batch_completed_callback=callback - ) as batcher: + async with target.mutations_batcher(flush_interval=flush_interval) as batcher: + batcher._user_batch_completed_callback = callback await batcher.append(bulk_mutation) await CrossSync.yield_to_event_loop() assert len(batcher._staged_entries) == 1 diff --git a/tests/system/data/test_system_autogen.py b/tests/system/data/test_system_autogen.py index 9a8f522d8..0b954fd66 100644 --- a/tests/system/data/test_system_autogen.py +++ b/tests/system/data/test_system_autogen.py @@ -403,9 +403,8 @@ def test_mutations_batcher_completed_callback(self, client, target, temp_rows): ) bulk_mutation = RowMutationEntry(row_key, [mutation]) flush_interval = 0.1 - with target.mutations_batcher( - flush_interval=flush_interval, _batch_completed_callback=callback - ) as batcher: + with target.mutations_batcher(flush_interval=flush_interval) as batcher: + batcher._user_batch_completed_callback = callback batcher.append(bulk_mutation) CrossSync._Sync_Impl.yield_to_event_loop() assert len(batcher._staged_entries) == 1 From c4da6eb92fb8c82a0d26ac3248370df0c5ed0037 Mon Sep 17 00:00:00 2001 From: Kevin Zheng Date: Tue, 10 Mar 2026 20:16:35 +0000 Subject: [PATCH 9/9] linting --- tests/unit/data/test__helpers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py index 7cd450295..6b6ecdb3a 100644 --- a/tests/unit/data/test__helpers.py +++ b/tests/unit/data/test__helpers.py @@ -316,7 +316,9 @@ class TestGetStatusesFromMutationsExceptionGroup: ( 0, None, - status_pb2.Status(code=code_pb2.UNKNOWN, message="An unknown error has occurred"), + status_pb2.Status( + code=code_pb2.UNKNOWN, message="An unknown error has occurred" + ), ), ( 100,