diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 35fe42814..61f932735 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -16,6 +16,7 @@ from __future__ import annotations from typing import ( + Callable, cast, Any, AsyncIterable, @@ -115,6 +116,7 @@ if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery + from google.rpc import status_pb2 if CrossSync.is_async: from google.cloud.bigtable.data._async.mutations_batcher import ( diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 6d87ff5d2..6d6aefc9a 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -14,7 +14,7 @@ # from __future__ import annotations -from typing import Sequence, TYPE_CHECKING, cast +from typing import Callable, Optional, Sequence, TYPE_CHECKING, cast import atexit import warnings from collections import deque @@ -24,6 +24,10 @@ from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts +from google.cloud.bigtable.data._helpers import ( + _get_statuses_from_mutations_exception_group, +) + from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data.mutations import ( @@ -33,6 +37,9 @@ from google.cloud.bigtable.data._cross_sync import CrossSync +from google.rpc import code_pb2 +from google.rpc import status_pb2 + if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -269,6 +276,7 @@ def __init__( self._newest_exceptions: deque[Exception] = deque( maxlen=self._exception_list_limit ) + self._user_batch_completed_callback = None # clean up on program exit atexit.register(self._on_exit) @@ -380,6 +388,7 @@ async def _execute_mutate_rows( list of FailedMutationEntryError objects for mutations that failed. FailedMutationEntryError objects will not contain index information """ + statuses = [status_pb2.Status(code=code_pb2.Code.UNKNOWN)] * len(batch) try: operation = CrossSync._MutateRowsOperation( self._target.client._gapic_client, @@ -391,13 +400,21 @@ async def _execute_mutate_rows( ) await operation.start() except MutationsExceptionGroup as e: + statuses = _get_statuses_from_mutations_exception_group(e, len(batch)) + # strip index information from exceptions, since it is not useful in a batch context for subexc in e.exceptions: subexc.index = None return list(e.exceptions) + else: + statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len(batch) finally: # mark batch as complete in flow control await self._flow_control.remove_from_flow(batch) + + # Call batch done callback with list of statuses. + if self._user_batch_completed_callback: + self._user_batch_completed_callback(statuses) return [] def _add_exceptions(self, excs: list[Exception]): diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index f88fa3078..0f411f88a 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -16,7 +16,7 @@ """ from __future__ import annotations -from typing import Callable, Sequence, List, Tuple, TYPE_CHECKING, Union +from typing import Callable, Sequence, List, Optional, Tuple, TYPE_CHECKING, Union import time import enum from collections import namedtuple @@ -26,6 +26,10 @@ from google.api_core import retry as retries from google.api_core.retry import RetryFailureReason from google.cloud.bigtable.data.exceptions import RetryExceptionGroup +from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup +from google.rpc import code_pb2 +from google.rpc import status_pb2 + if TYPE_CHECKING: import grpc @@ -224,6 +228,66 @@ def _align_timeouts(operation: float, attempt: float | None) -> tuple[float, flo return operation, final_attempt +def _get_statuses_from_mutations_exception_group( + exc_group: MutationsExceptionGroup, batch_size: int +) -> list[status_pb2.Status]: + """ + Helper function that populates a list of Status objects with exception information from + the exception group. + + Args: + exc_group: The exception group from a mutate rows operation + batch_size: How many RowMutationGroups were provided to the batch + Returns: + list[status_pb2.Status]: A list of Status proto objects + """ + # We exception handle as follows: + # + # 1. Each exception in the error group is a FailedMutationEntryError, and its + # cause is either a singular exception or a RetryExceptionGroup consisting of + # multiple exceptions. + # + # 2. In the case of a singular exception, if the error does not have a gRPC status + # code, we return a status code of UNKNOWN. + # + # 3. In the case of a RetryExceptionGroup, we use terminal exception in the exception + # group and process that. + statuses = [status_pb2.Status(code=code_pb2.OK)] * batch_size + for error in exc_group.exceptions: + if isinstance(error.index, int) and 0 <= error.index < len(statuses): + cause = error.__cause__ + if isinstance(cause, RetryExceptionGroup): + statuses[error.index] = _get_status(cause.exceptions[-1]) + else: + statuses[error.index] = _get_status(cause) + return statuses + + +def _get_status(exc: Optional[Exception]) -> status_pb2.Status: + """ + Helper function that returns a Status object corresponding to the given exception. + + Args: + exc: An exception to be converted into a Status. + Returns: + status_pb2.Status: A Status proto object. + """ + if ( + isinstance(exc, core_exceptions.GoogleAPICallError) + and exc.grpc_status_code is not None + ): + return status_pb2.Status( # type: ignore[unreachable] + code=exc.grpc_status_code.value[0], + message=exc.message, + details=exc.details, + ) + + return status_pb2.Status( + code=code_pb2.Code.UNKNOWN, + message=str(exc) if exc else "An unknown error has occurred", + ) + + def _validate_timeouts( operation_timeout: float, attempt_timeout: float | None, allow_none: bool = False ): diff --git a/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py index c14606de9..c72a79fc7 100644 --- a/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py @@ -25,10 +25,15 @@ from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts +from google.cloud.bigtable.data._helpers import ( + _get_statuses_from_mutations_exception_group, +) from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT from google.cloud.bigtable.data.mutations import Mutation from google.cloud.bigtable.data._cross_sync import CrossSync +from google.rpc import code_pb2 +from google.rpc import status_pb2 if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -233,6 +238,7 @@ def __init__( self._newest_exceptions: deque[Exception] = deque( maxlen=self._exception_list_limit ) + self._user_batch_completed_callback = None atexit.register(self._on_exit) def _timer_routine(self, interval: float | None) -> None: @@ -324,6 +330,7 @@ def _execute_mutate_rows( list[FailedMutationEntryError]: list of FailedMutationEntryError objects for mutations that failed. FailedMutationEntryError objects will not contain index information""" + statuses = [status_pb2.Status(code=code_pb2.Code.UNKNOWN)] * len(batch) try: operation = CrossSync._Sync_Impl._MutateRowsOperation( self._target.client._gapic_client, @@ -335,11 +342,16 @@ def _execute_mutate_rows( ) operation.start() except MutationsExceptionGroup as e: + statuses = _get_statuses_from_mutations_exception_group(e, len(batch)) for subexc in e.exceptions: subexc.index = None return list(e.exceptions) + else: + statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len(batch) finally: self._flow_control.remove_from_flow(batch) + if self._user_batch_completed_callback: + self._user_batch_completed_callback(statuses) return [] def _add_exceptions(self, excs: list[Exception]): diff --git a/google/cloud/bigtable/table.py b/google/cloud/bigtable/table.py index 465383b60..ee1d335ea 100644 --- a/google/cloud/bigtable/table.py +++ b/google/cloud/bigtable/table.py @@ -17,7 +17,6 @@ from typing import Set import warnings -from google.api_core.exceptions import GoogleAPICallError from google.api_core.exceptions import Aborted from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import NotFound @@ -31,12 +30,12 @@ from google.cloud.bigtable.column_family import _gc_rule_from_pb from google.cloud.bigtable.column_family import ColumnFamily from google.cloud.bigtable.data._helpers import TABLE_DEFAULT -from google.cloud.bigtable.data.exceptions import ( - RetryExceptionGroup, - MutationsExceptionGroup, +from google.cloud.bigtable.data._helpers import ( + _get_statuses_from_mutations_exception_group, ) -from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup from google.cloud.bigtable.data.mutations import RowMutationEntry +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.batcher import MutationsBatcher from google.cloud.bigtable.batcher import FLUSH_COUNT, MAX_MUTATION_SIZE from google.cloud.bigtable.encryption_info import EncryptionInfo @@ -767,9 +766,9 @@ def mutate_rows(self, rows, retry=DEFAULT_RETRY, timeout=DEFAULT): mutation_entries = [ RowMutationEntry(row.row_key, row._get_mutations()) for row in rows ] - return_statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len( + return_statuses = [status_pb2.Status(code=code_pb2.Code.UNKNOWN)] * len( mutation_entries - ) # By default, return status OKs for everything + ) try: self._table_impl.bulk_mutate_rows( @@ -779,41 +778,15 @@ def mutate_rows(self, rows, retry=DEFAULT_RETRY, timeout=DEFAULT): retryable_errors=retryable_errors, ) except MutationsExceptionGroup as mut_exc_group: - # We exception handle as follows: - # - # 1. Each exception in the error group is a FailedMutationEntryError, and its - # cause is either a singular exception or a RetryExceptionGroup consisting of - # multiple exceptions. - # - # 2. In the case of a singular exception, if the error does not have a gRPC status - # code, we return a status code of UNKNOWN. - # - # 3. In the case of a RetryExceptionGroup, we use terminal exception in the exception - # group and process that. - for error in mut_exc_group.exceptions: - cause = error.__cause__ - if isinstance(cause, RetryExceptionGroup): - return_statuses[error.index] = self._get_status( - cause.exceptions[-1] - ) - else: - return_statuses[error.index] = self._get_status(cause) - - return return_statuses - - @staticmethod - def _get_status(error): - if isinstance(error, GoogleAPICallError) and error.grpc_status_code is not None: - return status_pb2.Status( - code=error.grpc_status_code.value[0], - message=error.message, - details=error.details, + return_statuses = _get_statuses_from_mutations_exception_group( + mut_exc_group, len(mutation_entries) + ) + else: + return_statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len( + mutation_entries ) - return status_pb2.Status( - code=code_pb2.Code.UNKNOWN, - message=str(error), - ) + return return_statuses def sample_row_keys(self): """Read a sample of row keys in the table. diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index c96570b76..69680326d 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -482,6 +482,42 @@ async def test_mutations_batcher_timer_flush(self, client, target, temp_rows): # ensure cell is updated assert (await self._retrieve_cell_value(target, row_key)) == new_value + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("target") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_mutations_batcher_completed_callback( + self, client, target, temp_rows + ): + """ + test batcher with batch completed callback. It should be called when the batcher flushes. + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + from google.rpc import code_pb2, status_pb2 + + import mock + + callback = mock.Mock() + + new_value = uuid.uuid4().hex.encode() + row_key, mutation = await self._create_row_and_mutation( + target, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + flush_interval = 0.1 + async with target.mutations_batcher(flush_interval=flush_interval) as batcher: + batcher._user_batch_completed_callback = callback + await batcher.append(bulk_mutation) + await CrossSync.yield_to_event_loop() + assert len(batcher._staged_entries) == 1 + await CrossSync.sleep(flush_interval + 0.1) + assert len(batcher._staged_entries) == 0 + callback.assert_called_once_with([status_pb2.Status(code=code_pb2.OK)]) + # ensure cell is updated + assert (await self._retrieve_cell_value(target, row_key)) == new_value + @pytest.mark.usefixtures("client") @pytest.mark.usefixtures("target") @CrossSync.Retry( diff --git a/tests/system/data/test_system_autogen.py b/tests/system/data/test_system_autogen.py index 44895808a..0b954fd66 100644 --- a/tests/system/data/test_system_autogen.py +++ b/tests/system/data/test_system_autogen.py @@ -385,6 +385,34 @@ def test_mutations_batcher_timer_flush(self, client, target, temp_rows): assert len(batcher._staged_entries) == 0 assert self._retrieve_cell_value(target, row_key) == new_value + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("target") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutations_batcher_completed_callback(self, client, target, temp_rows): + """test batcher with batch completed callback. It should be called when the batcher flushes.""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + from google.rpc import code_pb2, status_pb2 + import mock + + callback = mock.Mock() + new_value = uuid.uuid4().hex.encode() + (row_key, mutation) = self._create_row_and_mutation( + target, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + flush_interval = 0.1 + with target.mutations_batcher(flush_interval=flush_interval) as batcher: + batcher._user_batch_completed_callback = callback + batcher.append(bulk_mutation) + CrossSync._Sync_Impl.yield_to_event_loop() + assert len(batcher._staged_entries) == 1 + CrossSync._Sync_Impl.sleep(flush_interval + 0.1) + assert len(batcher._staged_entries) == 0 + callback.assert_called_once_with([status_pb2.Status(code=code_pb2.OK)]) + assert self._retrieve_cell_value(target, row_key) == new_value + @pytest.mark.usefixtures("client") @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry( diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index e8df3eb42..9a06d692f 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -941,8 +941,8 @@ async def test__execute_mutate_rows(self): assert args[0] == table.client._gapic_client assert args[1] == table assert args[2] == batch - kwargs["operation_timeout"] == 17 - kwargs["attempt_timeout"] == 13 + assert kwargs["operation_timeout"] == 17 + assert kwargs["attempt_timeout"] == 13 assert result == [] @CrossSync.pytest @@ -962,8 +962,84 @@ async def test__execute_mutate_rows_returns_errors(self): table.default_mutate_rows_attempt_timeout = 13 table.default_mutate_rows_retryable_errors = () async with self._make_one(table) as instance: + batch = [self._make_mutation(), self._make_mutation()] + result = await instance._execute_mutate_rows(batch) + assert len(result) == 2 + assert result[0] == err1 + assert result[1] == err2 + # indices should be set to None + assert result[0].index is None + assert result[1].index is None + + @CrossSync.pytest + async def test__execute_mutate_rows_batch_completed_callback(self): + from google.rpc import code_pb2, status_pb2 + + with mock.patch.object(CrossSync, "_MutateRowsOperation") as mutate_rows: + mutate_rows.return_value = CrossSync.Mock() + start_operation = mutate_rows().start + table = mock.Mock() + table.table_name = "test-table" + table.app_profile_id = "test-app-profile" + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + callback = mock.Mock() + async with self._make_one(table) as instance: + instance._user_batch_completed_callback = callback batch = [self._make_mutation()] result = await instance._execute_mutate_rows(batch) + callback.assert_called_once_with([status_pb2.Status(code=code_pb2.OK)]) + assert start_operation.call_count == 1 + args, kwargs = mutate_rows.call_args + assert args[0] == table.client._gapic_client + assert args[1] == table + assert args[2] == batch + assert kwargs["operation_timeout"] == 17 + assert kwargs["attempt_timeout"] == 13 + assert result == [] + + @CrossSync.pytest + async def test__execute_mutate_rows_batch_completed_callback_errors(self): + from google.api_core import exceptions + from google.cloud.bigtable.data.exceptions import ( + MutationsExceptionGroup, + FailedMutationEntryError, + ) + from google.rpc import code_pb2, status_pb2 + + with mock.patch.object(CrossSync._MutateRowsOperation, "start") as mutate_rows: + err1 = FailedMutationEntryError( + 1, mock.Mock(), exceptions.DataLoss("test error") + ) + err2 = FailedMutationEntryError( + 2, mock.Mock(), exceptions.DataLoss("test error") + ) + mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + callback = mock.Mock() + async with self._make_one(table) as instance: + instance._user_batch_completed_callback = callback + batch = [ + self._make_mutation(), + self._make_mutation(), + self._make_mutation(), + ] + result = await instance._execute_mutate_rows(batch) + callback.assert_called_once_with( + [ + status_pb2.Status(code=code_pb2.OK), + status_pb2.Status( + code=code_pb2.DATA_LOSS, message="test error" + ), + status_pb2.Status( + code=code_pb2.DATA_LOSS, message="test error" + ), + ] + ) assert len(result) == 2 assert result[0] == err1 assert result[1] == err2 diff --git a/tests/unit/data/_sync_autogen/test_mutations_batcher.py b/tests/unit/data/_sync_autogen/test_mutations_batcher.py index 60a6708ba..413c7b955 100644 --- a/tests/unit/data/_sync_autogen/test_mutations_batcher.py +++ b/tests/unit/data/_sync_autogen/test_mutations_batcher.py @@ -821,8 +821,8 @@ def test__execute_mutate_rows(self): assert args[0] == table.client._gapic_client assert args[1] == table assert args[2] == batch - kwargs["operation_timeout"] == 17 - kwargs["attempt_timeout"] == 13 + assert kwargs["operation_timeout"] == 17 + assert kwargs["attempt_timeout"] == 13 assert result == [] def test__execute_mutate_rows_returns_errors(self): @@ -843,8 +843,85 @@ def test__execute_mutate_rows_returns_errors(self): table.default_mutate_rows_attempt_timeout = 13 table.default_mutate_rows_retryable_errors = () with self._make_one(table) as instance: + batch = [self._make_mutation(), self._make_mutation()] + result = instance._execute_mutate_rows(batch) + assert len(result) == 2 + assert result[0] == err1 + assert result[1] == err2 + assert result[0].index is None + assert result[1].index is None + + def test__execute_mutate_rows_batch_completed_callback(self): + from google.rpc import code_pb2, status_pb2 + + with mock.patch.object( + CrossSync._Sync_Impl, "_MutateRowsOperation" + ) as mutate_rows: + mutate_rows.return_value = CrossSync._Sync_Impl.Mock() + start_operation = mutate_rows().start + table = mock.Mock() + table.table_name = "test-table" + table.app_profile_id = "test-app-profile" + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + callback = mock.Mock() + with self._make_one(table) as instance: + instance._user_batch_completed_callback = callback batch = [self._make_mutation()] result = instance._execute_mutate_rows(batch) + callback.assert_called_once_with([status_pb2.Status(code=code_pb2.OK)]) + assert start_operation.call_count == 1 + (args, kwargs) = mutate_rows.call_args + assert args[0] == table.client._gapic_client + assert args[1] == table + assert args[2] == batch + assert kwargs["operation_timeout"] == 17 + assert kwargs["attempt_timeout"] == 13 + assert result == [] + + def test__execute_mutate_rows_batch_completed_callback_errors(self): + from google.api_core import exceptions + from google.cloud.bigtable.data.exceptions import ( + MutationsExceptionGroup, + FailedMutationEntryError, + ) + from google.rpc import code_pb2, status_pb2 + + with mock.patch.object( + CrossSync._Sync_Impl._MutateRowsOperation, "start" + ) as mutate_rows: + err1 = FailedMutationEntryError( + 1, mock.Mock(), exceptions.DataLoss("test error") + ) + err2 = FailedMutationEntryError( + 2, mock.Mock(), exceptions.DataLoss("test error") + ) + mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + callback = mock.Mock() + with self._make_one(table) as instance: + instance._user_batch_completed_callback = callback + batch = [ + self._make_mutation(), + self._make_mutation(), + self._make_mutation(), + ] + result = instance._execute_mutate_rows(batch) + callback.assert_called_once_with( + [ + status_pb2.Status(code=code_pb2.OK), + status_pb2.Status( + code=code_pb2.DATA_LOSS, message="test error" + ), + status_pb2.Status( + code=code_pb2.DATA_LOSS, message="test error" + ), + ] + ) assert len(result) == 2 assert result[0] == err1 assert result[1] == err2 diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py index 20ad972d7..6b6ecdb3a 100644 --- a/tests/unit/data/test__helpers.py +++ b/tests/unit/data/test__helpers.py @@ -16,8 +16,11 @@ import grpc from google.api_core import exceptions as core_exceptions import google.cloud.bigtable.data._helpers as _helpers +import google.cloud.bigtable.data.exceptions as bt_exceptions from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.rpc import code_pb2, status_pb2 + import mock @@ -264,6 +267,92 @@ def test_rst_stream_aware_predicate( assert predicate(exception) is expected_is_retryable +class TestGetStatusesFromMutationsExceptionGroup: + @pytest.mark.parametrize( + "failed_idx,cause_exc,expected_status", + [ + ( + 0, + core_exceptions.DeadlineExceeded( + "Operation timed out after 40 seconds" + ), + status_pb2.Status( + code=code_pb2.DEADLINE_EXCEEDED, + message="Operation timed out after 40 seconds", + ), + ), + ( + 0, + RuntimeError("Something happened"), + status_pb2.Status(code=code_pb2.UNKNOWN, message="Something happened"), + ), + ( + 0, + bt_exceptions.RetryExceptionGroup( + excs=[ + core_exceptions.ServiceUnavailable("Service Unavailable"), + core_exceptions.ServiceUnavailable("Service Unavailable"), + core_exceptions.DeadlineExceeded( + "Operation timed out after 40 seconds" + ), + ] + ), + status_pb2.Status( + code=code_pb2.DEADLINE_EXCEEDED, + message="Operation timed out after 40 seconds", + ), + ), + ( + 0, + bt_exceptions.RetryExceptionGroup( + excs=[ + core_exceptions.ServiceUnavailable("Service Unavailable"), + core_exceptions.ServiceUnavailable("Service Unavailable"), + RuntimeError("Something happened"), + ] + ), + status_pb2.Status(code=code_pb2.UNKNOWN, message="Something happened"), + ), + ( + 0, + None, + status_pb2.Status( + code=code_pb2.UNKNOWN, message="An unknown error has occurred" + ), + ), + ( + 100, + RuntimeError("Something happened"), + status_pb2.Status(code=code_pb2.OK), + ), + ( + None, + RuntimeError("Something happened"), + status_pb2.Status(code=code_pb2.OK), + ), + ], + ) + def test_get_statuses_from_mutations_exception_group( + self, failed_idx, cause_exc, expected_status + ): + mutation_exception_group = bt_exceptions.MutationsExceptionGroup( + excs=[ + bt_exceptions.FailedMutationEntryError( + failed_idx=failed_idx, + failed_mutation_entry=mock.Mock(), + cause=cause_exc, + ) + ], + total_entries=1, + message="Mutations failed.", + ) + + statuses = _helpers._get_statuses_from_mutations_exception_group( + mutation_exception_group, 1 + ) + assert statuses[0] == expected_status + + class TestGetRetryableErrors: @pytest.mark.parametrize( "input_codes,input_table,expected",