Skip to content
Merged
2 changes: 2 additions & 0 deletions google/cloud/bigtable/data/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from __future__ import annotations

from typing import (
Callable,
cast,
Any,
AsyncIterable,
Expand Down Expand Up @@ -115,6 +116,7 @@
if TYPE_CHECKING:
from google.cloud.bigtable.data._helpers import RowKeySamples
from google.cloud.bigtable.data._helpers import ShardedQuery
from google.rpc import status_pb2

if CrossSync.is_async:
from google.cloud.bigtable.data._async.mutations_batcher import (
Expand Down
19 changes: 18 additions & 1 deletion google/cloud/bigtable/data/_async/mutations_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#
from __future__ import annotations

from typing import Sequence, TYPE_CHECKING, cast
from typing import Callable, Optional, Sequence, TYPE_CHECKING, cast
import atexit
import warnings
from collections import deque
Expand All @@ -24,6 +24,10 @@
from google.cloud.bigtable.data.exceptions import FailedMutationEntryError
from google.cloud.bigtable.data._helpers import _get_retryable_errors
from google.cloud.bigtable.data._helpers import _get_timeouts
from google.cloud.bigtable.data._helpers import (
_get_statuses_from_mutations_exception_group,
)

from google.cloud.bigtable.data._helpers import TABLE_DEFAULT

from google.cloud.bigtable.data.mutations import (
Expand All @@ -33,6 +37,9 @@

from google.cloud.bigtable.data._cross_sync import CrossSync

from google.rpc import code_pb2
from google.rpc import status_pb2

if TYPE_CHECKING:
from google.cloud.bigtable.data.mutations import RowMutationEntry

Expand Down Expand Up @@ -269,6 +276,7 @@ def __init__(
self._newest_exceptions: deque[Exception] = deque(
maxlen=self._exception_list_limit
)
self._user_batch_completed_callback = None
# clean up on program exit
atexit.register(self._on_exit)

Expand Down Expand Up @@ -380,6 +388,7 @@ async def _execute_mutate_rows(
list of FailedMutationEntryError objects for mutations that failed.
FailedMutationEntryError objects will not contain index information
"""
statuses = [status_pb2.Status(code=code_pb2.Code.UNKNOWN)] * len(batch)
try:
operation = CrossSync._MutateRowsOperation(
self._target.client._gapic_client,
Expand All @@ -391,13 +400,21 @@ async def _execute_mutate_rows(
)
await operation.start()
except MutationsExceptionGroup as e:
statuses = _get_statuses_from_mutations_exception_group(e, len(batch))

# strip index information from exceptions, since it is not useful in a batch context
for subexc in e.exceptions:
subexc.index = None
return list(e.exceptions)
else:
statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len(batch)
finally:
# mark batch as complete in flow control
await self._flow_control.remove_from_flow(batch)

# Call batch done callback with list of statuses.
if self._user_batch_completed_callback:
self._user_batch_completed_callback(statuses)
return []

def _add_exceptions(self, excs: list[Exception]):
Expand Down
66 changes: 65 additions & 1 deletion google/cloud/bigtable/data/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""
from __future__ import annotations

from typing import Callable, Sequence, List, Tuple, TYPE_CHECKING, Union
from typing import Callable, Sequence, List, Optional, Tuple, TYPE_CHECKING, Union
import time
import enum
from collections import namedtuple
Expand All @@ -26,6 +26,10 @@
from google.api_core import retry as retries
from google.api_core.retry import RetryFailureReason
from google.cloud.bigtable.data.exceptions import RetryExceptionGroup
from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup
from google.rpc import code_pb2
from google.rpc import status_pb2


if TYPE_CHECKING:
import grpc
Expand Down Expand Up @@ -224,6 +228,66 @@ def _align_timeouts(operation: float, attempt: float | None) -> tuple[float, flo
return operation, final_attempt


def _get_statuses_from_mutations_exception_group(
exc_group: MutationsExceptionGroup, batch_size: int
) -> list[status_pb2.Status]:
"""
Helper function that populates a list of Status objects with exception information from
the exception group.

Args:
exc_group: The exception group from a mutate rows operation
batch_size: How many RowMutationGroups were provided to the batch
Returns:
list[status_pb2.Status]: A list of Status proto objects
"""
# We exception handle as follows:
#
# 1. Each exception in the error group is a FailedMutationEntryError, and its
# cause is either a singular exception or a RetryExceptionGroup consisting of
# multiple exceptions.
#
# 2. In the case of a singular exception, if the error does not have a gRPC status
# code, we return a status code of UNKNOWN.
#
# 3. In the case of a RetryExceptionGroup, we use terminal exception in the exception
# group and process that.
statuses = [status_pb2.Status(code=code_pb2.OK)] * batch_size
for error in exc_group.exceptions:
if isinstance(error.index, int) and 0 <= error.index < len(statuses):
cause = error.__cause__
if isinstance(cause, RetryExceptionGroup):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do an extra check here to make sure cause.exceptions has items? I can't remember how that works.

If so, there could be potential for an IndexError here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some testing and Python prevents you from initializing an exception group without items, so we probably don't need to worry about the empty case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, looks like that's consistent with our custom object:

raise ValueError("exceptions must be a non-empty sequence")

statuses[error.index] = _get_status(cause.exceptions[-1])
else:
statuses[error.index] = _get_status(cause)
return statuses


def _get_status(exc: Optional[Exception]) -> status_pb2.Status:
"""
Helper function that returns a Status object corresponding to the given exception.

Args:
exc: An exception to be converted into a Status.
Returns:
status_pb2.Status: A Status proto object.
"""
if (
isinstance(exc, core_exceptions.GoogleAPICallError)
and exc.grpc_status_code is not None
):
return status_pb2.Status( # type: ignore[unreachable]
code=exc.grpc_status_code.value[0],
message=exc.message,
details=exc.details,
)

return status_pb2.Status(
code=code_pb2.Code.UNKNOWN,
message=str(exc) if exc else "An unknown error has occurred",
)


def _validate_timeouts(
operation_timeout: float, attempt_timeout: float | None, allow_none: bool = False
):
Expand Down
12 changes: 12 additions & 0 deletions google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@
from google.cloud.bigtable.data.exceptions import FailedMutationEntryError
from google.cloud.bigtable.data._helpers import _get_retryable_errors
from google.cloud.bigtable.data._helpers import _get_timeouts
from google.cloud.bigtable.data._helpers import (
_get_statuses_from_mutations_exception_group,
)
from google.cloud.bigtable.data._helpers import TABLE_DEFAULT
from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT
from google.cloud.bigtable.data.mutations import Mutation
from google.cloud.bigtable.data._cross_sync import CrossSync
from google.rpc import code_pb2
from google.rpc import status_pb2

if TYPE_CHECKING:
from google.cloud.bigtable.data.mutations import RowMutationEntry
Expand Down Expand Up @@ -233,6 +238,7 @@ def __init__(
self._newest_exceptions: deque[Exception] = deque(
maxlen=self._exception_list_limit
)
self._user_batch_completed_callback = None
atexit.register(self._on_exit)

def _timer_routine(self, interval: float | None) -> None:
Expand Down Expand Up @@ -324,6 +330,7 @@ def _execute_mutate_rows(
list[FailedMutationEntryError]:
list of FailedMutationEntryError objects for mutations that failed.
FailedMutationEntryError objects will not contain index information"""
statuses = [status_pb2.Status(code=code_pb2.Code.UNKNOWN)] * len(batch)
try:
operation = CrossSync._Sync_Impl._MutateRowsOperation(
self._target.client._gapic_client,
Expand All @@ -335,11 +342,16 @@ def _execute_mutate_rows(
)
operation.start()
except MutationsExceptionGroup as e:
statuses = _get_statuses_from_mutations_exception_group(e, len(batch))
for subexc in e.exceptions:
subexc.index = None
return list(e.exceptions)
else:
statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len(batch)
finally:
self._flow_control.remove_from_flow(batch)
if self._user_batch_completed_callback:
self._user_batch_completed_callback(statuses)
return []

def _add_exceptions(self, excs: list[Exception]):
Expand Down
53 changes: 13 additions & 40 deletions google/cloud/bigtable/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Set
import warnings

from google.api_core.exceptions import GoogleAPICallError
from google.api_core.exceptions import Aborted
from google.api_core.exceptions import DeadlineExceeded
from google.api_core.exceptions import NotFound
Expand All @@ -31,12 +30,12 @@
from google.cloud.bigtable.column_family import _gc_rule_from_pb
from google.cloud.bigtable.column_family import ColumnFamily
from google.cloud.bigtable.data._helpers import TABLE_DEFAULT
from google.cloud.bigtable.data.exceptions import (
RetryExceptionGroup,
MutationsExceptionGroup,
from google.cloud.bigtable.data._helpers import (
_get_statuses_from_mutations_exception_group,
)
from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup
from google.cloud.bigtable.data.mutations import RowMutationEntry
from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery
from google.cloud.bigtable.batcher import MutationsBatcher
from google.cloud.bigtable.batcher import FLUSH_COUNT, MAX_MUTATION_SIZE
from google.cloud.bigtable.encryption_info import EncryptionInfo
Expand Down Expand Up @@ -767,9 +766,9 @@ def mutate_rows(self, rows, retry=DEFAULT_RETRY, timeout=DEFAULT):
mutation_entries = [
RowMutationEntry(row.row_key, row._get_mutations()) for row in rows
]
return_statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len(
return_statuses = [status_pb2.Status(code=code_pb2.Code.UNKNOWN)] * len(
mutation_entries
) # By default, return status OKs for everything
)

try:
self._table_impl.bulk_mutate_rows(
Expand All @@ -779,41 +778,15 @@ def mutate_rows(self, rows, retry=DEFAULT_RETRY, timeout=DEFAULT):
retryable_errors=retryable_errors,
)
except MutationsExceptionGroup as mut_exc_group:
# We exception handle as follows:
#
# 1. Each exception in the error group is a FailedMutationEntryError, and its
# cause is either a singular exception or a RetryExceptionGroup consisting of
# multiple exceptions.
#
# 2. In the case of a singular exception, if the error does not have a gRPC status
# code, we return a status code of UNKNOWN.
#
# 3. In the case of a RetryExceptionGroup, we use terminal exception in the exception
# group and process that.
for error in mut_exc_group.exceptions:
cause = error.__cause__
if isinstance(cause, RetryExceptionGroup):
return_statuses[error.index] = self._get_status(
cause.exceptions[-1]
)
else:
return_statuses[error.index] = self._get_status(cause)

return return_statuses

@staticmethod
def _get_status(error):
if isinstance(error, GoogleAPICallError) and error.grpc_status_code is not None:
return status_pb2.Status(
code=error.grpc_status_code.value[0],
message=error.message,
details=error.details,
return_statuses = _get_statuses_from_mutations_exception_group(
mut_exc_group, len(mutation_entries)
)
else:
return_statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len(
mutation_entries
)

return status_pb2.Status(
code=code_pb2.Code.UNKNOWN,
message=str(error),
)
return return_statuses

def sample_row_keys(self):
"""Read a sample of row keys in the table.
Expand Down
36 changes: 36 additions & 0 deletions tests/system/data/test_system_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,42 @@ async def test_mutations_batcher_timer_flush(self, client, target, temp_rows):
# ensure cell is updated
assert (await self._retrieve_cell_value(target, row_key)) == new_value

@pytest.mark.usefixtures("client")
@pytest.mark.usefixtures("target")
@CrossSync.Retry(
predicate=retry.if_exception_type(ClientError), initial=1, maximum=5
)
@CrossSync.pytest
async def test_mutations_batcher_completed_callback(
self, client, target, temp_rows
):
"""
test batcher with batch completed callback. It should be called when the batcher flushes.
"""
from google.cloud.bigtable.data.mutations import RowMutationEntry
from google.rpc import code_pb2, status_pb2

import mock

callback = mock.Mock()

new_value = uuid.uuid4().hex.encode()
row_key, mutation = await self._create_row_and_mutation(
target, temp_rows, new_value=new_value
)
bulk_mutation = RowMutationEntry(row_key, [mutation])
flush_interval = 0.1
async with target.mutations_batcher(flush_interval=flush_interval) as batcher:
batcher._user_batch_completed_callback = callback
await batcher.append(bulk_mutation)
await CrossSync.yield_to_event_loop()
assert len(batcher._staged_entries) == 1
await CrossSync.sleep(flush_interval + 0.1)
assert len(batcher._staged_entries) == 0
callback.assert_called_once_with([status_pb2.Status(code=code_pb2.OK)])
# ensure cell is updated
assert (await self._retrieve_cell_value(target, row_key)) == new_value

@pytest.mark.usefixtures("client")
@pytest.mark.usefixtures("target")
@CrossSync.Retry(
Expand Down
28 changes: 28 additions & 0 deletions tests/system/data/test_system_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,34 @@ def test_mutations_batcher_timer_flush(self, client, target, temp_rows):
assert len(batcher._staged_entries) == 0
assert self._retrieve_cell_value(target, row_key) == new_value

@pytest.mark.usefixtures("client")
@pytest.mark.usefixtures("target")
@CrossSync._Sync_Impl.Retry(
predicate=retry.if_exception_type(ClientError), initial=1, maximum=5
)
def test_mutations_batcher_completed_callback(self, client, target, temp_rows):
"""test batcher with batch completed callback. It should be called when the batcher flushes."""
from google.cloud.bigtable.data.mutations import RowMutationEntry
from google.rpc import code_pb2, status_pb2
import mock

callback = mock.Mock()
new_value = uuid.uuid4().hex.encode()
(row_key, mutation) = self._create_row_and_mutation(
target, temp_rows, new_value=new_value
)
bulk_mutation = RowMutationEntry(row_key, [mutation])
flush_interval = 0.1
with target.mutations_batcher(flush_interval=flush_interval) as batcher:
batcher._user_batch_completed_callback = callback
batcher.append(bulk_mutation)
CrossSync._Sync_Impl.yield_to_event_loop()
assert len(batcher._staged_entries) == 1
CrossSync._Sync_Impl.sleep(flush_interval + 0.1)
assert len(batcher._staged_entries) == 0
callback.assert_called_once_with([status_pb2.Status(code=code_pb2.OK)])
assert self._retrieve_cell_value(target, row_key) == new_value

@pytest.mark.usefixtures("client")
@pytest.mark.usefixtures("target")
@CrossSync._Sync_Impl.Retry(
Expand Down
Loading
Loading